Skip to content
Snippets Groups Projects
virtualrasters.py 33.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    /***************************************************************************
    
                                  Virtual Raster Builder
                                  ----------------------
    
            begin                : 2015-08-20
            git sha              : $Format:%H$
            copyright            : (C) 2017 by HU-Berlin
            email                : benjamin.jakimow@geo.hu-berlin.de
     ***************************************************************************/
    
    /***************************************************************************
     *                                                                         *
     *   This program is free software; you can redistribute it and/or modify  *
     *   it under the terms of the GNU General Public License as published by  *
    
     *   the Free Software Foundation; either version 3 of the License, or     *
    
     *   (at your option) any later version.                                   *
     *                                                                         *
     ***************************************************************************/
    """
    
    import os, sys, re, pickle, tempfile, uuid
    from xml.etree import ElementTree
    
    from osgeo import gdal, osr, ogr, gdalconst as gc
    
    from qgis.core import *
    
    from PyQt5.QtCore import *
    from PyQt5.QtGui import *
    
    from PyQt5.QtWidgets import *
    
    from timeseriesviewer.models import Option, OptionListModel
    
    #lookup GDAL Data Type and its size in bytes
    LUT_GDT_SIZE = {gdal.GDT_Byte:1,
                    gdal.GDT_UInt16:2,
                    gdal.GDT_Int16:2,
                    gdal.GDT_UInt32:4,
                    gdal.GDT_Int32:4,
                    gdal.GDT_Float32:4,
                    gdal.GDT_Float64:8,
                    gdal.GDT_CInt16:2,
                    gdal.GDT_CInt32:4,
                    gdal.GDT_CFloat32:4,
                    gdal.GDT_CFloat64:8}
    
    LUT_GDT_NAME = {gdal.GDT_Byte:'Byte',
                    gdal.GDT_UInt16:'UInt16',
                    gdal.GDT_Int16:'Int16',
                    gdal.GDT_UInt32:'UInt32',
                    gdal.GDT_Int32:'Int32',
                    gdal.GDT_Float32:'Float32',
                    gdal.GDT_Float64:'Float64',
                    gdal.GDT_CInt16:'Int16',
                    gdal.GDT_CInt32:'Int32',
                    gdal.GDT_CFloat32:'Float32',
                    gdal.GDT_CFloat64:'Float64'}
    
    
    GRA_tooltips = {'NearestNeighbour':'nearest neighbour resampling (default, fastest algorithm, worst interpolation quality).',
                  'Bilinear':'bilinear resampling.',
                  'Lanczos':'lanczos windowed sinc resampling.',
                  'Average':'average resampling, computes the average of all non-NODATA contributing pixels.',
                  'Cubic':'cubic resampling.',
                  'CubicSpline':'cubic spline resampling.',
                  'Mode':'mode resampling, selects the value which appears most often of all the sampled points',
                  'Max':'maximum resampling, selects the maximum value from all non-NODATA contributing pixels',
                  'Min':'minimum resampling, selects the minimum value from all non-NODATA contributing pixels.',
                  'Med':'median resampling, selects the median value of all non-NODATA contributing pixels.',
                  'Q1':'first quartile resampling, selects the first quartile value of all non-NODATA contributing pixels. ',
                  'Q3':'third quartile resampling, selects the third quartile value of all non-NODATA contributing pixels'
                  }
    
    
    RESAMPLE_ALGS = OptionListModel()
    
    for GRAkey in [k for k in list(gdal.__dict__.keys()) if k.startswith('GRA_')]:
    
        GRA = gdal.__dict__[GRAkey]
        GRA_Name = GRAkey[4:]
    
        option = Option(GRA, GRA_Name, tooltip=GRA_tooltips.get(GRA_Name))
        RESAMPLE_ALGS.addOption(option)
    
    
    # thanks to https://gis.stackexchange.com/questions/75533/how-to-apply-band-settings-using-gdal-python-bindings
    def read_vsimem(fn):
        """
        Reads VSIMEM path as string
        :param fn: vsimem path (str)
        :return: result of gdal.VSIFReadL(1, vsileng, vsifile)
        """
        vsifile = gdal.VSIFOpenL(fn,'r')
        gdal.VSIFSeekL(vsifile, 0, 2)
        vsileng = gdal.VSIFTellL(vsifile)
        gdal.VSIFSeekL(vsifile, 0, 0)
        return gdal.VSIFReadL(1, vsileng, vsifile)
    
    def write_vsimem(fn:str,data:str):
        """
        Writes data to vsimem path
        :param fn: vsimem path (str)
        :param data: string to write
        :return: result of gdal.VSIFCloseL(vsifile)
        """
        '''Write GDAL vsimem files'''
        vsifile = gdal.VSIFOpenL(fn,'w')
        size = len(data)
        gdal.VSIFWriteL(data, 1, size, vsifile)
        return gdal.VSIFCloseL(vsifile)
    
    
    def px2geo(px, gt):
        #see http://www.gdal.org/gdal_datamodel.html
        gx = gt[0] + px.x()*gt[1]+px.y()*gt[2]
        gy = gt[3] + px.x()*gt[4]+px.y()*gt[5]
    
        return QgsPoint(gx,gy)
    
    def describeRawFile(pathRaw, pathVrt, xsize, ysize,
                        bands=1,
                        eType = gdal.GDT_Byte,
                        interleave='bsq',
                        byteOrder='LSB',
                        headerOffset=0):
        """
        Creates a VRT to describe a raw binary file
        :param pathRaw: path of raw image
        :param pathVrt: path of destination VRT
        :param xsize: number of image samples / columns
        :param ysize: number of image lines
        :param bands: number of image bands
        :param eType: the GDAL data type
        :param interleave: can be 'bsq' (default),'bil' or 'bip'
        :param byteOrder: 'LSB' (default) or 'MSB'
        :param headerOffset: header offset in bytes, default = 0
        :return: gdal.Dataset of created VRT
        """
        assert xsize > 0
        assert ysize > 0
        assert bands > 0
        assert eType > 0
    
        assert eType in LUT_GDT_SIZE.keys(), 'dataType "{}" is not a valid gdal datatype'.format(eType)
        interleave = interleave.lower()
    
        assert interleave in ['bsq','bil','bip']
        assert byteOrder in ['LSB', 'MSB']
    
    
        drvVRT = gdal.GetDriverByName('VRT')
        assert isinstance(drvVRT, gdal.Driver)
        dsVRT = drvVRT.Create(pathVrt, xsize, ysize, bands=0, eType=eType)
        assert isinstance(dsVRT, gdal.Dataset)
    
        #vrt = ['<VRTDataset rasterXSize="{xsize}" rasterYSize="{ysize}">'.format(xsize=xsize,ysize=ysize)]
    
    
        vrtDir = os.path.dirname(pathVrt)
        if pathRaw.startswith(vrtDir):
            relativeToVRT = 1
            srcFilename = os.path.relpath(pathRaw, vrtDir)
        else:
            relativeToVRT = 0
            srcFilename = pathRaw
    
        for b in range(bands):
            if interleave == 'bsq':
                imageOffset = headerOffset
                pixelOffset = LUT_GDT_SIZE[eType]
                lineOffset = pixelOffset * xsize
            elif interleave == 'bip':
                imageOffset = headerOffset + b * LUT_GDT_SIZE[eType]
                pixelOffset = bands * LUT_GDT_SIZE[eType]
                lineOffset = xsize * bands
            else:
                raise Exception('Interleave {} is not supported'.format(interleave))
    
            options = ['subClass=VRTRawRasterBand']
            options.append('SourceFilename={}'.format(srcFilename))
            options.append('dataType={}'.format(LUT_GDT_NAME[eType]))
            options.append('ImageOffset={}'.format(imageOffset))
            options.append('PixelOffset={}'.format(pixelOffset))
            options.append('LineOffset={}'.format(lineOffset))
            options.append('ByteOrder={}'.format(byteOrder))
    
    
            xml = """<SourceFilename relativetoVRT="{relativeToVRT}">{srcFilename}</SourceFilename>
                <ImageOffset>{imageOffset}</ImageOffset>
                <PixelOffset>{pixelOffset}</PixelOffset>
                <LineOffset>{lineOffset}</LineOffset>
                <ByteOrder>{byteOrder}</ByteOrder>""".format(relativeToVRT=relativeToVRT,
    
                                                             srcFilename=srcFilename,
                                                             imageOffset=imageOffset,
                                                             pixelOffset=pixelOffset,
                                                             lineOffset=lineOffset,
                                                             byteOrder=byteOrder)
    
            #md = {}
            #md['source_0'] = xml
            #vrtBand = dsVRT.GetRasterBand(b + 1)
            assert dsVRT.AddBand(eType, options=options) == 0
    
            vrtBand = dsVRT.GetRasterBand(b+1)
            assert isinstance(vrtBand, gdal.Band)
            #vrtBand.SetMetadata(md, 'vrt_sources')
            #vrt.append('  <VRTRasterBand dataType="{dataType}" band="{band}" subClass="VRTRawRasterBand">'.format(dataType=LUT_GDT_NAME[eType], band=b+1))
    
        @staticmethod
        def fromGDALDataSet(pathOrDataSet):
    
            """
            Returns the VRTRasterInputSourceBands from a raster data source
            :param pathOrDataSet: str | gdal.Dataset
            :return: [list-of-VRTRasterInputSourceBand]
            """
    
    
            srcBands = []
    
            if isinstance(pathOrDataSet, str):
                pathOrDataSet = gdal.Open(pathOrDataSet)
    
            if isinstance(pathOrDataSet, gdal.Dataset):
                path = pathOrDataSet.GetFileList()[0]
                for b in range(pathOrDataSet.RasterCount):
                    srcBands.append(VRTRasterInputSourceBand(path, b))
            return srcBands
    
    
    
    
        def __init__(self, path:str, bandIndex:int, bandName:str=''):
            self.mPath = path
    
            self.mBandIndex = bandIndex
            self.mBandName = bandName
            self.mNoData = None
    
            self.mVirtualBand = None
    
    
        def isEqual(self, other):
            if isinstance(other, VRTRasterInputSourceBand):
                return self.mPath == other.mPath and self.mBandIndex == other.mBandIndex
    
            else:
                return False
    
        def __reduce_ex__(self, protocol):
    
            return self.__class__, (self.mPath, self.mBandIndex, self.mBandName), self.__getstate__()
    
        def __getstate__(self):
            state = self.__dict__.copy()
            state.pop('mVirtualBand')
            return state
    
        def __setstate__(self, state):
            self.__dict__.update(state)
    
    
        def virtualBand(self):
            return self.mVirtualBand
    
    class VRTRasterBand(QObject):
        sigNameChanged = pyqtSignal(str)
        sigSourceInserted = pyqtSignal(int, VRTRasterInputSourceBand)
        sigSourceRemoved = pyqtSignal(int, VRTRasterInputSourceBand)
    
        def __init__(self, name='', parent=None):
    
            self.mSources = []
    
        def __len__(self):
            return len(self.mSources)
    
        def setName(self, name):
    
            assert isinstance(name, str)
    
            oldName = self.mName
            self.mName = name
            if oldName != self.mName:
                self.sigNameChanged.emit(name)
    
        def name(self):
            return self.mName
    
        def addSource(self, virtualBandInputSource):
            assert isinstance(virtualBandInputSource, VRTRasterInputSourceBand)
    
            self.insertSource(len(self.mSources), virtualBandInputSource)
    
    
        def insertSource(self, index, virtualBandInputSource):
            assert isinstance(virtualBandInputSource, VRTRasterInputSourceBand)
    
            virtualBandInputSource.mVirtualBand = self
    
            if index <= len(self.mSources):
                self.mSources.insert(index, virtualBandInputSource)
    
                self.sigSourceInserted.emit(index, virtualBandInputSource)
            else:
    
                pass
                #print('DEBUG: index <= len(self.sources)')
    
        def bandIndex(self):
    
            if isinstance(self.mVRT, VRTRaster):
                return self.mVRT.mBands.index(self)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
    
            Removes a VRTRasterInputSourceBand
            :param vrtRasterInputSourceBand: band index| VRTRasterInputSourceBand
            :return: The VRTRasterInputSourceBand that was removed
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
    
            if not isinstance(vrtRasterInputSourceBand, VRTRasterInputSourceBand):
    
                vrtRasterInputSourceBand = self.mSources[vrtRasterInputSourceBand]
            if vrtRasterInputSourceBand in self.mSources:
                i = self.mSources.index(vrtRasterInputSourceBand)
                self.mSources.remove(vrtRasterInputSourceBand)
    
                self.sigSourceRemoved.emit(i, vrtRasterInputSourceBand)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            :return: list of file-paths to all source files
            """
    
            files = set([inputSource.mPath for inputSource in self.mSources])
    
            return sorted(list(files))
    
        def __repr__(self):
            infos = ['VirtualBand name="{}"'.format(self.mName)]
    
            for i, info in enumerate(self.mSources):
    
                assert isinstance(info, VRTRasterInputSourceBand)
                infos.append('\t{} SourceFileName {} SourceBand {}'.format(i + 1, info.mPath, info.mBandIndex))
    
        sigSourceBandInserted = pyqtSignal(VRTRasterBand, VRTRasterInputSourceBand)
        sigSourceBandRemoved = pyqtSignal(VRTRasterBand, VRTRasterInputSourceBand)
        sigSourceRasterAdded = pyqtSignal(list)
        sigSourceRasterRemoved = pyqtSignal(list)
        sigBandInserted = pyqtSignal(int, VRTRasterBand)
        sigBandRemoved = pyqtSignal(int, VRTRasterBand)
        sigCrsChanged = pyqtSignal(QgsCoordinateReferenceSystem)
    
        sigResamplingAlgChanged = pyqtSignal([str],[int])
    
        def __init__(self, parent=None):
            super(VRTRaster, self).__init__(parent)
            self.mBands = []
            self.mCrs = None
    
            self.mResamplingAlg = gdal.GRA_NearestNeighbour
    
            self.mMetadata = dict()
            self.mSourceRasterBounds = dict()
    
            self.sigSourceBandRemoved.connect(self.updateSourceRasterBounds)
            self.sigSourceBandInserted.connect(self.updateSourceRasterBounds)
            self.sigBandRemoved.connect(self.updateSourceRasterBounds)
            self.sigBandInserted.connect(self.updateSourceRasterBounds)
    
    
    
        def setResamplingAlg(self, value):
            """
            Sets the resampling algorithm
            :param value:
                - Any gdal.GRA_* constant, like gdal.GRA_NearestNeighbor
                - nearest,bilinear,cubic,cubicspline,lanczos,average,mode
                - None (will set the default value to 'nearest'
            """
            last = self.mResamplingAlg
    
    
            possibleNames = RESAMPLE_ALGS.optionNames()
            possibleValues = RESAMPLE_ALGS.optionValues()
    
    
            if value is None:
                self.mResamplingAlg = gdal.GRA_NearestNeighbour
    
            elif value in possibleNames:
                self.mResamplingAlg = possibleValues[possibleNames.index(value)]
            elif value in possibleValues:
    
            else:
                raise Exception('Unknown value "{}"'.format(value))
    
                self.sigResamplingAlgChanged[str].emit(self.resamplingAlg(asString=True))
                self.sigResamplingAlgChanged[int].emit(self.resamplingAlg())
    
    
    
        def resamplingAlg(self, asString=False):
            """
            "Returns the resampling algorithms.
            :param asString: Set True to return the resampling algorithm as string.
            :return:  gdal.GRA* constant or descriptive string.
            """
            if asString:
    
                i = RESAMPLE_ALGS.optionValues().index(self.mResamplingAlg)
    
                return RESAMPLE_ALGS.optionNames()[i]
    
                return self.mResamplingAlg
    
    
        def setExtent(self, rectangle, crs=None):
            last = self.mExtent
            if rectangle is None:
                #use implicit/automatic values
                self.mExtent = None
            else:
                if isinstance(crs, QgsCoordinateReferenceSystem) and isinstance(self.mCrs, QgsCoordinateReferenceSystem):
    
                    trans = QgsCoordinateTransform()
                    trans.setSourceCrs(crs)
                    trans.setDestinationCrs(self.mCrs)
    
                    rectangle = trans.transform(rectangle)
    
                assert isinstance(rectangle, QgsRectangle)
                assert rectangle.width() > 0
                assert rectangle.height() > 0
                self.mExtent = rectangle
    
            if last != self.mExtent:
                self.sigExtentChanged.emit()
            pass
    
        def extent(self):
            return self.mExtent
    
        def setResolution(self, xy):
            """
            Set the VRT resolution.
            :param xy: explicit value given as QSizeF(x,y) object or
                       implicit as 'highest','lowest','average'
            """
            last = self.mResolution
            if xy is None:
                self.mResolution = 'average'
            else:
                if isinstance(xy, QSizeF):
                    assert xy.width() > 0
                    assert xy.height() > 0
                    self.mResolution = QSizeF(xy)
    
                elif isinstance(xy, str):
    
                    assert xy in ['average','highest','lowest']
                    self.mResolution = xy
    
            if last != self.mResolution:
                self.sigResolutionChanged.emit()
    
        def resolution(self):
            """
            Returns the internal resolution descriptor, which can be
            an explicit QSizeF(x,y) or one of following strings: 'average','highest','lowest'
            """
            return self.mResolution
    
    
    
            """
            Sets the output Coordinate Reference System (CRS)
            :param crs: osr.SpatialReference or QgsCoordinateReferenceSystem
            :return:
            """
    
            if isinstance(crs, osr.SpatialReference):
                auth = '{}:{}'.format(crs.GetAttrValue('AUTHORITY',0), crs.GetAttrValue('AUTHORITY',1))
                crs = QgsCoordinateReferenceSystem(auth)
            if isinstance(crs, QgsCoordinateReferenceSystem):
                if crs != self.mCrs:
    
                    extent = self.extent()
                    if isinstance(extent, QgsRectangle):
    
                        trans = QgsCoordinateTransform()
                        trans.setDestinationCrs(self.mCrs, crs)
    
                        extent = trans.transform(extent)
                        self.setExtent(extent)
    
                    self.mCrs = crs
                    self.sigCrsChanged.emit(self.mCrs)
    
    
        def crs(self):
            return self.mCrs
    
    
        def addVirtualBand(self, virtualBand):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Adds a virtual band
            :param virtualBand: the VirtualBand to be added
            :return: VirtualBand
            """
    
            return self.insertVirtualBand(len(self), virtualBand)
    
        def insertSourceBand(self, virtualBandIndex, pathSource, sourceBandIndex):
            """
            Inserts a source band into the VRT stack
            :param virtualBandIndex: target virtual band index
            :param pathSource: path of source file
            :param sourceBandIndex: source file band index
            """
    
    
            while virtualBandIndex > len(self.mBands)-1:
    
                self.insertVirtualBand(len(self.mBands), VRTRasterBand())
    
            vBand.addSourceBand(pathSource, sourceBandIndex)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Inserts a VirtualBand
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            :param virtualBand: the VirtualBand to be inserted
            :return: the VirtualBand
            """
    
            assert isinstance(virtualBand, VRTRasterBand)
            assert index <= len(self.mBands)
            if len(virtualBand.name()) == 0:
                virtualBand.setName('Band {}'.format(index+1))
    
            virtualBand.mVRT = self
    
            virtualBand.sigSourceInserted.connect(
                lambda _, sourceBand: self.sigSourceBandInserted.emit(virtualBand, sourceBand))
            virtualBand.sigSourceRemoved.connect(
                lambda _, sourceBand: self.sigSourceBandInserted.emit(virtualBand, sourceBand))
    
            self.mBands.insert(index, virtualBand)
            self.sigBandInserted.emit(index, virtualBand)
    
            return self[index]
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
    
        def removeVirtualBands(self, bandsOrIndices):
            assert isinstance(bandsOrIndices, list)
            to_remove = []
    
            for virtualBand in bandsOrIndices:
                if not isinstance(virtualBand, VRTRasterBand):
                    virtualBand = self.mBands[virtualBand]
                to_remove.append((self.mBands.index(virtualBand), virtualBand))
    
            to_remove = sorted(to_remove, key=lambda t: t[0], reverse=True)
            for index, virtualBand in to_remove:
                self.mBands.remove(virtualBand)
                self.sigBandRemoved.emit(index, virtualBand)
    
        def removeInputSource(self, path):
            assert path in self.sourceRaster()
            for vBand in self.mBands:
                assert isinstance(vBand, VRTRasterBand)
    
                if path in vBand.mSources():
    
                    vBand.removeSource(path)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def removeVirtualBand(self, bandOrIndex):
    
        def addFilesAsMosaic(self, files):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Shortcut to mosaic all input files. All bands will maintain their band position in the virtual file.
            :param files: [list-of-file-paths]
            """
    
            for file in files:
                ds = gdal.Open(file)
                assert isinstance(ds, gdal.Dataset)
                nb = ds.RasterCount
                for b in range(nb):
                    if b+1 < len(self):
                        #add new virtual band
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    vBand = self[b]
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    vBand.addSourceBand(file, b)
            return self
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
            Shortcut to stack all input files, i.e. each band of an input file will be a new virtual band.
            Bands in the virtual file will be ordered as file1-band1, file1-band n, file2-band1, file2-band,...
            :param files: [list-of-file-paths]
            :return: self
            """
            for file in files:
                ds = gdal.Open(file)
    
                assert isinstance(ds, gdal.Dataset), 'Can not open {}'.format(file)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                nb = ds.RasterCount
                ds = None
                for b in range(nb):
                    #each new band is a new virtual band
    
                    vBand = self.addVirtualBand(VRTRasterBand())
                    assert isinstance(vBand, VRTRasterBand)
                    vBand.addSource(VRTRasterInputSourceBand(file, b))
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            return self
    
            for vBand in self.mBands:
                assert isinstance(vBand, VRTRasterBand)
    
                files.update(set(vBand.sourceFiles()))
            return sorted(list(files))
    
    
        def sourceRasterBounds(self):
            return self.mSourceRasterBounds
    
    
        def updateSourceRasterBounds(self):
    
            srcFiles = self.sourceRaster()
            toRemove = [f for f in self.mSourceRasterBounds.keys() if f not in srcFiles]
            toAdd = [f for f in srcFiles if f not in self.mSourceRasterBounds.keys()]
    
            for f in toRemove:
                del self.mSourceRasterBounds[f]
            for f in toAdd:
                self.mSourceRasterBounds[f] = RasterBounds(f)
    
            if len(srcFiles) > 0 and self.crs() == None:
                self.setCrs(self.mSourceRasterBounds[srcFiles[0]].crs)
    
            elif len(srcFiles) == 0:
                self.setCrs(None)
    
    
            if len(toRemove) > 0:
                self.sigSourceRasterRemoved.emit(toRemove)
            if len(toAdd) > 0:
                self.sigSourceRasterAdded.emit(toAdd)
    
    
        def loadVRT(self, pathVRT, bandIndex = None):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
    
            Load the VRT definition in pathVRT and appends it to this VRT
            :param pathVRT:
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            """
    
            if pathVRT in [None,'']:
                return
    
            if bandIndex is None:
                bandIndex = len(self.mBands)
    
            ds = gdal.Open(pathVRT)
            assert isinstance(ds, gdal.Dataset)
            assert ds.GetDriver().GetDescription() == 'VRT'
    
            for b in range(ds.RasterCount):
                srcBand = ds.GetRasterBand(b+1)
                vrtBand = VRTRasterBand(name=srcBand.GetDescription().decode('utf-8'))
                for key, xml in srcBand.GetMetadata(str('vrt_sources')).items():
    
                    tree = ElementTree.fromstring(xml)
    
                    srcPath = tree.find('SourceFilename').text
    
                    srcBandIndex = int(tree.find('SourceBand').text)
                    vrtBand.addSource(VRTRasterInputSourceBand(srcPath, srcBandIndex))
    
                self.insertVirtualBand(bandIndex, vrtBand)
                bandIndex += 1
    
    
    
    
    
        def saveVRT(self, pathVRT, warpedImageFolder = '.warpedimage'):
            """
            Save the VRT to path.
            If source images need to be warped to the final CRS warped VRT image will be created in a folder <directory>/<basename>+<warpedImageFolder>/
    
            :param pathVRT: str, path of final VRT.
            :param warpedImageFolder: basename of folder that is created
            :return:
            """
            """
            :param pathVRT: 
            :return:
            """
            assert len(self) >= 1, 'VRT needs to define at least 1 band'
    
            assert os.path.splitext(pathVRT)[-1].lower() == '.vrt'
    
            srcLookup = dict()
    
            inMemory = pathVRT.startswith('/vsimem/')
    
            if inMemory:
                dirWarped = '/vsimem/'
            else:
                dirWarped = os.path.join(os.path.splitext(pathVRT)[0] + '.WarpedImages')
    
            drvVRT = gdal.GetDriverByName('VRT')
    
            for i, pathSrc in enumerate(self.sourceRaster()):
                dsSrc = gdal.Open(pathSrc)
                assert isinstance(dsSrc, gdal.Dataset)
                band = dsSrc.GetRasterBand(1)
    
                if noData and srcNodata is None:
                    srcNodata = noData
    
    
                crs = QgsCoordinateReferenceSystem(dsSrc.GetProjection())
    
                if crs == self.mCrs:
                    srcLookup[pathSrc] = pathSrc
                else:
    
                    #do a CRS transformation using VRTs
    
                    warpedFileName = 'warped.{}.vrt'.format(os.path.basename(pathSrc))
                    if inMemory:
                        warpedFileName = dirWarped + warpedFileName
                    else:
                        os.makedirs(dirWarped, exist_ok=True)
                        warpedFileName = os.path.join(dirWarped, warpedFileName)
    
    
                    wops = gdal.WarpOptions(format='VRT',
                                            dstSRS=self.mCrs.toWkt())
    
                    tmp = gdal.Warp(warpedFileName, dsSrc, options=wops)
    
                    vrtXML = read_vsimem(warpedFileName)
                    xml = ElementTree.fromstring(vrtXML)
                    #print(vrtXML.decode('utf-8'))
    
                    if False:
                        dsTmp = gdal.Open(warpedFileName)
                        assert isinstance(dsTmp, gdal.Dataset)
                        drvVRT.Delete(warpedFileName)
                        dsTmp = gdal.Open(warpedFileName)
                        assert not isinstance(dsTmp, gdal.Dataset)
    
                    srcLookup[pathSrc] = warpedFileName
    
            srcFiles = [srcLookup[src] for src in self.sourceRaster()]
    
            #these need to be set
            ns = nl = gt = crs = eType = None
    
            extent = self.extent()
    
            srs = None
            if isinstance(self.crs(), QgsCoordinateReferenceSystem):
                srs = self.crs().toWkt()
    
            if len(srcFiles) > 0:
                # 1. build a temporary VRT that describes the spatial shifts of all input sources
                kwds = {}
                if res is None:
                    res = 'average'
                if isinstance(res, QSizeF):
                    kwds['resolution'] = 'user'
                    kwds['xRes'] = res.width()
                    kwds['yRes'] = res.height()
                else:
                    assert res in ['highest','lowest','average']
                    kwds['resolution'] = res
    
                if isinstance(extent, QgsRectangle):
                    kwds['outputBounds'] = (extent.xMinimum(), extent.yMinimum(), extent.xMaximum(), extent.yMaximum())
    
                if srs is not None:
                    kwds['outputSRS'] = srs
    
    
    
                pathInMEMVRT = '/vsimem/{}.vrt'.format(uuid.uuid4())
                vro = gdal.BuildVRTOptions(separate=True, **kwds)
                dsVRTDst = gdal.BuildVRT(pathInMEMVRT, srcFiles, options=vro)
    
                assert isinstance(dsVRTDst, gdal.Dataset)
    
                ns, nl = dsVRTDst.RasterXSize, dsVRTDst.RasterYSize
                gt = dsVRTDst.GetGeoTransform()
                crs = dsVRTDst.GetProjectionRef()
                eType = dsVRTDst.GetRasterBand(1).DataType
                SOURCE_TEMPLATES = dict()
                for i, srcFile in enumerate(srcFiles):
                    vrt_sources = dsVRTDst.GetRasterBand(i+1).GetMetadata(str('vrt_sources'))
                    assert len(vrt_sources) == 1
                    srcXML = vrt_sources['source_0']
                    assert os.path.basename(srcFile)+'</SourceFilename>' in srcXML
                    assert '<SourceBand>1</SourceBand>' in srcXML
                    SOURCE_TEMPLATES[srcFile] = srcXML
    
                drvVRT.Delete(pathInMEMVRT)
    
            else:
                # special case: no source files defined
                ns = nl = 1 #this is the minimum size
                if isinstance(extent, QgsRectangle):
                    x0 = extent.xMinimum()
                    y1 = extent.yMaximum()
                else:
                    x0 = 0
                    y1 = 0
    
                if isinstance(res, QSizeF):
                    resx = res.width()
                    resy = res.height()
                else:
                    resx = 1
                    resy = 1
    
                gt = (x0, resx, 0, y1, 0, -resy)
                eType = gdal.GDT_Float32
    
            drvVRT = gdal.GetDriverByName('VRT')
    
            assert isinstance(drvVRT, gdal.Driver)
    
            dsVRTDst = drvVRT.Create(pathVRT, ns, nl,0, eType=eType)
    
            #2.1. set general properties
            assert isinstance(dsVRTDst, gdal.Dataset)
    
    
            if srs is not None:
                dsVRTDst.SetProjection(srs)
    
            dsVRTDst.SetGeoTransform(gt)
    
            #2.2. add virtual bands
    
            for i, vBand in enumerate(self.mBands):
                assert isinstance(vBand, VRTRasterBand)
    
                assert dsVRTDst.AddBand(eType, options=['subClass=VRTSourcedRasterBand']) == 0
                vrtBandDst = dsVRTDst.GetRasterBand(i+1)
                assert isinstance(vrtBandDst, gdal.Band)
    
                vrtBandDst.SetDescription(vBand.name())
    
                md = {}
                #add all input sources for this virtual band
    
                for iSrc, sourceInfo in enumerate(vBand.mSources):
    
                    assert isinstance(sourceInfo, VRTRasterInputSourceBand)
                    bandIndex = sourceInfo.mBandIndex
                    xml = SOURCE_TEMPLATES[srcLookup[sourceInfo.mPath]]
    
                    xml = re.sub('<SourceBand>1</SourceBand>', '<SourceBand>{}</SourceBand>'.format(bandIndex+1), xml)
    
                    md['source_{}'.format(iSrc)] = xml
    
                vrtBandDst.SetMetadata(md,'vrt_sources')
    
    
    
            dsVRTDst = None
    
            #check if we get what we like to get
            dsCheck = gdal.Open(pathVRT)
    
            assert isinstance(dsCheck, gdal.Dataset)
    
    
        def __repr__(self):
    
            info = ['VirtualRasterBuilder: {} bands, {} source files'.format(
    
                len(self.mBands), len(self.sourceRaster()))]
            for vBand in self.mBands:
    
                info.append(str(vBand))
            return '\n'.join(info)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def __len__(self):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def __getitem__(self, slice):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def __delitem__(self, slice):
            self.removeVirtualBands(self[slice])
    
        def __contains__(self, item):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    
        def __iter__(self):
            return iter(self.mClasses)
    
    
    def createVirtualBandMosaic(bandFiles, pathVRT):
        drv = gdal.GetDriverByName('VRT')
    
        refPath = bandFiles[0]
        refDS = gdal.Open(refPath)
        ns, nl, nb = refDS.RasterXSize, refDS.RasterYSize, refDS.RasterCount
        noData = refDS.GetRasterBand(1).GetNoDataValue()
    
        vrtOptions = gdal.BuildVRTOptions(
            # here we can use the options known from http://www.gdal.org/gdalbuildvrt.html
            separate=False
        )
        if len(bandFiles) > 1:
            s =""
        vrtDS = gdal.BuildVRT(pathVRT, bandFiles, options=vrtOptions)
        vrtDS.FlushCache()
    
        assert vrtDS.RasterCount == nb
        return vrtDS
    
    def createVirtualBandStack(bandFiles, pathVRT):
    
        nb = len(bandFiles)
    
        drv = gdal.GetDriverByName('VRT')
    
        refPath = bandFiles[0]
        refDS = gdal.Open(refPath)
        ns, nl = refDS.RasterXSize, refDS.RasterYSize
        noData = refDS.GetRasterBand(1).GetNoDataValue()
    
        vrtOptions = gdal.BuildVRTOptions(
            # here we can use the options known from http://www.gdal.org/gdalbuildvrt.html
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            separate=True,
    
        )
        vrtDS = gdal.BuildVRT(pathVRT, bandFiles, options=vrtOptions)
        vrtDS.FlushCache()
    
        assert vrtDS.RasterCount == nb
    
        #copy band metadata from
        for i in range(nb):
            band = vrtDS.GetRasterBand(i+1)
            band.SetDescription(bandFiles[i])
    
            band.ComputeBandStats()
    
    
    class RasterBounds(object):
        def __init__(self, path):
            self.path = None
            self.polygon = None
    
            self.curve = None
    
            self.crs = None
    
            if path is not None:
                self.fromImage(path)
    
    
        def fromImage(self, path):
            self.path = path
            ds = gdal.Open(path)
            assert isinstance(ds, gdal.Dataset)
            gt = ds.GetGeoTransform()
            bounds = [px2geo(QPoint(0, 0), gt),
                      px2geo(QPoint(ds.RasterXSize, 0), gt),
                      px2geo(QPoint(ds.RasterXSize, ds.RasterYSize), gt),
                      px2geo(QPoint(0, ds.RasterYSize), gt)]
            crs = QgsCoordinateReferenceSystem(ds.GetProjection())
            ring = ogr.Geometry(ogr.wkbLinearRing)
            for p in bounds:
                assert isinstance(p, QgsPoint)
                ring.AddPoint(p.x(), p.y())
    
    
            curve = ogr.Geometry(ogr.wkbLinearRing)
            curve.AddGeometry(ring)
    
            self.curve = QgsCircularString()
    
            self.curve.fromWkt(curve.ExportToWkt())
    
    
            polygon = ogr.Geometry(ogr.wkbPolygon)
            polygon.AddGeometry(ring)
    
            self.polygon = QgsPolygon()
    
            self.polygon.fromWkt(polygon.ExportToWkt())
            self.polygon.exteriorRing().close()
            assert self.polygon.exteriorRing().isClosed()
    
            self.crs = crs
    
    
        def __repr__(self):
            return self.polygon.asWkt()