Skip to content
Snippets Groups Projects
virtualrasters.py 67.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • # -*- coding: utf-8 -*-
    """
    /***************************************************************************
                                  HUB TimeSeriesViewer
                                  -------------------
            begin                : 2015-08-20
            git sha              : $Format:%H$
            copyright            : (C) 2017 by HU-Berlin
            email                : benjamin.jakimow@geo.hu-berlin.de
     ***************************************************************************/
    
    /***************************************************************************
     *                                                                         *
     *   This program is free software; you can redistribute it and/or modify  *
     *   it under the terms of the GNU General Public License as published by  *
     *   the Free Software Foundation; either version 2 of the License, or     *
     *   (at your option) any later version.                                   *
     *                                                                         *
     ***************************************************************************/
    """
    # noinspection PyPep8Naming
    from __future__ import absolute_import
    
    import os, sys, re, pickle, tempfile
    from collections import OrderedDict
    
    from qgis.core import *
    
    from PyQt4.QtCore import *
    from PyQt4.QtGui import *
    
    def px2geo(px, gt):
        #see http://www.gdal.org/gdal_datamodel.html
        gx = gt[0] + px.x()*gt[1]+px.y()*gt[2]
        gy = gt[3] + px.x()*gt[4]+px.y()*gt[5]
        return QgsPoint(gx,gy)
    
    
    class VRTRasterInputSourceBand(object):
        def __init__(self, path, bandIndex, bandName=''):
            self.mPath = os.path.normpath(path)
            self.mBandIndex = bandIndex
            self.mBandName = bandName
            self.mNoData = None
    
            self.mVirtualBand = None
    
        def isEqual(self, other):
            if isinstance(other, VRTRasterInputSourceBand):
                return self.mPath == other.mPath and self.mBandIndex == other.mBandIndex
    
            else:
                return False
    
        def __reduce_ex__(self, protocol):
    
            return self.__class__, (self.mPath, self.mBandIndex, self.mBandName), self.__getstate__()
    
        def __getstate__(self):
            state = self.__dict__.copy()
            state.pop('mVirtualBand')
            return state
    
        def __setstate__(self, state):
            self.__dict__.update(state)
    
    
        def virtualBand(self):
            return self.mVirtualBand
    
    class VRTRasterBand(QObject):
        sigNameChanged = pyqtSignal(str)
        sigSourceInserted = pyqtSignal(int, VRTRasterInputSourceBand)
        sigSourceRemoved = pyqtSignal(int, VRTRasterInputSourceBand)
    
        def __init__(self, name='', parent=None):
    
        def setName(self, name):
            oldName = self.mName
            self.mName = name
            if oldName != self.mName:
                self.sigNameChanged.emit(name)
    
        def name(self):
            return self.mName
    
        def addSource(self, virtualBandInputSource):
            assert isinstance(virtualBandInputSource, VRTRasterInputSourceBand)
            self.insertSource(len(self.sources), virtualBandInputSource)
    
        def insertSource(self, index, virtualBandInputSource):
            assert isinstance(virtualBandInputSource, VRTRasterInputSourceBand)
    
            virtualBandInputSource.mVirtualBand = self
            assert index <= len(self.sources)
            self.sources.insert(index, virtualBandInputSource)
    
            self.sigSourceInserted.emit(index, virtualBandInputSource)
    
    
        def bandIndex(self):
    
            if isinstance(self.mVRT, VRTRaster):
                return self.mVRT.mBands.index(self)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
    
            Removes a VRTRasterInputSourceBand
            :param vrtRasterInputSourceBand: band index| VRTRasterInputSourceBand
            :return: The VRTRasterInputSourceBand that was removed
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
    
            if not isinstance(vrtRasterInputSourceBand, VRTRasterInputSourceBand):
                vrtRasterInputSourceBand = self.sources[vrtRasterInputSourceBand]
            if vrtRasterInputSourceBand in self.sources:
                i = self.sources.index(vrtRasterInputSourceBand)
                self.sources.remove(vrtRasterInputSourceBand)
                self.sigSourceRemoved.emit(i, vrtRasterInputSourceBand)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            :return: list of file-paths to all source files
            """
    
            files = set([inputSource.mPath for inputSource in self.sources])
    
            return sorted(list(files))
    
        def __repr__(self):
            infos = ['VirtualBand name="{}"'.format(self.mName)]
            for i, info in enumerate(self.sources):
    
                assert isinstance(info, VRTRasterInputSourceBand)
                infos.append('\t{} SourceFileName {} SourceBand {}'.format(i + 1, info.mPath, info.mBandIndex))
    
    LUT_ReampleAlg = {'nearest': gdal.GRA_NearestNeighbour,
                      'bilinear': gdal.GRA_Bilinear,
                      'mode':gdal.GRA_Mode,
                      'lanczos':gdal.GRA_Lanczos,
                      'average':gdal.GRA_Average,
                      'cubic':gdal.GRA_Cubic,
                      'cubic_splie':gdal.GRA_CubicSpline}
    
    class VRTRasterPreviewMapCanvas(QgsMapCanvas):
    
        def __init__(self, parent=None, *args, **kwds):
            super(VRTRasterPreviewMapCanvas, self).__init__(parent, *args, **kwds)
    
    
        def contextMenuEvent(self,  event):
            menu = QMenu()
            action = menu.addAction('Refresh')
            action.triggered.connect(self.refresh)
    
            action = menu.addAction('Reset')
            action.triggered.connect(self.reset)
    
            menu.exec_(event.globalPos())
    
        def setLayerSet(self, layers):
            raise DeprecationWarning()
    
        def setLayers(self, layers):
            assert isinstance(layers, list)
            def area(layer):
                extent = layer.extent()
                return extent.width() * extent.height()
            layers = list(sorted(layers, key = lambda lyr: area(lyr), reverse=True))
            QgsMapLayerRegistry.instance().addMapLayers(layers)
    
            super(VRTRasterPreviewMapCanvas, self).setLayerSet([QgsMapCanvasLayer(l) for l in layers])
    
    
    
    
        def reset(self):
            extent = self.fullExtent()
            extent.scale(1.05)
            self.setExtent(extent)
            self.refresh()
    
    
        sigSourceBandInserted = pyqtSignal(VRTRasterBand, VRTRasterInputSourceBand)
        sigSourceBandRemoved = pyqtSignal(VRTRasterBand, VRTRasterInputSourceBand)
        sigSourceRasterAdded = pyqtSignal(list)
        sigSourceRasterRemoved = pyqtSignal(list)
        sigBandInserted = pyqtSignal(int, VRTRasterBand)
        sigBandRemoved = pyqtSignal(int, VRTRasterBand)
        sigCrsChanged = pyqtSignal(QgsCoordinateReferenceSystem)
    
        def __init__(self, parent=None):
            super(VRTRaster, self).__init__(parent)
            self.mBands = []
            self.mCrs = None
            self.mResampleAlg = gdal.GRA_NearestNeighbour
            self.mMetadata = dict()
            self.mSourceRasterBounds = dict()
            self.mOutputBounds = None
            self.sigSourceBandRemoved.connect(self.updateSourceRasterBounds)
            self.sigSourceBandInserted.connect(self.updateSourceRasterBounds)
            self.sigBandRemoved.connect(self.updateSourceRasterBounds)
            self.sigBandInserted.connect(self.updateSourceRasterBounds)
    
        def setCrs(self, crs):
            if isinstance(crs, osr.SpatialReference):
                auth = '{}:{}'.format(crs.GetAttrValue('AUTHORITY',0), crs.GetAttrValue('AUTHORITY',1))
                crs = QgsCoordinateReferenceSystem(auth)
            if isinstance(crs, QgsCoordinateReferenceSystem):
                if crs != self.mCrs:
                    self.mCrs = crs
                    self.sigCrsChanged.emit(self.mCrs)
    
    
        def crs(self):
            return self.mCrs
    
    
        def addVirtualBand(self, virtualBand):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Adds a virtual band
            :param virtualBand: the VirtualBand to be added
            :return: VirtualBand
            """
    
            return self.insertVirtualBand(len(self), virtualBand)
    
        def insertSourceBand(self, virtualBandIndex, pathSource, sourceBandIndex):
            """
            Inserts a source band into the VRT stack
            :param virtualBandIndex: target virtual band index
            :param pathSource: path of source file
            :param sourceBandIndex: source file band index
            """
    
    
            while virtualBandIndex > len(self.mBands)-1:
    
                self.insertVirtualBand(len(self.mBands), VRTRasterBand())
    
            vBand.addSourceBand(pathSource, sourceBandIndex)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Inserts a VirtualBand
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            :param virtualBand: the VirtualBand to be inserted
            :return: the VirtualBand
            """
    
            assert isinstance(virtualBand, VRTRasterBand)
            assert index <= len(self.mBands)
            if len(virtualBand.name()) == 0:
                virtualBand.setName('Band {}'.format(index+1))
    
            virtualBand.mVRT = self
    
            virtualBand.sigSourceInserted.connect(
                lambda _, sourceBand: self.sigSourceBandInserted.emit(virtualBand, sourceBand))
            virtualBand.sigSourceRemoved.connect(
                lambda _, sourceBand: self.sigSourceBandInserted.emit(virtualBand, sourceBand))
    
            self.mBands.insert(index, virtualBand)
            self.sigBandInserted.emit(index, virtualBand)
    
            return self[index]
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
    
        def removeVirtualBands(self, bandsOrIndices):
            assert isinstance(bandsOrIndices, list)
            to_remove = []
    
            for virtualBand in bandsOrIndices:
                if not isinstance(virtualBand, VRTRasterBand):
                    virtualBand = self.mBands[virtualBand]
                to_remove.append((self.mBands.index(virtualBand), virtualBand))
    
            to_remove = sorted(to_remove, key=lambda t: t[0], reverse=True)
            for index, virtualBand in to_remove:
                self.mBands.remove(virtualBand)
                self.sigBandRemoved.emit(index, virtualBand)
    
        def removeInputSource(self, path):
            assert path in self.sourceRaster()
            for vBand in self.mBands:
                assert isinstance(vBand, VRTRasterBand)
                if path in vBand.sources():
                    vBand.removeSource(path)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def removeVirtualBand(self, bandOrIndex):
    
        def addFilesAsMosaic(self, files):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Shortcut to mosaic all input files. All bands will maintain their band position in the virtual file.
            :param files: [list-of-file-paths]
            """
    
            for file in files:
                ds = gdal.Open(file)
                assert isinstance(ds, gdal.Dataset)
                nb = ds.RasterCount
                for b in range(nb):
                    if b+1 < len(self):
                        #add new virtual band
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    vBand = self[b]
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    vBand.addSourceBand(file, b)
            return self
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Shortcut to stack all input files, i.e. each band of an input file will be a new virtual band.
            Bands in the virtual file will be ordered as file1-band1, file1-band n, file2-band1, file2-band,...
            :param files: [list-of-file-paths]
            :return: self
            """
            for file in files:
                ds = gdal.Open(file)
    
                assert isinstance(ds, gdal.Dataset), 'Can not open {}'.format(file)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                nb = ds.RasterCount
                ds = None
                for b in range(nb):
                    #each new band is a new virtual band
    
                    vBand = self.addVirtualBand(VRTRasterBand())
                    assert isinstance(vBand, VRTRasterBand)
                    vBand.addSource(VRTRasterInputSourceBand(file, b))
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            return self
    
            for vBand in self.mBands:
                assert isinstance(vBand, VRTRasterBand)
    
                files.update(set(vBand.sourceFiles()))
            return sorted(list(files))
    
    
        def sourceRasterBounds(self):
            return self.mSourceRasterBounds
    
        def outputBounds(self):
            if isinstance(self.mOutputBounds, RasterBounds):
                return
                #calculate from source rasters
    
        def setOutputBounds(self, bounds):
            assert isinstance(self, RasterBounds)
            self.mOutputBounds = bounds
    
    
        def updateSourceRasterBounds(self):
    
            srcFiles = self.sourceRaster()
            toRemove = [f for f in self.mSourceRasterBounds.keys() if f not in srcFiles]
            toAdd = [f for f in srcFiles if f not in self.mSourceRasterBounds.keys()]
    
            for f in toRemove:
                del self.mSourceRasterBounds[f]
            for f in toAdd:
                self.mSourceRasterBounds[f] = RasterBounds(f)
    
            if len(srcFiles) > 0 and self.crs() == None:
                self.setCrs(self.mSourceRasterBounds[srcFiles[0]].crs)
    
            elif len(srcFiles) == 0:
                self.setCrs(None)
    
    
            if len(toRemove) > 0:
                self.sigSourceRasterRemoved.emit(toRemove)
            if len(toAdd) > 0:
                self.sigSourceRasterAdded.emit(toAdd)
    
    
        def saveVRT(self, pathVRT, resampleAlg=gdal.GRA_NearestNeighbour, **kwds):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            :param pathVRT: path to VRT that is created
            :param options --- can be be an array of strings, a string or let empty and filled from other keywords..
            :param resolution --- 'highest', 'lowest', 'average', 'user'.
            :param outputBounds --- output bounds as (minX, minY, maxX, maxY) in target SRS.
            :param xRes, yRes --- output resolution in target SRS.
            :param targetAlignedPixels --- whether to force output bounds to be multiple of output resolution.
            :param bandList --- array of band numbers (index start at 1).
            :param addAlpha --- whether to add an alpha mask band to the VRT when the source raster have none.
            :param resampleAlg --- resampling mode.
            :param outputSRS --- assigned output SRS.
            :param allowProjectionDifference --- whether to accept input datasets have not the same projection. Note: they will *not* be reprojected.
            :param srcNodata --- source nodata value(s).
            :param callback --- callback method.
            :param callback_data --- user data for callback.
            :return: gdal.DataSet(pathVRT)
            """
    
    
            if len(self.mBands) == 0:
                print('No VRT Inputs defined.')
                return None
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            _kwds = dict()
            supported = ['options','resolution','outputBounds','xRes','yRes','targetAlignedPixels','addAlpha','resampleAlg',
            'outputSRS','allowProjectionDifference','srcNodata','VRTNodata','hideNodata','callback', 'callback_data']
            for k in kwds.keys():
                if k in supported:
                    _kwds[k] = kwds[k]
    
    
            if 'resampleAlg' not in _kwds:
                _kwds['resampleAlg'] = resampleAlg
    
            if isinstance(self.mOutputBounds, RasterBounds):
                bounds = self.mOutputBounds.polygon
                xmin, ymin,xmax, ymax = bounds
                _kwds['outputBounds'] = (xmin, ymin,xmax, ymax)
    
            dirVrt = os.path.dirname(pathVRT)
            dirWarpedVRT = os.path.join(dirVrt, 'WarpedVRTs')
            if not os.path.isdir(dirVrt):
                os.mkdir(dirVrt)
    
            srcLookup = dict()
    
            for i, pathSrc in enumerate(self.sourceRaster()):
                dsSrc = gdal.Open(pathSrc)
                assert isinstance(dsSrc, gdal.Dataset)
                band = dsSrc.GetRasterBand(1)
    
                if noData and srcNodata is None:
                    srcNodata = noData
    
    
                crs = QgsCoordinateReferenceSystem(dsSrc.GetProjection())
    
                if crs == self.mCrs:
                    srcLookup[pathSrc] = pathSrc
                else:
    
                    if not os.path.isdir(dirWarpedVRT):
                        os.mkdir(dirWarpedVRT)
                    pathVRT2 = os.path.join(dirWarpedVRT, 'warped.{}.vrt'.format(os.path.basename(pathSrc)))
                    wops = gdal.WarpOptions(format='VRT',
                                            dstSRS=self.mCrs.toWkt())
                    tmp = gdal.Warp(pathVRT2, dsSrc, options=wops)
                    assert isinstance(tmp, gdal.Dataset)
                    tmp = None
                    srcLookup[pathSrc] = pathVRT2
    
    
    
    
            srcFiles = [srcLookup[src] for src in self.sourceRaster()]
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            vro = gdal.BuildVRTOptions(separate=True, **_kwds)
    
            #1. build a temporary VRT that described the spatial shifts of all input sources
            gdal.BuildVRT(pathVRT, srcFiles, options=vro)
            dsVRTDst = gdal.Open(pathVRT)
            assert isinstance(dsVRTDst, gdal.Dataset)
    
            ns, nl = dsVRTDst.RasterXSize, dsVRTDst.RasterYSize
            gt = dsVRTDst.GetGeoTransform()
            crs = dsVRTDst.GetProjectionRef()
            eType = dsVRTDst.GetRasterBand(1).DataType
            SOURCE_TEMPLATES = dict()
            for i, srcFile in enumerate(srcFiles):
                vrt_sources = dsVRTDst.GetRasterBand(i+1).GetMetadata('vrt_sources')
                assert len(vrt_sources) == 1
                srcXML = vrt_sources.values()[0]
                assert os.path.basename(srcFile)+'</SourceFilename>' in srcXML
                assert '<SourceBand>1</SourceBand>' in srcXML
                SOURCE_TEMPLATES[srcFile] = srcXML
            dsVRTDst = None
            #remove the temporary VRT, we don't need it any more
            os.remove(pathVRT)
    
            #2. build final VRT from scratch
            drvVRT = gdal.GetDriverByName('VRT')
            assert isinstance(drvVRT, gdal.Driver)
    
            dsVRTDst = drvVRT.Create(pathVRT, ns, nl,0, eType=eType)
    
            #2.1. set general properties
            assert isinstance(dsVRTDst, gdal.Dataset)
            dsVRTDst.SetProjection(crs)
            dsVRTDst.SetGeoTransform(gt)
    
            #2.2. add virtual bands
    
            for i, vBand in enumerate(self.mBands):
                assert isinstance(vBand, VRTRasterBand)
    
                assert dsVRTDst.AddBand(eType, options=['subClass=VRTSourcedRasterBand']) == 0
                vrtBandDst = dsVRTDst.GetRasterBand(i+1)
                assert isinstance(vrtBandDst, gdal.Band)
    
                vrtBandDst.SetDescription(str(vBand.name()))
    
                md = {}
                #add all input sources for this virtual band
                for iSrc, sourceInfo in enumerate(vBand.sources):
    
                    assert isinstance(sourceInfo, VRTRasterInputSourceBand)
                    bandIndex = sourceInfo.mBandIndex
                    xml = SOURCE_TEMPLATES[srcLookup[sourceInfo.mPath]]
    
                    xml = re.sub('<SourceBand>1</SourceBand>','<SourceBand>{}</SourceBand>'.format(bandIndex+1), xml)
                    md['source_{}'.format(iSrc)] = xml
                vrtBandDst.SetMetadata(md,'vrt_sources')
    
                    vrtBandDst.ComputeBandStats(1)
    
    
            dsVRTDst = None
    
            #check if we get what we like to get
            dsCheck = gdal.Open(pathVRT)
    
            s = ""
    
    
        def __repr__(self):
    
            info = ['VirtualRasterBuilder: {} bands, {} source files'.format(
    
                len(self.mBands), len(self.sourceRaster()))]
            for vBand in self.mBands:
    
                info.append(str(vBand))
            return '\n'.join(info)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def __len__(self):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def __getitem__(self, slice):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def __delitem__(self, slice):
            self.removeVirtualBands(self[slice])
    
        def __contains__(self, item):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def __iter__(self):
            return iter(self.mClasses)
    
    
    
    class VRTRasterVectorLayer(QgsVectorLayer):
    
        def __init__(self, vrtRaster, crs=None):
            assert isinstance(vrtRaster, VRTRaster)
            if crs is None:
                crs = QgsCoordinateReferenceSystem('EPSG:4326')
    
    
            uri = 'polygon?crs={}'.format(crs.authid())
    
            super(VRTRasterVectorLayer, self).__init__(uri, 'VRTRaster', 'memory', False)
            self.mCrs = crs
            self.mVRTRaster = vrtRaster
    
            #initialize fields
            assert self.startEditing()
            # standard field names, types, etc.
            fieldDefs = [('oid', QVariant.Int, 'integer'),
                         ('type', QVariant.String, 'string'),
                         ('name', QVariant.String, 'string'),
                         ('path', QVariant.String, 'string'),
                         ]
            # initialize fields
            for fieldDef in fieldDefs:
                field = QgsField(fieldDef[0], fieldDef[1], fieldDef[2])
                self.addAttribute(field)
            self.commitChanges()
    
            symbol = QgsFillSymbolV2.createSimple({'style': 'no', 'color': 'red', 'outline_color':'black'})
            self.rendererV2().setSymbol(symbol)
            self.label().setFields(self.fields())
            self.label().setLabelField(3,3)
            self.mVRTRaster.sigSourceRasterAdded.connect(self.onRasterInserted)
            self.mVRTRaster.sigSourceRasterRemoved.connect(self.onRasterRemoved)
            self.onRasterInserted(self.mVRTRaster.sourceRaster())
    
    
        def path2feature(self, path):
            for f in self.dataProvider().getFeatures():
                if str(f.attribute('path')) == str(path):
                    return f
            return None
    
        def path2fid(self, path):
            for f in self.dataProvider().getFeatures():
                if str(f.attribute('path')) == str(path):
                    return f.id()
    
    
            return None
    
        def fid2path(self, fid):
            for f in self.dataProvider().getFeatures():
                if f.fid() == fid:
                    return f
    
            return None
    
    
        def onRasterInserted(self, listOfNewFiles):
            assert isinstance(listOfNewFiles, list)
            if len(listOfNewFiles) == 0:
                return
    
            for f in listOfNewFiles:
                bounds = self.mVRTRaster.sourceRasterBounds()[f]
                assert isinstance(bounds, RasterBounds)
                oid = str(id(bounds))
                geometry =QgsPolygonV2(bounds.polygon)
    
                #geometry = QgsCircularStringV2(bounds.curve)
    
                trans = QgsCoordinateTransform(bounds.crs, self.crs())
                geometry.transform(trans)
    
    
    
    
                feature = QgsFeature(self.pendingFields())
                #feature.setGeometry(QgsGeometry(geometry))
                feature.setGeometry(QgsGeometry.fromWkt(geometry.asWkt()))
                #feature.setFeatureId(int(oid))
                feature.setAttribute('oid', oid)
                feature.setAttribute('type', 'source file')
                feature.setAttribute('name', str(os.path.basename(f)))
                feature.setAttribute('path', str(f))
                #feature.setValid(True)
    
            self.updateExtents()
            assert self.commitChanges()
            self.dataChanged.emit()
    
    
        def onRasterRemoved(self, files):
            self.startEditing()
            self.selectAll()
            toRemove = []
            for f in self.selectedFeatures():
                if f.attribute('path') in files:
                    toRemove.append(f.id())
            self.setSelectedFeatures(toRemove)
            self.deleteSelectedFeatures()
            self.commitChanges()
    
    def createVirtualBandMosaic(bandFiles, pathVRT):
        drv = gdal.GetDriverByName('VRT')
    
        refPath = bandFiles[0]
        refDS = gdal.Open(refPath)
        ns, nl, nb = refDS.RasterXSize, refDS.RasterYSize, refDS.RasterCount
        noData = refDS.GetRasterBand(1).GetNoDataValue()
    
        vrtOptions = gdal.BuildVRTOptions(
            # here we can use the options known from http://www.gdal.org/gdalbuildvrt.html
            separate=False
        )
        if len(bandFiles) > 1:
            s =""
        vrtDS = gdal.BuildVRT(pathVRT, bandFiles, options=vrtOptions)
        vrtDS.FlushCache()
    
        assert vrtDS.RasterCount == nb
        return vrtDS
    
    def createVirtualBandStack(bandFiles, pathVRT):
    
        nb = len(bandFiles)
    
        drv = gdal.GetDriverByName('VRT')
    
        refPath = bandFiles[0]
        refDS = gdal.Open(refPath)
        ns, nl = refDS.RasterXSize, refDS.RasterYSize
        noData = refDS.GetRasterBand(1).GetNoDataValue()
    
        vrtOptions = gdal.BuildVRTOptions(
            # here we can use the options known from http://www.gdal.org/gdalbuildvrt.html
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            separate=True,
    
        )
        vrtDS = gdal.BuildVRT(pathVRT, bandFiles, options=vrtOptions)
        vrtDS.FlushCache()
    
        assert vrtDS.RasterCount == nb
    
        #copy band metadata from
        for i in range(nb):
            band = vrtDS.GetRasterBand(i+1)
            band.SetDescription(bandFiles[i])
    
            band.ComputeBandStats()
    
    
    from timeseriesviewer.utils import loadUi
    
    class TreeNode(QObject):
    
        sigWillAddChildren = pyqtSignal(QObject, int, int)
        sigAddedChildren = pyqtSignal(QObject, int, int)
        sigWillRemoveChildren = pyqtSignal(QObject, int, int)
        sigRemovedChildren = pyqtSignal(QObject, int, int)
        sigUpdated = pyqtSignal(QObject)
    
            super(TreeNode, self).__init__()
            self.mParent = parentNode
    
            self.mChildren = []
    
            self.mValues = []
            self.mIcon = None
            self.mToolTip = None
    
    
    
            if isinstance(parentNode, TreeNode):
                parentNode.appendChildNodes(self)
    
    
        def nodeIndex(self):
            return self.mParent.mChildren.index(self)
    
        def next(self):
            i = self.nodeIndex()
            if i < len(self.mChildren.mChildren):
                return self.mParent.mChildren[i+1]
            else:
                return None
    
        def previous(self):
            i = self.nodeIndex()
            if i > 0:
                return self.mParent.mChildren[i - 1]
            else:
                return None
    
    
        def detach(self):
            """
            Detaches this TreeNode from its parent TreeNode
            :return:
            """
            if isinstance(self.mParent, TreeNode):
                self.mParent.mChildren.remove(self)
                self.setParentNode(None)
    
        def appendChildNodes(self, listOfChildNodes):
            self.insertChildNodes(len(self.mChildren), listOfChildNodes)
    
        def insertChildNodes(self, index, listOfChildNodes):
            assert index <= len(self.mChildren)
            if isinstance(listOfChildNodes, TreeNode):
                listOfChildNodes = [listOfChildNodes]
            assert isinstance(listOfChildNodes, list)
            l = len(listOfChildNodes)
            idxLast = index+l-1
            self.sigWillAddChildren.emit(self, index, idxLast)
            for i, node in enumerate(listOfChildNodes):
                assert isinstance(node, TreeNode)
                node.mParent = self
                # connect node signals
                node.sigWillAddChildren.connect(self.sigWillAddChildren)
                node.sigAddedChildren.connect(self.sigAddedChildren)
                node.sigWillRemoveChildren.connect(self.sigWillRemoveChildren)
                node.sigRemovedChildren.connect(self.sigRemovedChildren)
                node.sigUpdated.connect(self.sigUpdated)
    
                self.mChildren.insert(index+i, node)
    
            self.sigAddedChildren.emit(self, index, idxLast)
    
        def removeChildNode(self, node):
            assert node in self.mChildren
            i = self.mChildren.index(node)
            self.removeChildNodes(i, 1)
    
        def removeChildNodes(self, row, count):
    
            if row < 0 or count <= 0:
                return False
    
            rowLast = row + count - 1
    
            if rowLast >= self.childCount():
                return False
    
            self.sigWillRemoveChildren.emit(self, row, rowLast)
            to_remove = self.childNodes()[row:rowLast+1]
            for n in to_remove:
                self.mChildren.remove(n)
                #n.mParent = None
    
            self.sigRemovedChildren.emit(self, row, rowLast)
    
    
    
    
        def setToolTip(self, toolTip):
            self.mToolTip = toolTip
        def toolTip(self):
            return self.mToolTip
    
    
            return self.mParent
    
    
        def setParentNode(self, treeNode):
            assert isinstance(treeNode, TreeNode)
            self.mParent = treeNode
    
    
        def setIcon(self, icon):
            self.mIcon = icon
    
        def icon(self):
            return self.mIcon
    
        def setName(self, name):
            self.mName = name
    
        def name(self):
            return self.mName
    
        def contextMenu(self):
            return None
    
    
        def setValues(self, listOfValues):
            if not isinstance(listOfValues, list):
                listOfValues = [listOfValues]
            self.mValues = listOfValues[:]
        def values(self):
            return self.mValues[:]
    
        def childCount(self):
            return len(self.mChildren)
    
    
        def findChildNodes(self, type, recursive=True):
            results = []
            for node in self.mChildren:
                if isinstance(node, type):
                    results.append(node)
                if recursive:
                    results.extend(node.findChildNodes(type, recursive=True))
            return results
    
    
    class SourceRasterFileNode(TreeNode):
    
        def __init__(self, parentNode, path):
            super(SourceRasterFileNode, self).__init__(parentNode)
    
            self.mPath = path
            self.setName(os.path.basename(path))
    
            srcNode = TreeNode(self, name='Path')
            srcNode.setValues(path)
    
    
    
            #populate metainfo
            ds = gdal.Open(path)
            assert isinstance(ds, gdal.Dataset)
    
            crsNode.setIcon(QIcon(':/timeseriesviewer/icons/CRS.png'))
    
            crs = osr.SpatialReference()
            crs.ImportFromWkt(ds.GetProjection())
    
            authInfo = '{}:{}'.format(crs.GetAttrValue('AUTHORITY',0), crs.GetAttrValue('AUTHORITY',1))
            crsNode.setValues([authInfo,crs.ExportToWkt()])
            self.bandNode = TreeNode(None, name='Bands')
    
            for b in range(ds.RasterCount):
                band = ds.GetRasterBand(b+1)
    
    
                inputSource = VRTRasterInputSourceBand(path, b)
                inputSource.mBandName = band.GetDescription()
                if inputSource.mBandName in [None,'']:
                    inputSource.mBandName = '{}'.format(b + 1)
                inputSource.mNoData = band.GetNoDataValue()
    
                SourceRasterBandNode(self.bandNode, inputSource)
            self.bandNode.setParentNode(self)
            self.appendChildNodes(self.bandNode)
    
        def sourceBands(self):
            return [n.mSrcBand for n in self.bandNode.mChildren if isinstance(n, SourceRasterBandNode)]
    
    
    class SourceRasterBandNode(TreeNode):
    
        def __init__(self, parentNode, vrtRasterInputSourceBand):
            assert isinstance(vrtRasterInputSourceBand, VRTRasterInputSourceBand)
    
            super(SourceRasterBandNode, self).__init__(parentNode)
    
            self.setIcon(QIcon(":/timeseriesviewer/icons/mIconRaster.png"))
    
            self.mSrcBand = vrtRasterInputSourceBand
            self.setName(self.mSrcBand.mBandName)
            #self.setValues([self.mSrcBand.mPath])
            self.setToolTip('band {}:{}'.format(self.mSrcBand.mBandIndex+1, self.mSrcBand.mPath))
    
    class VRTRasterNode(TreeNode):
        def __init__(self, parentNode, vrtRaster):
            assert isinstance(vrtRaster, VRTRaster)
    
            super(VRTRasterNode, self).__init__(parentNode)
            self.mVRTRaster = vrtRaster
            self.mVRTRaster.sigBandInserted.connect(self.onBandInserted)
            self.mVRTRaster.sigBandRemoved.connect(self.onBandRemoved)
    
        def onBandInserted(self, index, vrtRasterBand):
            assert isinstance(vrtRasterBand, VRTRasterBand)
            i = vrtRasterBand.bandIndex()
            assert i == index
            node = VRTRasterBandNode(None, vrtRasterBand)
            self.insertChildNodes(i, [node])
    
        def onBandRemoved(self, removedIdx):
            self.removeChildNodes(removedIdx, 1)
    
    
    class VRTRasterBandNode(TreeNode):
    
        def __init__(self, parentNode, virtualBand):
    
            assert isinstance(virtualBand, VRTRasterBand)
    
            super(VRTRasterBandNode, self).__init__(parentNode)
    
            self.mVirtualBand = virtualBand
    
            self.setName(virtualBand.name())
    
            self.setIcon(QIcon(":/timeseriesviewer/icons/mIconVirtualRaster.png"))
    
            #self.nodeBands = TreeNode(self, name='Input Bands')
            #self.nodeBands.setToolTip('Source bands contributing to this virtual raster band')
            self.nodeBands = self
    
            virtualBand.sigNameChanged.connect(self.setName)
    
            virtualBand.sigSourceInserted.connect(lambda _, src: self.onSourceInserted(src))
            virtualBand.sigSourceRemoved.connect(self.onSourceRemoved)
            for src in self.mVirtualBand.sources:
                self.onSourceInserted(src)
    
    
        def onSourceInserted(self, inputSource):
            assert isinstance(inputSource, VRTRasterInputSourceBand)
    
            assert inputSource.virtualBand() == self.mVirtualBand
            i = self.mVirtualBand.sources.index(inputSource)
    
    
            node = VRTRasterInputSourceBandNode(None, inputSource)
            self.nodeBands.insertChildNodes(i, node)
    
        def onSourceRemoved(self, row, inputSource):
            assert isinstance(inputSource, VRTRasterInputSourceBand)
    
            node = self.nodeBands.childNodes()[row]
            if  node.mSrc != inputSource:
                s = ""
            self.nodeBands.removeChildNode(node)
    
    
    
    
    class VRTRasterInputSourceBandNode(TreeNode):
        def __init__(self, parentNode, vrtRasterInputSourceBand):
            assert isinstance(vrtRasterInputSourceBand, VRTRasterInputSourceBand)
            super(VRTRasterInputSourceBandNode, self).__init__(parentNode)
    
            self.setIcon(QIcon(":/timeseriesviewer/icons/mIconRaster.png"))
    
            self.mSrc = vrtRasterInputSourceBand
            name = '{}:{}'.format(self.mSrc.mBandIndex+1, os.path.basename(self.mSrc.mPath))
            self.setName(name)
            #self.setValues([self.mSrc.mPath, self.mSrc.mBandIndex])
    
        def sourceBand(self):
            return self.mSrc
    
    class TreeView(QTreeView):
    
        def __init__(self, *args, **kwds):
            super(TreeView, self).__init__(*args, **kwds)
    
    class TreeModel(QAbstractItemModel):
        def __init__(self, parent=None, rootNode = None):
            super(TreeModel, self).__init__(parent)
    
    
            self.mColumnNames = ['Node','Value']
    
            self.mRootNode = rootNode if isinstance(rootNode, TreeNode) else TreeNode(None)
            self.mRootNode.sigWillAddChildren.connect(self.nodeWillAddChildren)
            self.mRootNode.sigAddedChildren.connect(self.nodeAddedChildren)
            self.mRootNode.sigWillRemoveChildren.connect(self.nodeWillRemoveChildren)
            self.mRootNode.sigRemovedChildren.connect(self.nodeRemovedChildren)
            self.mRootNode.sigUpdated.connect(self.nodeUpdated)
    
            self.mTreeView = None
            if isinstance(parent, QTreeView):
                self.connectTreeView(parent)
    
        def nodeWillAddChildren(self, node, idx1, idxL):
            idxNode = self.node2idx(node)
            self.beginInsertRows(idxNode, idx1, idxL)
    
    
        def nodeAddedChildren(self, node, idx1, idxL):
            self.endInsertRows()
            #for i in range(idx1, idxL+1):
            for n in node.childNodes():
                self.setColumnSpan(node)