Skip to content
Snippets Groups Projects
stackedbandinput.py 28.6 KiB
Newer Older
# -*- coding: utf-8 -*-
# noinspection PyPep8Naming
"""
***************************************************************************
    stackedbandinput.py

    Sometimes time-series-data is written out as stacked band images, having one observation per band.
    This module helps to use such data as EOTS input.
    ---------------------
    Date                 : June 2018
    Copyright            : (C) 2018 by Benjamin Jakimow
    Email                : benjamin.jakimow@geo.hu-berlin.de
***************************************************************************
*                                                                         *
*   This program is free software; you can redistribute it and/or modify  *
*   it under the terms of the GNU General Public License as published by  *
*   the Free Software Foundation; either version 2 of the License, or     *
*   (at your option) any later version.                                   *
*                                                                         *
***************************************************************************
"""

import os, re, tempfile, pickle, copy, shutil, locale, uuid, csv, io
Benjamin Jakimow's avatar
Benjamin Jakimow committed
from xml.etree import ElementTree
from collections import OrderedDict
from qgis.core import *
from qgis.gui import *
from qgis.utils import qgsfunction
from qgis.PyQt.QtCore import *
from qgis.PyQt.QtGui import *
from qgis.PyQt.QtWidgets import *
from qgis.core import QgsField, QgsFields, QgsFeature, QgsMapLayer, QgsVectorLayer, QgsConditionalStyle
from qgis.gui import QgsMapCanvas, QgsDockWidget
from pyqtgraph.widgets.PlotWidget import PlotWidget
from pyqtgraph.graphicsItems.PlotDataItem import PlotDataItem
from pyqtgraph.graphicsItems.PlotItem import PlotItem
import pyqtgraph.functions as fn
import numpy as np
from osgeo import gdal, gdal_array
Benjamin Jakimow's avatar
Benjamin Jakimow committed
import numpy as np
from eotimeseriesviewer.utils import *
from eotimeseriesviewer.virtualrasters import *
from eotimeseriesviewer.models import *
from eotimeseriesviewer.dateparser import *
from eotimeseriesviewer.plotstyling import PlotStyle, PlotStyleDialog, MARKERSYMBOLS2QGIS_SYMBOLS
import eotimeseriesviewer.mimedata as mimedata
Benjamin Jakimow's avatar
Benjamin Jakimow committed

def datesFromDataset(dataset:gdal.Dataset)->list:

    nb = dataset.RasterCount

    def checkDates(dateList):
        if not len(dateList) == nb:
            return False
        for d in dateList:
            if not isinstance(d, np.datetime64):
                return False
        return True

    searchedKeys = []
    searchedKeys.append(re.compile('aquisition[ ]*dates$', re.I))
    searchedKeys.append(re.compile('observation[ ]*dates$', re.I))
    searchedKeys.append(re.compile('wavelength$', re.I))

    #1. Check Metadata
    for domain in dataset.GetMetadataDomainList():
        domainData = dataset.GetMetadata_Dict(domain)
        assert isinstance(domainData, dict)

        for key, values in domainData.items():
            for regex in searchedKeys:
                if regex.search(key.strip()):
                    values = re.sub('[{}]', '', values)
                    values = values.split(',')
                    dateValues = [extractDateTimeGroup(t) for t in values]
                    if checkDates(dateValues):
                        return dateValues


    #2. Check Band Names
    bandDates = [extractDateTimeGroup(dataset.GetRasterBand(b+1).GetDescription()) for b in range(nb)]
    bandDates = [b for b in bandDates if isinstance(b, np.datetime64)]
    if checkDates(bandDates):
        return bandDates

    return []

class InputStackInfo(object):

    def __init__(self, dataset):
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        if isinstance(dataset, str):
            #test ENVI header first
            basename = os.path.splitext(dataset)[0]
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            if os.path.isfile(basename+'.hdr'):
                ds = gdal.OpenEx(dataset, allowed_drivers=['ENVI'])
            if not isinstance(ds, gdal.Dataset):
                ds = gdal.Open(dataset)
            if not isinstance(ds, gdal.Dataset):
                raise Exception('Unable to open {}'.format(dataset))

            dataset = ds
            del ds

        assert isinstance(dataset, gdal.Dataset)

        self.mMetadataDomains = dataset.GetMetadataDomainList()
        self.mMetaData = OrderedDict()
        for domain in self.mMetadataDomains:
            self.mMetaData[domain] = dataset.GetMetadata_Dict(domain)

        self.ns = dataset.RasterXSize
        self.nl = dataset.RasterYSize
        self.nb = dataset.RasterCount

        self.wkt = dataset.GetProjection()
        self.gt = dataset.GetGeoTransform()

        self.colorTable = dataset.GetRasterBand(1).GetColorTable()
        self.classNames = dataset.GetRasterBand(1).GetCategoryNames()

        self.path = dataset.GetFileList()[0]

        self.outputBandName = os.path.basename(self.path)
        if len(self.outputBandName) == 0:
            self.outputBandName = ''
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        self.bandnames = []
        self.nodatavalues = []
        for b in range(self.nb):
            band = dataset.GetRasterBand(b+1)
            assert isinstance(band, gdal.Band)
            self.bandnames.append(band.GetDescription())
            self.nodatavalues.append(band.GetNoDataValue())


        self.mDates = datesFromDataset(dataset)

    def __len__(self):
        return len(self.mDates)

    def dates(self)->list:
        """Returns a list of dates"""
        return self.mDates


    def structure(self):
        return (self.ns, self.nl, self.nb, self.gt, self.wkt)

Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def wavelength(self):
        return self.mMetaData[''].get('wavelength')

    def setWavelength(self, wl):
        self.mMetaData['']['wavelength'] = wl


class OutputVRTDescription(object):
    """
    Descrbies an output VRT
    """

    def __init__(self, path:str, date:np.datetime64):
        super(OutputVRTDescription, self).__init__()
        self.mPath = path
        self.mDate = date


    def setPath(self, path:str):
        self.mPath = path

class InputStackTableModel(QAbstractTableModel):



    def __init__(self, parent=None):

        super(InputStackTableModel, self).__init__(parent)
        self.mStackImages = []

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.cn_source = 'Source'
        self.cn_dates = 'Dates'
        self.cn_crs = 'GT + CRS'
        self.cn_ns = 'ns'
        self.cn_nl = 'nl'
        self.cn_nb = 'nb'
        self.cn_name = 'Band Name'
        self.cn_wl = 'Wavelength'
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.mColumnNames = [self.cn_source, self.cn_dates, self.cn_name, self.cn_wl, self.cn_ns, self.cn_nl, self.cn_nb, self.cn_crs]

        self.mColumnTooltips = {}
        self.mColumnTooltips[self.cn_source] = 'Stack source uri / file path'
        self.mColumnTooltips[self.cn_crs] = 'Geo-Transformation + Coordinate Reference System'
        self.mColumnTooltips[self.cn_ns] = 'Number of samples / pixel in horizontal direction'
        self.mColumnTooltips[self.cn_nl] = 'Number of lines / pixel in vertical direction'
        self.mColumnTooltips[self.cn_nb] = 'Number of bands'
        self.mColumnTooltips[self.cn_name] = 'Prefix of band name in output image'
        self.mColumnTooltips[self.cn_wl] = 'Wavelength in output image'
        self.mColumnTooltips[self.cn_dates] = 'Identified dates'

    def __len__(self):
        return len(self.mStackImages)

    def __iter__(self):
        return iter(self.mStackImages)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def columnName(self, i) -> str:
        if isinstance(i, QModelIndex):
            i = i.column()
        return self.mColumnNames[i]



    def dateInfo(self):
        """
        Returns a list with all extracted dates and a list of date in common between all datasets
        :return: [all dates], [dates in common]
        """
        if len(self) == 0:
            return [],[]
        datesTotal = set()
        datesInCommon = None
        for i, f in enumerate(self.mStackImages):
            assert isinstance(f, InputStackInfo)

            dates = f.dates()
            if datesInCommon is None:
                datesInCommon = set(dates)
            else:
                datesInCommon = datesInCommon.intersection(dates)

            datesTotal = datesTotal.union(f.dates())

        return sorted(list(datesTotal)), sorted(list(datesInCommon))

    def flags(self, index):
        if index.isValid():
            columnName = self.columnName(index)
            flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable
            if columnName in [self.cn_name, self.cn_wl]: #allow check state
Benjamin Jakimow's avatar
Benjamin Jakimow committed
                flags = flags | Qt.ItemIsEditable

            return flags
            #return item.qt_flags(index.column())
        return None

    def headerData(self, col, orientation, role):
        if Qt is None:
            return None
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        if orientation == Qt.Horizontal:
            cname = self.mColumnNames[col]
            if role == Qt.DisplayRole:
                return cname
            elif role == Qt.ToolTipRole:
                return self.mColumnTooltips.get(cname)
        elif orientation == Qt.Vertical and role == Qt.DisplayRole:
            return col
        return None


    def rowCount(self, parent=None):
        return len(self.mStackImages)

    def columnCount(self, parent: QModelIndex):
        return len(self.mColumnNames)

    def insertSources(self, paths, i=None):
        """
        Inserts new datasources
        :param paths: [list-of-datasources]
        :param i: index where to add the first datasource.
        """

        if i == None:
            i = self.rowCount()

        if not isinstance(paths, list):
            paths = [paths]

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        infos = [InputStackInfo(p) for p in paths]
        if len(infos) > 0:

            self.beginInsertRows(QModelIndex(), i, i+len(infos)-1)
            for j, info in enumerate(infos):
                assert isinstance(info, InputStackInfo)
                if len(info.outputBandName) == 0:
                    info.outputBandName = 'Band {}'.format(i+j+1)
                self.mStackImages.insert(i+j, info)
            self.endInsertRows()

Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def removeSources(self, stackInfos:list):

        for stackInfo in stackInfos:
            assert stackInfo in self.mStackImages

        for stackInfo in stackInfos:
            assert isinstance(stackInfo, InputStackInfo)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            idx = self.info2index(stackInfo)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            self.beginRemoveRows(QModelIndex(), idx.row(), idx.row())
            self.mStackImages.remove(stackInfo)
            self.endRemoveRows()

    def isValid(self):
        l = len(self.mStackImages)
        if l == 0:
            return False
        ref = self.mStackImages[0]
        assert isinstance(ref, InputStackInfo)

        #all input stacks need to have the same characteristic
        for stackInfo in self.mStackImages[1:]:
            assert isinstance(stackInfo, InputStackInfo)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            if not ref.dates() == stackInfo.dates():
                return False
            if not ref.structure() == stackInfo.structure():
                return False
        return True


Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def index2info(self, index:QModelIndex) -> InputStackInfo:
        return self.mStackImages[index.row()]
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def info2index(self, info:InputStackInfo) -> QModelIndex:
        r = self.mStackImages.index(info)
        return self.createIndex(r,0, info)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def data(self, index: QModelIndex, role: int):
        if not index.isValid():
            return None

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        info = self.mStackImages[index.row()]
        assert isinstance(info, InputStackInfo)
        cname = self.columnName(index)
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        if role in [Qt.DisplayRole, Qt.ToolTipRole]:
            if cname == self.cn_source:
                return info.path
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            if cname == self.cn_dates:
                dates = info.dates()
                if role == Qt.DisplayRole:
                    return len(dates)
                if role == Qt.ToolTipRole:
                    if len(dates) == 0:
                        return 'No dates identified. Can not use this image as input'
                    else:
                        if len(dates) > 11:
                            dates = dates[0:10] + ['...']
                        return '\n'.join([str(d) for d in dates])


            if cname == self.cn_ns:
                return info.ns
            if cname == self.cn_nl:
                return info.nl
            if cname == self.cn_nb:
                return info.nb
            if cname == self.cn_crs:
                return '{} {}'.format(info.gt, info.wkt)
            elif cname == self.cn_wl:
Benjamin Jakimow's avatar
Benjamin Jakimow committed
                return info.mMetaData[''].get('wavelength')
            elif cname == self.cn_name:
                return info.outputBandName

        if role == Qt.EditRole:
            if cname == self.cn_wl:
                return info.mMetaData[''].get('wavelength')
            elif cname == self.cn_name:
                return info.outputBandName
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        if role == Qt.BackgroundColorRole:
            if cname in [self.cn_name, self.cn_wl]:
                return QColor('yellow')
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def setData(self, index: QModelIndex, value, role: int):

        if not index.isValid():
            return None

        info = self.index2info(index)
        cname = self.columnName(index)

        changed = False
        if role == Qt.EditRole:
            if cname == self.cn_name:
                if isinstance(value, str) and len(value) > 0:
                    info.outputBandName = value
                    changed = True
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            elif cname == self.cn_wl:
                info.setWavelength(value)
                changed = True
        if changed:
            self.dataChanged.emit(index, index)
        return changed

class OutputImageModel(QAbstractTableModel):

    def __init__(self, parent=None):
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        super(OutputImageModel, self).__init__(parent)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.cn_uri = 'Path'
        self.cn_date = 'Date'
        self.mOutputImages = []
        self.mColumnNames = [self.cn_date, self.cn_uri]
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.mColumnTooltips = {}
        self.mColumnTooltips[self.cn_uri] = 'Output location'
        self.masterVRT_DateLookup = {}
        self.masterVRT_SourceBandTemplates = {}
        self.masterVRT_InputStacks = None
        self.masterVRT_XML = None
        self.mOutputDir = '/vsimem/'
        self.mOutputPrefix = 'date'
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def headerData(self, col, orientation, role):
        if Qt is None:
            return None
        if orientation == Qt.Horizontal:
            cname = self.mColumnNames[col]
            if role == Qt.DisplayRole:
                return cname
            elif role == Qt.ToolTipRole:
                return self.mColumnTooltips.get(cname)
        elif orientation == Qt.Vertical and role == Qt.DisplayRole:
            return col
        return None

    def createVRTUri(self, date:np.datetime64):

        path = os.path.join(self.mOutputDir, self.mOutputPrefix)
        path = '{}{}.vrt'.format(path, date)

        return path

    def clearOutputs(self):
        self.beginRemoveRows(QModelIndex(), 0, self.rowCount() - 1)
        self.mOutputImages = []
        self.endRemoveRows()

    def setMultiStackSources(self, listOfInputStacks:list, dates:list):
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        self.clearOutputs()
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        if listOfInputStacks is None or len(listOfInputStacks) == 0:
            return
        if dates is None or len(dates) == 0:
            return
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        for s in listOfInputStacks:
            assert isinstance(s, InputStackInfo)
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        listOfInputStacks = [s for s in listOfInputStacks if len(s) > 0]
        numberOfOutputVRTBands = len(listOfInputStacks)
        self.masterVRT_DateLookup.clear()
        self.masterVRT_InputStacks = listOfInputStacks
        self.masterVRT_SourceBandTemplates.clear()
        #dates = set()
        #for s in listOfInputStacks:
        #    for d in s.dates():
        #        dates.add(d)
        #dates = sorted(list(dates))
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        #create a LUT to get the stack indices for a related date (not each stack might contain a band for each date)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        for stackIndex, s in enumerate(listOfInputStacks):
            for bandIndex, bandDate in enumerate(s.dates()):
                if bandDate not in self.masterVRT_DateLookup.keys():
                    self.masterVRT_DateLookup[bandDate] = []
                self.masterVRT_DateLookup[bandDate].append((stackIndex, bandIndex))
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        #create VRT Template XML
        VRT = VRTRaster()
        wavelength = []
        for stackIndex, stack in enumerate(listOfInputStacks):
            assert isinstance(stack, InputStackInfo)
            vrtBand = VRTRasterBand()
            vrtBand.setName(stack.outputBandName)
            vrtSrc = VRTRasterInputSourceBand(stack.path, 0)
            vrtBand.addSource(vrtSrc)
            wavelength.append(stack.wavelength())
            VRT.addVirtualBand(vrtBand)

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        pathVSITmp = '/vsimem/temp.vrt'
        dsVRT = VRT.saveVRT(pathVSITmp)
        dsVRT.SetMetadataItem('acquisition date', 'XML_REPLACE_DATE')

        if None not in wavelength:
            dsVRT.SetMetadataItem('wavelength', ','.join(str(wl) for wl in wavelength))
            dsVRT.SetMetadataItem('wavelength units', 'Nanometers')

        for stackIndex, stack in enumerate(listOfInputStacks):
            band = dsVRT.GetRasterBand(stackIndex+1)
            assert isinstance(band, gdal.Band)
            assert isinstance(stack, InputStackInfo)
            if isinstance(stack.colorTable, gdal.ColorTable) and stack.colorTable.GetCount() > 0:
                band.SetColorTable(stack.colorTable)
            if stack.classNames:
                band.SetCategoryNames(stack.classNames)

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        dsVRT.FlushCache()
        drv = dsVRT.GetDriver()
        masterVRT_XML = read_vsimem(pathVSITmp).decode('utf-8')
        drv.Delete(pathVSITmp)
        outputVRTs = []


        eTree = ElementTree.fromstring(masterVRT_XML)
        for iBand, elemBand in enumerate(eTree.findall('VRTRasterBand')):
            sourceElements  = elemBand.findall('ComplexSource') + elemBand.findall('SimpleSource')
            assert len(sourceElements) == 1
            self.masterVRT_SourceBandTemplates[iBand] = copy.deepcopy(sourceElements[0])
            elemBand.remove(sourceElements[0])

        for date in dates:
            assert isinstance(date, np.datetime64)
            path = self.createVRTUri(date)
            outputDescription = OutputVRTDescription(path, date)
            outputVRTs.append(outputDescription)

        self.masterVRT_XML = eTree


        self.beginInsertRows(QModelIndex(), 0, len(outputVRTs)-1)
        self.mOutputImages = outputVRTs[:]
        self.endInsertRows()

    def setOutputDir(self, path:str):
        self.mOutputDir = path
        self.updateOutputURIs()

    def setOutputPrefix(self, basename:str):
        self.mOutputPrefix = basename
        self.updateOutputURIs()

    def updateOutputURIs(self):
        c = self.mColumnNames.index(self.cn_uri)
        ul = self.createIndex(0, c)
        lr = self.createIndex(self.rowCount()-1, c)

        for outputVRT in self:
            assert isinstance(outputVRT, OutputVRTDescription)
            outputVRT.setPath(self.createVRTUri(outputVRT.mDate))
        self.dataChanged.emit(ul, lr)


    def __len__(self):
        return len(self.mOutputImages)

    def __iter__(self):
        return iter(self.mOutputImages)

    def rowCount(self, parent=None) -> int:
        return len(self.mOutputImages)

    def columnCount(self, parent=None) -> int:
        return len(self.mColumnNames)

    def columnName(self, i) -> str:
        if isinstance(i, QModelIndex):
            i = i.column()
        return self.mColumnNames[i]

    def columnIndex(self, columnName:str)-> QModelIndex:
        c = self.mColumnNames.index(columnName)
        return self.createIndex(0, c)

    def index2vrt(self, index:QModelIndex) -> OutputVRTDescription:
        return self.mOutputImages[index.row()]

    def vrt2index(self, vrt:OutputVRTDescription) -> QModelIndex:
        i = self.mOutputImages[vrt]
        return self.createIndex(i, 0, vrt)

    def data(self, index: QModelIndex, role: int):

        if not index.isValid():
            return None

        cname = self.columnName(index)
        vrt = self.index2vrt(index)
        if role in [Qt.DisplayRole, Qt.ToolTipRole]:
            if cname == self.cn_uri:
                return vrt.mPath
            if cname == self.cn_date:
                return str(vrt.mDate)

    def vrtXML(self, outputDefinition:OutputVRTDescription, asElementTree=False) -> str:
        """
        Create the VRT XML related to an outputDefinition
        :param outputDefinition:
        :return: str
        """

        # re.search(tmpXml, '<MDI key='>')

        # xml = copy.deepcopy(eTree)
        if self.masterVRT_XML is None:
            return None
        #xmlTree = ElementTree.fromstring(self.masterVRT_XML)
        xmlTree = copy.deepcopy(self.masterVRT_XML)

        # set metadata
        for elem in xmlTree.findall('Metadata/MDI'):
            if elem.attrib['key'] == 'acquisition date':
                elem.text = str(outputDefinition.mDate)

        # insert required rasterbands
        requiredBands = self.masterVRT_DateLookup[outputDefinition.mDate]

        xmlVRTBands = xmlTree.findall('VRTRasterBand')

        for t in requiredBands:
            stackIndex, stackBandIndex = t

            stackSourceXMLTemplate = copy.deepcopy(self.masterVRT_SourceBandTemplates[stackIndex])
            stackSourceXMLTemplate.find('SourceBand').text = str(stackBandIndex+1)
            xmlVRTBands[stackIndex].append(stackSourceXMLTemplate)

        if asElementTree:
            return xmlTree
        else:
            return ElementTree.tostring(xmlTree).decode('utf-8')





class StackedBandInputDialog(QDialog, loadUI('stackedinputdatadialog.ui')):

    def __init__(self, parent=None):

        super(StackedBandInputDialog, self).__init__(parent=parent)
        self.setupUi(self)
        self.setWindowTitle('Stacked Time Series Data Input')
        self.mWrittenFiles = []

        self.tableModelInputStacks = InputStackTableModel()
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.tableModelInputStacks.rowsInserted.connect(self.updateOutputs)
        self.tableModelInputStacks.dataChanged.connect(self.updateOutputs)
        self.tableModelInputStacks.rowsRemoved.connect(self.updateOutputs)
        self.tableModelInputStacks.rowsInserted.connect(self.updateInputInfo)
        self.tableModelInputStacks.rowsRemoved.connect(self.updateInputInfo)
        self.tableViewSourceStacks.setModel(self.tableModelInputStacks)
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        self.tableModelOutputImages = OutputImageModel()
        self.tableModelOutputImages.rowsInserted.connect(self.updateOutputInfo)
        self.tableModelOutputImages.rowsRemoved.connect(self.updateOutputInfo)
        self.tableModelOutputImages.dataChanged.connect(self.updateOutputInfo)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.tableViewOutputImages.setModel(self.tableModelOutputImages)
        self.tableViewOutputImages.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
        self.buttonGroupDateMode.buttonClicked.connect(self.updateOutputs)
        self.buttonGroupOutputLocation.buttonClicked.connect(self.updateOutputs)
        self.cbOpenInQGIS.setEnabled(isinstance(qgis.utils.iface, QgisInterface))
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.tbFilePrefix.textChanged.connect(self.tableModelOutputImages.setOutputPrefix)
        self.tbFilePrefix.setText('img')

        self.fileWidgetOutputDir.setStorageMode(QgsFileWidget.GetDirectory)
        self.fileWidgetOutputDir.fileChanged.connect(self.tableModelOutputImages.setOutputDir)

        sm = self.tableViewSourceStacks.selectionModel()
        assert isinstance(sm, QItemSelectionModel)
        sm.selectionChanged.connect(self.onSourceStackSelectionChanged)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.onSourceStackSelectionChanged([],[])
        sm = self.tableViewOutputImages.selectionModel()
        assert isinstance(sm, QItemSelectionModel)
        sm.selectionChanged.connect(self.onOutputImageSelectionChanged)

        self.initActions()

    def writtenFiles(self):
        """
        Returns the files written after pressing the "Save" button.
        :return: [list-of-written-file-paths]
        """
        return self.mWrittenFiles[:]
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def updateOutputs(self, *args):
        """
        Updates the output file information
        """
Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.tableModelOutputImages.clearOutputs()
        inputStacks = self.tableModelInputStacks.mStackImages
        datesTotal, datesIntersection = self.tableModelInputStacks.dateInfo()
        if self.rbDatesAll.isChecked():
            self.tableModelOutputImages.setMultiStackSources(inputStacks, datesTotal)
        elif self.rbDatesIntersection.isChecked():
            self.tableModelOutputImages.setMultiStackSources(inputStacks, datesIntersection)
        if self.rbSaveInMemory.isChecked():
            self.tableModelOutputImages.setOutputDir(r'/vsimem/')
        elif self.rbSaveInDirectory.isChecked():
            self.tableModelOutputImages.setOutputDir(self.fileWidgetOutputDir.filePath())
    def updateInputInfo(self):
        """
        Updates the input file information
        """

        n = len(self.tableModelInputStacks)
        datesTotal, datesInCommon = self.tableModelInputStacks.dateInfo()
        info = None
        if n > 0:
            nAll = len(datesTotal)
            nInt = len(datesInCommon)
            info = '{} Input Images with {} dates in total, {} in intersection'.format(n, nAll, nInt)

        self.tbInfoInputImages.setText(info)

    def updateOutputInfo(self):

        n = len(self.tableModelOutputImages)
        info = None
        if n > 0:
            nb = len(self.tableModelOutputImages.masterVRT_InputStacks)
            info = '{} output images with {} bands to {}'.format(n, nb, self.tableModelOutputImages.mOutputDir)
        self.buttonBox.button(QDialogButtonBox.Save).setEnabled(n > 0)
        self.tbInfoOutputImages.setText(info)

    def initActions(self):
        """
        Initializes QActions and what they trigger.
        """

        self.actionAddSourceStack.triggered.connect(self.onAddSource)
        self.actionRemoveSourceStack.triggered.connect(self.onRemoveSources)

        self.btnAddSourceStack.setDefaultAction(self.actionAddSourceStack)
        self.btnRemoveSourceStack.setDefaultAction(self.actionRemoveSourceStack)

        self.buttonBox.button(QDialogButtonBox.Save).clicked.connect(self.accept)
        self.buttonBox.button(QDialogButtonBox.Cancel).clicked.connect(self.close)

    def onAddSource(self, *args):
        """
        Reacts on new added datasets
        """
        import eotimeseriesviewer.settings
        defDir = eotimeseriesviewer.settings.value(eotimeseriesviewer.settings.Keys.RasterSourceDirectory)
        filters = QgsProviderRegistry.instance().fileVectorFilters()
        files, filter = QFileDialog.getOpenFileNames(directory=defDir, filter=filters)

        if len(files) > 0:
            self.tableModelInputStacks.insertSources(files)
        s = ""


    def addSources(self, paths):
        """
        Adds new datasources
        :param paths: [list-of-new-datasources]
        :return:
        """
        self.tableModelInputStacks.insertSources(paths)

    def onRemoveSources(self, *args):
Benjamin Jakimow's avatar
Benjamin Jakimow committed

        model = self.tableViewSourceStacks.selectionModel()
        assert isinstance(model, QItemSelectionModel)

        infos = [self.tableModelInputStacks.index2info(idx) for idx in model.selectedRows()]
        self.tableModelInputStacks.removeSources(infos)

    def onSourceStackSelectionChanged(self, selected, deselected):

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        self.actionRemoveSourceStack.setEnabled(len(selected) > 0)

    def onOutputImageSelectionChanged(self, selected, deselected):

        if len(selected) > 0:
            idx = selected.indexes()[0]

            vrtOutput = self.tableModelOutputImages.index2vrt(idx)
            assert isinstance(vrtOutput, OutputVRTDescription)
            xml = self.tableModelOutputImages.vrtXML(vrtOutput)
            self.tbXMLPreview.setPlainText(xml)
        else:
            self.tbXMLPreview.setPlainText(None)
            s = ""

    def saveImages(self):
        """
        Write the VRT images
        :return: [list-of-written-file-paths]
        """
        nTotal = len(self.tableModelOutputImages)
        if nTotal == 0:
            return
        writtenFiles = []
        self.progressBar.setValue(0)
        from eotimeseriesviewer.virtualrasters import write_vsimem, read_vsimem
        for i, outVRT in enumerate(self.tableModelOutputImages):
            assert isinstance(outVRT, OutputVRTDescription)
            xml = self.tableModelOutputImages.vrtXML(outVRT)
            if outVRT.mPath.startswith('/vsimem/'):
                write_vsimem(outVRT.mPath, xml)
            else:
                f = open(outVRT.mPath, 'w', encoding='utf-8')
                f.write(xml)
                f.flush()
                f.close()
            writtenFiles.append(outVRT.mPath)
            self.progressBar.setValue(int(100. * i / nTotal))
        QTimer.singleShot(500, lambda: self.progressBar.setValue(0))
        if self.cbOpenInQGIS.isEnabled() and self.cbOpenInQGIS.isChecked():
            mapLayers = [QgsRasterLayer(p) for p in writtenFiles]
            QgsProject.instance().addMapLayers(mapLayers, addToLegend=True)
        self.mWrittenFiles.extend(writtenFiles)
        return writtenFiles