Skip to content
Snippets Groups Projects
stackedbandinput.py 28.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • # -*- coding: utf-8 -*-
    # noinspection PyPep8Naming
    """
    ***************************************************************************
        stackedbandinput.py
    
        Sometimes time-series-data is written out as stacked band images, having one observation per band.
        This module helps to use such data as EOTS input.
        ---------------------
        Date                 : June 2018
        Copyright            : (C) 2018 by Benjamin Jakimow
        Email                : benjamin.jakimow@geo.hu-berlin.de
    ***************************************************************************
    *                                                                         *
    *   This program is free software; you can redistribute it and/or modify  *
    *   it under the terms of the GNU General Public License as published by  *
    *   the Free Software Foundation; either version 2 of the License, or     *
    *   (at your option) any later version.                                   *
    *                                                                         *
    ***************************************************************************
    """
    
    import os, re, tempfile, pickle, copy, shutil, locale, uuid, csv, io
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    from xml.etree import ElementTree
    
    from collections import OrderedDict
    from qgis.core import *
    from qgis.gui import *
    from qgis.utils import qgsfunction
    from qgis.PyQt.QtCore import *
    from qgis.PyQt.QtGui import *
    from qgis.PyQt.QtWidgets import *
    from qgis.core import QgsField, QgsFields, QgsFeature, QgsMapLayer, QgsVectorLayer, QgsConditionalStyle
    from qgis.gui import QgsMapCanvas, QgsDockWidget
    from pyqtgraph.widgets.PlotWidget import PlotWidget
    from pyqtgraph.graphicsItems.PlotDataItem import PlotDataItem
    from pyqtgraph.graphicsItems.PlotItem import PlotItem
    import pyqtgraph.functions as fn
    import numpy as np
    from osgeo import gdal, gdal_array
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    import numpy as np
    
    from eotimeseriesviewer.utils import *
    from eotimeseriesviewer.virtualrasters import *
    from eotimeseriesviewer.models import *
    from eotimeseriesviewer.dateparser import *
    from eotimeseriesviewer.plotstyling import PlotStyle, PlotStyleDialog, MARKERSYMBOLS2QGIS_SYMBOLS
    import eotimeseriesviewer.mimedata as mimedata
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
    def datesFromDataset(dataset:gdal.Dataset)->list:
    
        nb = dataset.RasterCount
    
        def checkDates(dateList):
            if not len(dateList) == nb:
                return False
            for d in dateList:
                if not isinstance(d, np.datetime64):
                    return False
            return True
    
        searchedKeys = []
        searchedKeys.append(re.compile('aquisition[ ]*dates$', re.I))
        searchedKeys.append(re.compile('observation[ ]*dates$', re.I))
        searchedKeys.append(re.compile('wavelength$', re.I))
    
        #1. Check Metadata
        for domain in dataset.GetMetadataDomainList():
            domainData = dataset.GetMetadata_Dict(domain)
            assert isinstance(domainData, dict)
    
            for key, values in domainData.items():
                for regex in searchedKeys:
                    if regex.search(key.strip()):
                        values = re.sub('[{}]', '', values)
                        values = values.split(',')
                        dateValues = [extractDateTimeGroup(t) for t in values]
                        if checkDates(dateValues):
                            return dateValues
    
    
        #2. Check Band Names
        bandDates = [extractDateTimeGroup(dataset.GetRasterBand(b+1).GetDescription()) for b in range(nb)]
        bandDates = [b for b in bandDates if isinstance(b, np.datetime64)]
        if checkDates(bandDates):
            return bandDates
    
        return []
    
    
    class InputStackInfo(object):
    
        def __init__(self, dataset):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            if isinstance(dataset, str):
                #test ENVI header first
                basename = os.path.splitext(dataset)[0]
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                if os.path.isfile(basename+'.hdr'):
                    ds = gdal.OpenEx(dataset, allowed_drivers=['ENVI'])
                if not isinstance(ds, gdal.Dataset):
                    ds = gdal.Open(dataset)
                if not isinstance(ds, gdal.Dataset):
                    raise Exception('Unable to open {}'.format(dataset))
    
                dataset = ds
                del ds
    
    
            assert isinstance(dataset, gdal.Dataset)
    
            self.mMetadataDomains = dataset.GetMetadataDomainList()
            self.mMetaData = OrderedDict()
    
            for domain in self.mMetadataDomains:
                self.mMetaData[domain] = dataset.GetMetadata_Dict(domain)
    
            self.ns = dataset.RasterXSize
            self.nl = dataset.RasterYSize
            self.nb = dataset.RasterCount
    
            self.wkt = dataset.GetProjection()
            self.gt = dataset.GetGeoTransform()
    
    
            self.colorTable = dataset.GetRasterBand(1).GetColorTable()
            self.classNames = dataset.GetRasterBand(1).GetCategoryNames()
    
    
            self.path = dataset.GetFileList()[0]
    
    
            self.outputBandName = os.path.basename(self.path)
            if len(self.outputBandName) == 0:
                self.outputBandName = ''
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            self.bandnames = []
            self.nodatavalues = []
            for b in range(self.nb):
                band = dataset.GetRasterBand(b+1)
                assert isinstance(band, gdal.Band)
                self.bandnames.append(band.GetDescription())
                self.nodatavalues.append(band.GetNoDataValue())
    
    
            self.mDates = datesFromDataset(dataset)
    
        def __len__(self):
            return len(self.mDates)
    
        def dates(self)->list:
            """Returns a list of dates"""
            return self.mDates
    
    
    
        def structure(self):
            return (self.ns, self.nl, self.nb, self.gt, self.wkt)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def wavelength(self):
            return self.mMetaData[''].get('wavelength')
    
        def setWavelength(self, wl):
            self.mMetaData['']['wavelength'] = wl
    
    
    class OutputVRTDescription(object):
        """
        Descrbies an output VRT
        """
    
        def __init__(self, path:str, date:np.datetime64):
            super(OutputVRTDescription, self).__init__()
            self.mPath = path
            self.mDate = date
    
    
        def setPath(self, path:str):
            self.mPath = path
    
    
    class InputStackTableModel(QAbstractTableModel):
    
    
    
        def __init__(self, parent=None):
    
            super(InputStackTableModel, self).__init__(parent)
            self.mStackImages = []
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.cn_source = 'Source'
            self.cn_dates = 'Dates'
            self.cn_crs = 'GT + CRS'
            self.cn_ns = 'ns'
            self.cn_nl = 'nl'
            self.cn_nb = 'nb'
            self.cn_name = 'Band Name'
    
            self.cn_wl = 'Wavelength'
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.mColumnNames = [self.cn_source, self.cn_dates, self.cn_name, self.cn_wl, self.cn_ns, self.cn_nl, self.cn_nb, self.cn_crs]
    
            self.mColumnTooltips = {}
            self.mColumnTooltips[self.cn_source] = 'Stack source uri / file path'
            self.mColumnTooltips[self.cn_crs] = 'Geo-Transformation + Coordinate Reference System'
            self.mColumnTooltips[self.cn_ns] = 'Number of samples / pixel in horizontal direction'
            self.mColumnTooltips[self.cn_nl] = 'Number of lines / pixel in vertical direction'
            self.mColumnTooltips[self.cn_nb] = 'Number of bands'
            self.mColumnTooltips[self.cn_name] = 'Prefix of band name in output image'
            self.mColumnTooltips[self.cn_wl] = 'Wavelength in output image'
            self.mColumnTooltips[self.cn_dates] = 'Identified dates'
    
        def __len__(self):
            return len(self.mStackImages)
    
        def __iter__(self):
            return iter(self.mStackImages)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def columnName(self, i) -> str:
    
            if isinstance(i, QModelIndex):
                i = i.column()
            return self.mColumnNames[i]
    
    
    
    
        def dateInfo(self):
            """
            Returns a list with all extracted dates and a list of date in common between all datasets
            :return: [all dates], [dates in common]
            """
            if len(self) == 0:
                return [],[]
            datesTotal = set()
            datesInCommon = None
            for i, f in enumerate(self.mStackImages):
                assert isinstance(f, InputStackInfo)
    
                dates = f.dates()
                if datesInCommon is None:
                    datesInCommon = set(dates)
                else:
                    datesInCommon = datesInCommon.intersection(dates)
    
                datesTotal = datesTotal.union(f.dates())
    
            return sorted(list(datesTotal)), sorted(list(datesInCommon))
    
    
        def flags(self, index):
            if index.isValid():
                columnName = self.columnName(index)
                flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable
                if columnName in [self.cn_name, self.cn_wl]: #allow check state
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    flags = flags | Qt.ItemIsEditable
    
    
                return flags
                #return item.qt_flags(index.column())
            return None
    
        def headerData(self, col, orientation, role):
            if Qt is None:
                return None
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            if orientation == Qt.Horizontal:
                cname = self.mColumnNames[col]
                if role == Qt.DisplayRole:
                    return cname
                elif role == Qt.ToolTipRole:
                    return self.mColumnTooltips.get(cname)
    
            elif orientation == Qt.Vertical and role == Qt.DisplayRole:
                return col
            return None
    
    
        def rowCount(self, parent=None):
            return len(self.mStackImages)
    
        def columnCount(self, parent: QModelIndex):
            return len(self.mColumnNames)
    
        def insertSources(self, paths, i=None):
    
            """
            Inserts new datasources
            :param paths: [list-of-datasources]
            :param i: index where to add the first datasource.
            """
    
    
            if i == None:
                i = self.rowCount()
    
            if not isinstance(paths, list):
                paths = [paths]
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            infos = [InputStackInfo(p) for p in paths]
    
            if len(infos) > 0:
    
                self.beginInsertRows(QModelIndex(), i, i+len(infos)-1)
                for j, info in enumerate(infos):
    
                    assert isinstance(info, InputStackInfo)
                    if len(info.outputBandName) == 0:
                        info.outputBandName = 'Band {}'.format(i+j+1)
    
                    self.mStackImages.insert(i+j, info)
                self.endInsertRows()
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def removeSources(self, stackInfos:list):
    
            for stackInfo in stackInfos:
                assert stackInfo in self.mStackImages
    
            for stackInfo in stackInfos:
                assert isinstance(stackInfo, InputStackInfo)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                idx = self.info2index(stackInfo)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                self.beginRemoveRows(QModelIndex(), idx.row(), idx.row())
                self.mStackImages.remove(stackInfo)
                self.endRemoveRows()
    
    
        def isValid(self):
            l = len(self.mStackImages)
            if l == 0:
                return False
            ref = self.mStackImages[0]
            assert isinstance(ref, InputStackInfo)
    
            #all input stacks need to have the same characteristic
            for stackInfo in self.mStackImages[1:]:
                assert isinstance(stackInfo, InputStackInfo)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                if not ref.dates() == stackInfo.dates():
                    return False
    
                if not ref.structure() == stackInfo.structure():
                    return False
            return True
    
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def index2info(self, index:QModelIndex) -> InputStackInfo:
            return self.mStackImages[index.row()]
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def info2index(self, info:InputStackInfo) -> QModelIndex:
            r = self.mStackImages.index(info)
            return self.createIndex(r,0, info)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def data(self, index: QModelIndex, role: int):
    
            if not index.isValid():
                return None
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            info = self.mStackImages[index.row()]
    
            assert isinstance(info, InputStackInfo)
            cname = self.columnName(index)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            if role in [Qt.DisplayRole, Qt.ToolTipRole]:
                if cname == self.cn_source:
    
                    return info.path
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                if cname == self.cn_dates:
                    dates = info.dates()
                    if role == Qt.DisplayRole:
                        return len(dates)
                    if role == Qt.ToolTipRole:
                        if len(dates) == 0:
                            return 'No dates identified. Can not use this image as input'
                        else:
                            if len(dates) > 11:
                                dates = dates[0:10] + ['...']
                            return '\n'.join([str(d) for d in dates])
    
    
                if cname == self.cn_ns:
                    return info.ns
                if cname == self.cn_nl:
                    return info.nl
                if cname == self.cn_nb:
                    return info.nb
                if cname == self.cn_crs:
                    return '{} {}'.format(info.gt, info.wkt)
    
                elif cname == self.cn_wl:
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    return info.mMetaData[''].get('wavelength')
                elif cname == self.cn_name:
                    return info.outputBandName
    
            if role == Qt.EditRole:
                if cname == self.cn_wl:
                    return info.mMetaData[''].get('wavelength')
                elif cname == self.cn_name:
                    return info.outputBandName
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            if role == Qt.BackgroundColorRole:
                if cname in [self.cn_name, self.cn_wl]:
                    return QColor('yellow')
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def setData(self, index: QModelIndex, value, role: int):
    
            if not index.isValid():
                return None
    
            info = self.index2info(index)
            cname = self.columnName(index)
    
            changed = False
            if role == Qt.EditRole:
                if cname == self.cn_name:
    
                    if isinstance(value, str) and len(value) > 0:
                        info.outputBandName = value
                        changed = True
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                elif cname == self.cn_wl:
                    info.setWavelength(value)
                    changed = True
            if changed:
                self.dataChanged.emit(index, index)
            return changed
    
    
    class OutputImageModel(QAbstractTableModel):
    
        def __init__(self, parent=None):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            super(OutputImageModel, self).__init__(parent)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.cn_uri = 'Path'
            self.cn_date = 'Date'
            self.mOutputImages = []
    
            self.mColumnNames = [self.cn_date, self.cn_uri]
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.mColumnTooltips = {}
            self.mColumnTooltips[self.cn_uri] = 'Output location'
            self.masterVRT_DateLookup = {}
            self.masterVRT_SourceBandTemplates = {}
            self.masterVRT_InputStacks = None
            self.masterVRT_XML = None
            self.mOutputDir = '/vsimem/'
            self.mOutputPrefix = 'date'
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def headerData(self, col, orientation, role):
            if Qt is None:
                return None
            if orientation == Qt.Horizontal:
                cname = self.mColumnNames[col]
                if role == Qt.DisplayRole:
                    return cname
                elif role == Qt.ToolTipRole:
                    return self.mColumnTooltips.get(cname)
            elif orientation == Qt.Vertical and role == Qt.DisplayRole:
                return col
            return None
    
        def createVRTUri(self, date:np.datetime64):
    
            path = os.path.join(self.mOutputDir, self.mOutputPrefix)
            path = '{}{}.vrt'.format(path, date)
    
            return path
    
        def clearOutputs(self):
            self.beginRemoveRows(QModelIndex(), 0, self.rowCount() - 1)
            self.mOutputImages = []
            self.endRemoveRows()
    
    
        def setMultiStackSources(self, listOfInputStacks:list, dates:list):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            self.clearOutputs()
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            if listOfInputStacks is None or len(listOfInputStacks) == 0:
                return
    
            if dates is None or len(dates) == 0:
                return
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            for s in listOfInputStacks:
                assert isinstance(s, InputStackInfo)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            listOfInputStacks = [s for s in listOfInputStacks if len(s) > 0]
            numberOfOutputVRTBands = len(listOfInputStacks)
            self.masterVRT_DateLookup.clear()
            self.masterVRT_InputStacks = listOfInputStacks
            self.masterVRT_SourceBandTemplates.clear()
    
            #dates = set()
            #for s in listOfInputStacks:
            #    for d in s.dates():
            #        dates.add(d)
            #dates = sorted(list(dates))
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            #create a LUT to get the stack indices for a related date (not each stack might contain a band for each date)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            for stackIndex, s in enumerate(listOfInputStacks):
    
                for bandIndex, bandDate in enumerate(s.dates()):
    
                    if bandDate not in self.masterVRT_DateLookup.keys():
                        self.masterVRT_DateLookup[bandDate] = []
    
                    self.masterVRT_DateLookup[bandDate].append((stackIndex, bandIndex))
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            #create VRT Template XML
            VRT = VRTRaster()
            wavelength = []
            for stackIndex, stack in enumerate(listOfInputStacks):
                assert isinstance(stack, InputStackInfo)
                vrtBand = VRTRasterBand()
                vrtBand.setName(stack.outputBandName)
                vrtSrc = VRTRasterInputSourceBand(stack.path, 0)
                vrtBand.addSource(vrtSrc)
                wavelength.append(stack.wavelength())
                VRT.addVirtualBand(vrtBand)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            pathVSITmp = '/vsimem/temp.vrt'
            dsVRT = VRT.saveVRT(pathVSITmp)
            dsVRT.SetMetadataItem('acquisition date', 'XML_REPLACE_DATE')
    
            if None not in wavelength:
                dsVRT.SetMetadataItem('wavelength', ','.join(str(wl) for wl in wavelength))
                dsVRT.SetMetadataItem('wavelength units', 'Nanometers')
    
    
            for stackIndex, stack in enumerate(listOfInputStacks):
                band = dsVRT.GetRasterBand(stackIndex+1)
                assert isinstance(band, gdal.Band)
                assert isinstance(stack, InputStackInfo)
    
                if isinstance(stack.colorTable, gdal.ColorTable) and stack.colorTable.GetCount() > 0:
    
                    band.SetColorTable(stack.colorTable)
                if stack.classNames:
                    band.SetCategoryNames(stack.classNames)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            dsVRT.FlushCache()
            drv = dsVRT.GetDriver()
            masterVRT_XML = read_vsimem(pathVSITmp).decode('utf-8')
            drv.Delete(pathVSITmp)
            outputVRTs = []
    
    
            eTree = ElementTree.fromstring(masterVRT_XML)
            for iBand, elemBand in enumerate(eTree.findall('VRTRasterBand')):
                sourceElements  = elemBand.findall('ComplexSource') + elemBand.findall('SimpleSource')
                assert len(sourceElements) == 1
                self.masterVRT_SourceBandTemplates[iBand] = copy.deepcopy(sourceElements[0])
                elemBand.remove(sourceElements[0])
    
            for date in dates:
                assert isinstance(date, np.datetime64)
                path = self.createVRTUri(date)
                outputDescription = OutputVRTDescription(path, date)
                outputVRTs.append(outputDescription)
    
            self.masterVRT_XML = eTree
    
    
            self.beginInsertRows(QModelIndex(), 0, len(outputVRTs)-1)
            self.mOutputImages = outputVRTs[:]
            self.endInsertRows()
    
        def setOutputDir(self, path:str):
            self.mOutputDir = path
            self.updateOutputURIs()
    
        def setOutputPrefix(self, basename:str):
            self.mOutputPrefix = basename
            self.updateOutputURIs()
    
        def updateOutputURIs(self):
            c = self.mColumnNames.index(self.cn_uri)
            ul = self.createIndex(0, c)
            lr = self.createIndex(self.rowCount()-1, c)
    
            for outputVRT in self:
                assert isinstance(outputVRT, OutputVRTDescription)
                outputVRT.setPath(self.createVRTUri(outputVRT.mDate))
            self.dataChanged.emit(ul, lr)
    
    
        def __len__(self):
            return len(self.mOutputImages)
    
        def __iter__(self):
            return iter(self.mOutputImages)
    
        def rowCount(self, parent=None) -> int:
            return len(self.mOutputImages)
    
        def columnCount(self, parent=None) -> int:
            return len(self.mColumnNames)
    
        def columnName(self, i) -> str:
            if isinstance(i, QModelIndex):
                i = i.column()
            return self.mColumnNames[i]
    
        def columnIndex(self, columnName:str)-> QModelIndex:
            c = self.mColumnNames.index(columnName)
            return self.createIndex(0, c)
    
        def index2vrt(self, index:QModelIndex) -> OutputVRTDescription:
            return self.mOutputImages[index.row()]
    
        def vrt2index(self, vrt:OutputVRTDescription) -> QModelIndex:
            i = self.mOutputImages[vrt]
            return self.createIndex(i, 0, vrt)
    
        def data(self, index: QModelIndex, role: int):
    
            if not index.isValid():
                return None
    
            cname = self.columnName(index)
            vrt = self.index2vrt(index)
            if role in [Qt.DisplayRole, Qt.ToolTipRole]:
                if cname == self.cn_uri:
                    return vrt.mPath
                if cname == self.cn_date:
                    return str(vrt.mDate)
    
        def vrtXML(self, outputDefinition:OutputVRTDescription, asElementTree=False) -> str:
            """
            Create the VRT XML related to an outputDefinition
            :param outputDefinition:
            :return: str
            """
    
            # re.search(tmpXml, '<MDI key='>')
    
            # xml = copy.deepcopy(eTree)
            if self.masterVRT_XML is None:
                return None
            #xmlTree = ElementTree.fromstring(self.masterVRT_XML)
            xmlTree = copy.deepcopy(self.masterVRT_XML)
    
            # set metadata
            for elem in xmlTree.findall('Metadata/MDI'):
                if elem.attrib['key'] == 'acquisition date':
                    elem.text = str(outputDefinition.mDate)
    
            # insert required rasterbands
            requiredBands = self.masterVRT_DateLookup[outputDefinition.mDate]
    
            xmlVRTBands = xmlTree.findall('VRTRasterBand')
    
            for t in requiredBands:
                stackIndex, stackBandIndex = t
    
                stackSourceXMLTemplate = copy.deepcopy(self.masterVRT_SourceBandTemplates[stackIndex])
                stackSourceXMLTemplate.find('SourceBand').text = str(stackBandIndex+1)
                xmlVRTBands[stackIndex].append(stackSourceXMLTemplate)
    
            if asElementTree:
                return xmlTree
            else:
                return ElementTree.tostring(xmlTree).decode('utf-8')
    
    
    
    
    
    
    class StackedBandInputDialog(QDialog, loadUI('stackedinputdatadialog.ui')):
    
        def __init__(self, parent=None):
    
            super(StackedBandInputDialog, self).__init__(parent=parent)
            self.setupUi(self)
    
            self.setWindowTitle('Stacked Time Series Data Input')
    
            self.mWrittenFiles = []
    
    
            self.tableModelInputStacks = InputStackTableModel()
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.tableModelInputStacks.rowsInserted.connect(self.updateOutputs)
            self.tableModelInputStacks.dataChanged.connect(self.updateOutputs)
            self.tableModelInputStacks.rowsRemoved.connect(self.updateOutputs)
    
            self.tableModelInputStacks.rowsInserted.connect(self.updateInputInfo)
            self.tableModelInputStacks.rowsRemoved.connect(self.updateInputInfo)
    
            self.tableViewSourceStacks.setModel(self.tableModelInputStacks)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            self.tableModelOutputImages = OutputImageModel()
    
            self.tableModelOutputImages.rowsInserted.connect(self.updateOutputInfo)
            self.tableModelOutputImages.rowsRemoved.connect(self.updateOutputInfo)
            self.tableModelOutputImages.dataChanged.connect(self.updateOutputInfo)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.tableViewOutputImages.setModel(self.tableModelOutputImages)
    
            self.tableViewOutputImages.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
    
            self.buttonGroupDateMode.buttonClicked.connect(self.updateOutputs)
            self.buttonGroupOutputLocation.buttonClicked.connect(self.updateOutputs)
    
            self.cbOpenInQGIS.setEnabled(isinstance(qgis.utils.iface, QgisInterface))
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.tbFilePrefix.textChanged.connect(self.tableModelOutputImages.setOutputPrefix)
            self.tbFilePrefix.setText('img')
    
    
            self.fileWidgetOutputDir.setStorageMode(QgsFileWidget.GetDirectory)
            self.fileWidgetOutputDir.fileChanged.connect(self.tableModelOutputImages.setOutputDir)
    
    
            sm = self.tableViewSourceStacks.selectionModel()
            assert isinstance(sm, QItemSelectionModel)
            sm.selectionChanged.connect(self.onSourceStackSelectionChanged)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.onSourceStackSelectionChanged([],[])
    
            sm = self.tableViewOutputImages.selectionModel()
            assert isinstance(sm, QItemSelectionModel)
            sm.selectionChanged.connect(self.onOutputImageSelectionChanged)
    
    
            self.initActions()
    
    
        def writtenFiles(self):
            """
            Returns the files written after pressing the "Save" button.
            :return: [list-of-written-file-paths]
            """
            return self.mWrittenFiles[:]
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def updateOutputs(self, *args):
    
            """
            Updates the output file information
            """
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.tableModelOutputImages.clearOutputs()
            inputStacks = self.tableModelInputStacks.mStackImages
    
            datesTotal, datesIntersection = self.tableModelInputStacks.dateInfo()
            if self.rbDatesAll.isChecked():
                self.tableModelOutputImages.setMultiStackSources(inputStacks, datesTotal)
            elif self.rbDatesIntersection.isChecked():
                self.tableModelOutputImages.setMultiStackSources(inputStacks, datesIntersection)
    
            if self.rbSaveInMemory.isChecked():
    
                self.tableModelOutputImages.setOutputDir(r'/vsimem/')
    
            elif self.rbSaveInDirectory.isChecked():
                self.tableModelOutputImages.setOutputDir(self.fileWidgetOutputDir.filePath())
    
        def updateInputInfo(self):
    
            """
            Updates the input file information
            """
    
    
            n = len(self.tableModelInputStacks)
            datesTotal, datesInCommon = self.tableModelInputStacks.dateInfo()
            info = None
            if n > 0:
                nAll = len(datesTotal)
                nInt = len(datesInCommon)
                info = '{} Input Images with {} dates in total, {} in intersection'.format(n, nAll, nInt)
    
            self.tbInfoInputImages.setText(info)
    
        def updateOutputInfo(self):
    
            n = len(self.tableModelOutputImages)
            info = None
            if n > 0:
    
                nb = len(self.tableModelOutputImages.masterVRT_InputStacks)
                info = '{} output images with {} bands to {}'.format(n, nb, self.tableModelOutputImages.mOutputDir)
            self.buttonBox.button(QDialogButtonBox.Save).setEnabled(n > 0)
    
            self.tbInfoOutputImages.setText(info)
    
    
        def initActions(self):
    
            """
            Initializes QActions and what they trigger.
            """
    
    
            self.actionAddSourceStack.triggered.connect(self.onAddSource)
            self.actionRemoveSourceStack.triggered.connect(self.onRemoveSources)
    
            self.btnAddSourceStack.setDefaultAction(self.actionAddSourceStack)
            self.btnRemoveSourceStack.setDefaultAction(self.actionRemoveSourceStack)
    
    
            self.buttonBox.button(QDialogButtonBox.Save).clicked.connect(self.accept)
            self.buttonBox.button(QDialogButtonBox.Cancel).clicked.connect(self.close)
    
    
        def onAddSource(self, *args):
    
            """
            Reacts on new added datasets
            """
    
            import eotimeseriesviewer.settings
            defDir = eotimeseriesviewer.settings.value(eotimeseriesviewer.settings.Keys.RasterSourceDirectory)
    
            filters = QgsProviderRegistry.instance().fileVectorFilters()
            files, filter = QFileDialog.getOpenFileNames(directory=defDir, filter=filters)
    
            if len(files) > 0:
                self.tableModelInputStacks.insertSources(files)
            s = ""
    
    
    
        def addSources(self, paths):
    
            """
            Adds new datasources
            :param paths: [list-of-new-datasources]
            :return:
            """
    
            self.tableModelInputStacks.insertSources(paths)
    
        def onRemoveSources(self, *args):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
            model = self.tableViewSourceStacks.selectionModel()
            assert isinstance(model, QItemSelectionModel)
    
            infos = [self.tableModelInputStacks.index2info(idx) for idx in model.selectedRows()]
            self.tableModelInputStacks.removeSources(infos)
    
    
        def onSourceStackSelectionChanged(self, selected, deselected):
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            self.actionRemoveSourceStack.setEnabled(len(selected) > 0)
    
    
        def onOutputImageSelectionChanged(self, selected, deselected):
    
            if len(selected) > 0:
                idx = selected.indexes()[0]
    
                vrtOutput = self.tableModelOutputImages.index2vrt(idx)
                assert isinstance(vrtOutput, OutputVRTDescription)
                xml = self.tableModelOutputImages.vrtXML(vrtOutput)
                self.tbXMLPreview.setPlainText(xml)
            else:
                self.tbXMLPreview.setPlainText(None)
                s = ""
    
    
        def saveImages(self):
            """
            Write the VRT images
            :return: [list-of-written-file-paths]
            """
    
            nTotal = len(self.tableModelOutputImages)
            if nTotal == 0:
                return
    
            writtenFiles = []
            self.progressBar.setValue(0)
    
            from eotimeseriesviewer.virtualrasters import write_vsimem, read_vsimem
    
            for i, outVRT in enumerate(self.tableModelOutputImages):
                assert isinstance(outVRT, OutputVRTDescription)
                xml = self.tableModelOutputImages.vrtXML(outVRT)
    
                if outVRT.mPath.startswith('/vsimem/'):
                    write_vsimem(outVRT.mPath, xml)
                else:
                    f = open(outVRT.mPath, 'w', encoding='utf-8')
                    f.write(xml)
                    f.flush()
                    f.close()
    
                writtenFiles.append(outVRT.mPath)
    
                self.progressBar.setValue(int(100. * i / nTotal))
    
            QTimer.singleShot(500, lambda: self.progressBar.setValue(0))
    
            if self.cbOpenInQGIS.isEnabled() and self.cbOpenInQGIS.isChecked():
                mapLayers = [QgsRasterLayer(p) for p in writtenFiles]
    
                QgsProject.instance().addMapLayers(mapLayers, addToLegend=True)
    
            self.mWrittenFiles.extend(writtenFiles)
            return writtenFiles