Skip to content
Snippets Groups Projects
classificationscheme.py 58.2 KiB
Newer Older
# -*- coding: utf-8 -*-

"""
***************************************************************************
    classificationscheme.py

    Methods and Objects to describe raster classifications
    ---------------------
    Date                 : Juli 2017
    Copyright            : (C) 2017 by Benjamin Jakimow
    Email                : benjamin.jakimow@geo.hu-berlin.de
***************************************************************************
*                                                                         *
*   This program is free software; you can redistribute it and/or modify  *
*   it under the terms of the GNU General Public License as published by  *
*   the Free Software Foundation; either version 2 of the License, or     *
*   (at your option) any later version.                                   *
*                                                                         *
***************************************************************************
"""

Benjamin Jakimow's avatar
Benjamin Jakimow committed
import os, json, pickle, warnings, csv, re, sys
from qgis.core import *
from qgis.gui import *
from qgis.PyQt.QtCore import *
from qgis.PyQt.QtGui import *
from qgis.PyQt.QtWidgets import *
import numpy as np
from osgeo import gdal
Benjamin Jakimow's avatar
Benjamin Jakimow committed
from qps.utils import gdalDataset, nextColor, loadUIFormClass, findMapLayer, registeredMapLayers


loadClassificationUI = lambda name: loadUIFormClass(os.path.join(os.path.dirname(__file__), name))

DEFAULT_UNCLASSIFIEDCOLOR = QColor('black')
DEFAULT_FIRST_COLOR = QColor('#a6cee3')

MIMEDATA_KEY = 'hub-classscheme'
MIMEDATA_KEY_TEXT = 'text/plain'
MIMEDATA_INTERNAL_IDs = 'classinfo_ids'


Benjamin Jakimow's avatar
Benjamin Jakimow committed
def findMapLayersWithClassInfo()->list:
    """
    Returns QgsMapLayers from which a ClassificationScheme can be derived.
    Searches in all QgsMapLayerStores known to classification.MAP_LAYER_STORES
    :return: [list-of-QgsMapLayer]
    """

Benjamin Jakimow's avatar
Benjamin Jakimow committed
    results = []
    for lyr in registeredMapLayers():
        if isinstance(lyr, QgsVectorLayer) and isinstance(lyr.renderer(), QgsCategorizedSymbolRenderer):
            results.append(lyr)
        elif isinstance(lyr, QgsRasterLayer) and isinstance(lyr.renderer(), QgsPalettedRasterRenderer):
            results.append(lyr)
    return results




def hasClassification(pathOrDataset):
    """
    This function tests if a gdal-readable raster data set contains
    categorical information that can be used to retrieve a ClassificationScheme
    :param pathOrDataset: string | gdal.Dataset
    :return: True | False
    """
    ds = None
Benjamin Jakimow's avatar
Benjamin Jakimow committed
    try:
        if isinstance(pathOrDataset, gdal.Dataset):
            ds = pathOrDataset
        elif isinstance(pathOrDataset, str):
            ds = gdal.Open(pathOrDataset)
        elif isinstance(ds, QgsRasterLayer):
            ds = gdal.Open(ds.source())
    except Exception as ex:
        pass

    if not isinstance(ds, gdal.Dataset):
        return False

    for b in range(ds.RasterCount):
        band = ds.GetRasterBand(b + 1)
        assert isinstance(band, gdal.Band)
        if band.GetCategoryNames() or band.GetColorTable():
            return True
    return False


def getTextColorWithContrast(c:QColor)->QColor:
    """
    Returns a QColor with good contrast to c
    :param c: QColor
    :return: QColor
    """
    assert isinstance(c, QColor)
    if c.lightness() < 0.5:
        return QColor('white')
    else:
        return QColor('black')



class ClassInfo(QObject):
    sigSettingsChanged = pyqtSignal()

    def __init__(self, label=0, name=None, color=None, parent=None):
        super(ClassInfo, self).__init__(parent)

        if name is None:
            name = 'Unclassified' if label == 0 else 'Class {}'.format(label)

        if color is None:
            if label == 0:
                color = DEFAULT_UNCLASSIFIEDCOLOR
            else:
                color = DEFAULT_FIRST_COLOR


        self.mName = name
        self.mLabel = label
        self.mColor = color
        if color:
            self.setColor(color)


    def setLabel(self, label:int):
        """
        Sets the label value.
        :param label: int, must be >= 0
        """
        assert isinstance(label, int)
        assert label >= 0
        self.mLabel = label
        self.sigSettingsChanged.emit()

    def label(self)->int:
        """
        Returns the class label values
        :return: int
        """
        return self.mLabel

    def color(self)->QColor:
        """
        Returns the class color.
        :return: QColor
        """
        return QColor(self.mColor)

    def name(self)->str:
        """
        Returns the class name
        :return: str
        """
        return self.mName

    def setColor(self, color:QColor):
        """
        Sets the class color.
        :param color: QColor
        """
        assert isinstance(color, QColor)
        self.mColor = color
        self.sigSettingsChanged.emit()

    def setName(self, name:str):
        """
        Sets thes class name
        :param name: str
        """
        assert isinstance(name, str)
        self.mName = name
        self.sigSettingsChanged.emit()


    def pixmap(self, *args)->QPixmap:
        """
        Returns a QPixmap. Default size is 20x20px
        :param args: QPixmap arguments.
        :return: QPixmap
        """
        if len(args) == 0:
            args = (QSize(20, 20),)

        pm = QPixmap(*args)
        pm.fill(self.mColor)
        return pm

    def icon(self, *args)->QIcon:
        """
        Returns the class color as QIcon
        :param args: QPixmap arguments
        :return: QIcon
        """
        return QIcon(self.pixmap(*args))

    def clone(self):
        """
        Create a copy of this ClassInfo
        :return: ClassInfo
        """
        return ClassInfo(name=self.mName, color=self.mColor)

    def __ne__(self, other):
        return not self.__eq__(other)

    def __eq__(self, other):
        if not isinstance(other, ClassInfo):
            return False
        return other.mName == self.mName and \
               other.mLabel == self.mLabel and \
               other.mColor.getRgb() == self.mColor.getRgb()

    def __repr__(self):
        return 'ClassInfo' + self.__str__()

    def __str__(self):
        return '{} "{}" ({})'.format(self.mLabel, self.mName, self.mColor.name())

    def json(self)->str:
        return json.dumps([self.label(), self.name(), self.color().name()])

    def fromJSON(self, jsonString:str):
        try:
            label, name, color = json.loads(jsonString)
            color = QColor(color)
            return ClassInfo(label=label, name=name, color=color)
        except:
            return None


class ClassificationScheme(QAbstractTableModel):

    sigClassesRemoved = pyqtSignal(list)
    #sigClassRemoved = pyqtSignal(ClassInfo, int)
    #sigClassAdded = pyqtSignal(ClassInfo, int)
    sigClassesAdded = pyqtSignal(list)
    sigNameChanged = pyqtSignal(str)

Benjamin Jakimow's avatar
Benjamin Jakimow committed
    def __init__(self, name : str = None):
        super(ClassificationScheme, self).__init__()
        self.mClasses = []
        self.mName = name
        self.mIsEditable = True

Benjamin Jakimow's avatar
Benjamin Jakimow committed
        if name is None:
            name = 'Classification'

        self.mColColor = 'Color'
        self.mColName = 'Name'
        self.mColLabel = 'Label'

    def setIsEditable(self, b:bool):
        """
        Sets if class names and colors can be changed
        :param b: bool
        """
        if b != self.mIsEditable:
            self.mIsEditable = True
            self.dataChanged(self.createIndex(0,0),
                             self.createIndex(self.rowCount()-1, self.columnCount()-1))

    def isEditable(self)->bool:
        """
        Returns if class names and colors can be changed.
        :return: bool
        """
        return self.mIsEditable

    def columnNames(self)->list:
        """
        Returns the column names.
        :return: [list-of-str]
        """
        return [self.mColLabel, self.mColName, self.mColColor]

    def dropMimeData(self, mimeData:QMimeData, action:Qt.DropAction, row:int, column:int, parent:QModelIndex):
        if row == -1:
            row = parent.row()
        if action == Qt.MoveAction:
            if MIMEDATA_INTERNAL_IDs in mimeData.formats():
                ba = bytes(mimeData.data(MIMEDATA_INTERNAL_IDs))
                ids = pickle.loads(ba)

                classesToBeMoved = [c for c in self if id(c) in ids]
                self.beginResetModel()
                for c in reversed(classesToBeMoved):
                    idx = self.classInfo2index(c)


                    #self.beginMoveRows(QModelIndex(), idx.row(), idx.row(), QModelIndex(), row)
                    del self.mClasses[idx.row()]
                    self.mClasses.insert(row, c)
                    #self.endMoveRows()
                self.endResetModel()
                self._updateLabels()
                return True
        elif action == Qt.CopyAction:
            if MIMEDATA_KEY in mimeData.formats():
                cs = ClassificationScheme.fromQByteArray(mimeData.data(MIMEDATA_KEY))
                self.insertClasses(cs[:], row)

        return False

    def mimeData(self, indexes)->QMimeData:
        """
        Returns class infos as QMimeData.
        :param indexes:
        :return:
        """

        if indexes is None:
            indexes = [self.createIndex(r, 0) for r in range(len(self))]

        classes = [self[idx.row()] for idx in indexes]
        cs = ClassificationScheme()
        cs.insertClasses(classes)
        mimeData = QMimeData()
        mimeData.setData(MIMEDATA_KEY, cs.qByteArray())
        mimeData.setData(MIMEDATA_INTERNAL_IDs, QByteArray(pickle.dumps([id(c) for c in classes ])))
        mimeData.setText(cs.toString())
        return mimeData

    def mimeTypes(self)->list:
        """
        Returns a list of supported mimeTypes.
        :return: [list-of-str]
        """
        return [MIMEDATA_KEY, MIMEDATA_INTERNAL_IDs, MIMEDATA_KEY_TEXT]


    def rowCount(self, parent:QModelIndex=None):
        """
        Returns the number of row / ClassInfos.
        :param parent: QModelIndex
        :return: int
        """
        return len(self.mClasses)

    def columnCount(self, parent: QModelIndex=None):
        return len(self.columnNames())


    def index2ClassInfo(self, index)->ClassInfo:
        if isinstance(index, QModelIndex):
            index = index.row()
        return self.mClasses[index]

    def classInfo2index(self, classInfo:ClassInfo)->QModelIndex:
        row = self.mClasses.index(classInfo)
        return self.createIndex(row, 0)


    def data(self, index: QModelIndex, role: int = Qt.DisplayRole):
        if not index.isValid():
            return None

        value = None
        col = index.column()
        row = index.row()
        classInfo = self.index2ClassInfo(row)

        if role == Qt.DisplayRole:
            if col == 0:
                return classInfo.label()
            if col == 1:
                return classInfo.name()
            if col == 2:
                return classInfo.color().name()

        if role == Qt.ForegroundRole:
            if col == self.mColColor:
                return QBrush(getTextColorWithContrast(classInfo.color()))

        if role == Qt.BackgroundColorRole:
            if col == 2:
                return QBrush(classInfo.color())

        if role == Qt.AccessibleTextRole:
            if col == 0:
                return str(classInfo.label())
            if col == 1:
                return classInfo.name()
            if col == 2:
                return classInfo.color().name()

        if role == Qt.ToolTipRole:
            if col == 0:
                return 'Class label "{}"'.format(classInfo.label())
            if col == 1:
                return 'Class name "{}"'.format(classInfo.name())
            if col == 2:
                return 'Class color "{}"'.format(classInfo.color().name())

        if role == Qt.EditRole:
            if col == 1:
                return classInfo.name()
            if col == 2:
                return classInfo.color()

        if role == Qt.UserRole:
            return classInfo

        return None

    def supportedDragActions(self):
        return Qt.MoveAction

    def supportedDropActions(self):
        return Qt.MoveAction | Qt.CopyAction

    def setData(self, index: QModelIndex, value, role: int):
        if not index.isValid():
            return False

        col = index.column()
        row = index.row()
        classInfo = self.index2ClassInfo(row)
        b = False
        if role == Qt.EditRole:
            if col == 1:
                classInfo.setName(value)
                b = True
            if col == 2:
                classInfo.setColor(value)
                b = True
        if b:
            self.dataChanged.emit(index, index, [role])
        return False

    def flags(self, index: QModelIndex):
        if not index.isValid():
            return Qt.NoItemFlags
        col = index.column()

        flags = Qt.ItemIsSelectable | Qt.ItemIsEnabled
        if self.mIsEditable:
            flags |= Qt.ItemIsDragEnabled | Qt.ItemIsDropEnabled
            if col == 1:
                flags |= Qt.ItemIsEditable
        return flags


    def headerData(self, section: int, orientation: Qt.Orientation, role: int = Qt.DisplayRole):

        if role == Qt.DisplayRole:
            if orientation == Qt.Horizontal:
                return self.columnNames()[section]

        return super(ClassificationScheme, self).headerData(section, orientation, role)


    def setName(self, name:str='')->str:
        """
        Sets ClassificationScheme name
        :param name: str
        :return: str, the name
        """
        b = name != self.mName
        self.mName = name
        if b:
            self.sigNameChanged.emit(self.mName)
        return self.mName

    def name(self)->str:
        """
        Returns the ClassificationScheme name
        :return:
        """
        return self.mName

    def json(self)->str:
        """
        Returns a JSON string of this ClassificationScheme which can be deserialized with ClassificationScheme.fromJSON()
        :return: str, JSON string
        """
        data = {'name':self.mName,
                'classes':[(c.label(), c.name(), c.color().name()) for c in self]
                }

        return json.dumps(data)

    def pickle(self)->bytes:
        """
        Serializes this ClassificationScheme a byte object, which can be deserializes with ClassificationScheme.fromPickle()
        :return: bytes
        """
        return pickle.dumps(self.json())

    def qByteArray(self)->QByteArray:
        """
        Serializes this ClassicationScheme as QByteArray.
        Can be deserialized with ClassificationScheme.fromQByteArray()
        :return: QByteArray
        """
        return QByteArray(self.pickle())

    @staticmethod
    def fromQByteArray(array:QByteArray):
        return ClassificationScheme.fromPickle(bytes(array))

    @staticmethod
    def fromPickle(pkl:bytes):
        return ClassificationScheme.fromJson(pickle.loads(pkl))


    @staticmethod
    def fromFile(p:str):
        try:
            if os.path.isfile(p):
                if p.endswith('.json'):
                    jsonStr = None
                    with open(p, 'r') as f:
                        jsonStr = f.read()
                    return ClassificationScheme.fromJson(jsonStr)

        except Exception as ex:
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            print(ex, file=sys.stderr)
        return None

    @staticmethod
    def fromJson(jsonStr:str):
        try:
            data = json.loads(jsonStr)

            s = ""
            cs = ClassificationScheme(name= data['name'])
            classes = []
            for classData in data['classes']:
                label, name, colorName = classData
                classes.append(ClassInfo(label=label, name=name, color=QColor(colorName)))
            cs.insertClasses(classes)
            return cs
        except Exception as ex:
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            print(ex, file=sys.stderr)
            return None


    def rasterRenderer(self, band=0)->QgsPalettedRasterRenderer:
        """
        Returns the ClassificationScheme as QgsPalettedRasterRenderer
        :return: ClassificationScheme
        """
        #DUMMY_RASTERINTERFACE = QgsSingleBandGrayRenderer(None, 0)


        classes = []
        for classInfo in self:
            qgsClass = QgsPalettedRasterRenderer.Class(
                classInfo.label(),
                classInfo.color(),
                classInfo.name())
            classes.append(qgsClass)
        renderer = QgsPalettedRasterRenderer(None, band, classes)
        return renderer

    @staticmethod
    def fromRasterRenderer(renderer:QgsRasterRenderer):
        """
        Extracts a ClassificatonScheme from a QgsRasterRenderer
        :param renderer: QgsRasterRenderer
        :return: ClassificationScheme
        """
        if not isinstance(renderer, QgsPalettedRasterRenderer):
            return None

        classes = []
        for qgsClass in renderer.classes():
            classInfo = ClassInfo(label=qgsClass.value,
                                  name=qgsClass.label,
                                  color=QColor(qgsClass.color))
            classes.append(classInfo)

        cs = ClassificationScheme()
        cs.insertClasses(classes)

        return cs

    def featureRenderer(self)->QgsCategorizedSymbolRenderer:
        """
        Returns the ClassificationScheme as QgsCategorizedSymbolRenderer
        :return: ClassificationScheme
        """

        r = QgsCategorizedSymbolRenderer('dummy', [])

        for c in self:
            assert isinstance(c, ClassInfo)
            symbol = QgsMarkerSymbol()
            symbol.setColor(QColor(c.color()))
            cat = QgsRendererCategory(c.label(), symbol, c.name(), render=True)
            r.addCategory(cat)
        return r


    @staticmethod
    def fromFeatureRenderer(renderer:QgsCategorizedSymbolRenderer):
        """
        Extracts a ClassificatonScheme from a QgsCategorizedSymbolRenderer
        :param renderer: QgsCategorizedSymbolRenderer
        :return: ClassificationScheme
        """
        if not isinstance(renderer, QgsCategorizedSymbolRenderer):
            return None
        classes = []
        for cat in sorted(renderer.categories(), key = lambda c:c.value()):
            assert isinstance(cat, QgsRendererCategory)
            c = ClassInfo(name=cat.label(), color=QColor(cat.symbol().color()))
            classes.append(c)
        cs = ClassificationScheme()
        cs.insertClasses(classes)
        return cs


    def clear(self):
        """
        Removes all ClassInfos
        """
        self.beginRemoveColumns(QModelIndex(), 0, self.rowCount()-1)
        removed = self.mClasses[:]
        del self.mClasses[:]
        self.endRemoveRows()
        self.sigClassesRemoved.emit(removed)


    def clone(self):
        return self.copy()

    def copy(self):
        """
        Create a copy of this ClassificationScheme
        :return:
        """
        cs = ClassificationScheme()
        classes = [c.clone() for c in self.mClasses]
        cs.insertClasses(classes, 0)
        return cs

    def __getitem__(self, slice):
        return self.mClasses[slice]

    def __delitem__(self, slice):
        classes = self[slice]
        self.removeClasses(classes)

    def __contains__(self, item):
        return item in self.mClasses

    def __len__(self):
        return len(self.mClasses)

    def __iter__(self):
        return iter(self.mClasses)

    def __ne__(self, other):
        return not self.__eq__(other)

    def __eq__(self, other):
        if not (isinstance(other, ClassificationScheme) and len(self) == len(other)):
            return False
        return all(self[i] == other[i] for i in range(len(self)))

    def __str__(self):
        return self.__repr__() + '{} classes'.format(len(self))


    def range(self):
        """
        Returns the class label range (min,max).
        """
        labels = self.classLabels()
        return min(labels), max(labels)

    def classNames(self):
        """
        Returns all class names.
        :return: [list-of-class-names (str)]
        """
        return [c.name() for c in self.mClasses]

    def classColors(self):
        """
        Returns all class color.
        :return: [list-of-class-colors (QColor)]
        """
        return [QColor(c.color()) for c in self.mClasses]

    def classLabels(self)->list:
        """
        Returns the list of class labels [0,...,n-1]
        :return: [list-of-int]
        """
        return [c.label() for c in self.mClasses]

    def classColorArray(self)->np.ndarray:
        """
        Returns the RGBA class-colors as array
        :return: numpy.ndarray([nClasses,4])
        """
        return np.asarray([c.color().getRgb() for c in self])

    def gdalColorTable(self)->gdal.ColorTable:
        """
        Returns the class colors as GDAL Color Table
        :return: gdal.Colortable
        """
        ct = gdal.ColorTable()
        for i, c in enumerate(self):
            assert isinstance(c, ClassInfo)
            ct.SetColorEntry(i, c.mColor.getRgb())
        return ct

    def _updateLabels(self):
        """
        Assigns class labels according to the ClassInfo position
        """
        for i, c in enumerate(self.mClasses):
            c.mLabel = i
        self.dataChanged.emit(self.createIndex(0,0),
                              self.createIndex(self.rowCount()-1,0),
                              [Qt.DisplayRole, Qt.ToolTipRole])
        s = ""

    def removeClasses(self, classes):
        """
        Removes as single ClassInfo or a list of ClassInfos.
        :param classes: ClassInfo or [list-of-ClassInfo-to-remove]
        :returns: [list-of-removed-ClassInfo]
        """
        if isinstance(classes, ClassInfo):
            classes = [classes]
        assert isinstance(classes, list)

        removedIndices = []
        for c in classes:
            assert c in self.mClasses
            removedIndices.append(self.mClasses.index(c))

        removedIndices = list(reversed(sorted(removedIndices)))
        removedClasses = []
        for i in removedIndices:
            c = self.mClasses[i]
            self.beginRemoveRows(QModelIndex(), i, i)
            self.mClasses.remove(c)
            removedClasses.append(c)
            self.endRemoveRows()
        self._updateLabels()
        self.sigClassesRemoved.emit(removedClasses)

    def createClasses(self, n:int):
        """
        Creates n new classes with default an default initialization.
        Can be used to populate the ClassificationScheme.
        :param n: int, number of classes to add.
        """
        assert isinstance(n, int)
        assert n >= 0
        classes = []

        if len(self) > 0:
            nextCol = nextColor(self[-1].color())
        else:
            nextCol = DEFAULT_FIRST_COLOR

        for i in range(n):
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            j = len(self) + i
            if j == 0:
                color = QColor('black')
                name = 'Unclassified'
            else:
                color = QColor(nextCol)
                nextCol = nextColor(nextCol)
Benjamin Jakimow's avatar
Benjamin Jakimow committed
                name = 'Class {}'.format(j)
            classes.append(ClassInfo(name=name, color=color))
        self.insertClasses(classes)

    def addClasses(self, classes, index=None):
        warnings.warn('use insertClasses()', DeprecationWarning)
        self.insertClasses(classes, index=index)

    def insertClasses(self, classes, index=None):
        """
        Adds / inserts a list of ClassInfos
        :param classes: [list-of-ClassInfo]
        :param index: int, index to insert the first of the new classes.
                           defaults to len(ClassificationScheme)

        """
        if isinstance(classes, ClassInfo):
            classes = [ClassInfo]

        assert isinstance(classes, list)
        if len(classes) == 0:
            return

        for c in classes:
            assert isinstance(c, ClassInfo)
            assert id(c) not in [id(c) for c in self.mClasses], 'You cannot add the same ClassInfo instance to a ClassificationScheme twice. Create a copy first.'

        if index is None:
            #default: add new classes to end of list
            index = len(self.mClasses)
        #negative index? insert to beginning
        index = max(index, 0)


        self.beginInsertRows(QModelIndex(), index, index+len(classes)-1)
        for i, c in enumerate(classes):
            assert isinstance(c, ClassInfo)
            index = index + i
            #c.sigSettingsChanged.connect(self.onClassInfoSettingChanged)
            self.mClasses.insert(index, c)
        self.endInsertRows()
        self._updateLabels()
        self.sigClassesAdded.emit(classes)


    #sigClassInfoChanged = pyqtSignal(ClassInfo)
    #def onClassInfoSettingChanged(self, *args):
    #    self.sigClassInfoChanged.emit(self.sender())

    def classIndexFromValue(self, value, matchSimilarity=False)->int:
        """
        Get a values and returns the index of ClassInfo that matches best to.
        :param value: any
        :return: int
        """
        classNames = self.classNames()
        i = -1

        #1. match on identity
        if isinstance(value, (int, float)):
            i = int(value)

        elif isinstance(value, str):
            if value in classNames:
                i = classNames.index(value)

        #2. not found? match on similarity
        if i == -1 and matchSimilarity == True:
            if isinstance(value, (int, float)):
                pass

            elif isinstance(value, str):
                if value in classNames:
                    i = classNames.index(value)
            pass
        return i

    def classFromValue(self, value, matchSimilarity=False)->ClassInfo:
        i = self.classIndexFromValue(value, matchSimilarity=matchSimilarity)
        if i != -1:
            return self[i]
        else:
            return None

    def addClass(self, c, index=None):
        warnings.warn('Use insert class', DeprecationWarning)


    def insertClass(self, c, index=None):
        """
        Adds a ClassInfo
        :param c: ClassInfo
        :param index: int, index to add the ClassInfo. Defaults to the end.
        """
        assert isinstance(c, ClassInfo)
        self.insertClasses([c], index=index)


    def saveToRasterBand(self, band:gdal.Band):
        """
        Saves the ClassificationScheme to the gdal.Band.
        ClassInfo names are stored by gdal.Band.SetCategoryNames and colors as gdal.ColorTable.
        :param band: gdal.Band
        """
        assert isinstance(band, gdal.Band)
        ct = gdal.ColorTable()
        cat = []
        for i, classInfo in enumerate(self.mClasses):
            c = classInfo.mColor
            cat.append(classInfo.mName)
            assert isinstance(c, QColor)
            rgba = (c.red(), c.green(), c.blue(), c.alpha())
Benjamin Jakimow's avatar
Benjamin Jakimow committed
            ct.SetColorEntry(i, rgba)

        band.SetColorTable(ct)
        band.SetCategoryNames(cat)


    def saveToRaster(self, path, bandIndex=0):
        """
        Saves this ClassificationScheme to an raster image
        :param path: path (str) of raster image or gdal.Dataset instance
        :param bandIndex: band index of raster band to set this ClassificationScheme.
                          Defaults to 0 = the first band
        """
        if isinstance(path, str):
            ds = gdal.Open(path)
        elif isinstance(path, gdal.Dataset):
            ds = path

        assert isinstance(ds, gdal.Dataset)
        assert ds.RasterCount < bandIndex
        band = ds.GetRasterBand(bandIndex + 1)
        self.saveToRasterBand(band)


        ds = None

    def toString(self, sep=';')->str:
        """
        A quick dump of all ClassInfos
        :param sep: value separator, ';' by default
        :return: str
        """
        lines = ['ClassificationScheme("{}")'.format(self.name())]
        lines += [sep.join(['label', 'name', 'color'])]
        for classInfo in self.mClasses:
            c = classInfo.color()
            info = [classInfo.label(), classInfo.name(), c.name()]
            info = ['{}'.format(v) for v in info]
            lines.append(sep.join(info))
        return '\n'.join(lines)

    def saveToCsv(self, path:str, sep:str=';', mode:str = None)->str:
        """
        Saves the ClassificationScheme as CSV table.
        :param path: str, path of CSV file
        :param sep: separator (';' by default)
        :returns: the path of written file (if something was written)
        """
        if mode == None:
            lines = self.toString(sep=sep)
            with open(path, 'w') as f:
                f.write(lines)

            return path

        return None


    def saveToJson(self, path:str, mode:str=None)->str:
        """
        Save the ClassificationScheme as JSON file.
        :param path: str, path of JSON file
        :return: path of written file
        """
        if mode == None:
            lines = self.json()
            with open(path, 'w') as f:
                f.write(lines)
            return path

        return None


    @staticmethod
    def create(n):
        """
        Create a ClassificationScheme with n classes (including 'Unclassified' with label = 0)
        :param n: number of classes including 'Unclassified'
        :return: ClassificationScheme
        """
        s = ClassificationScheme()
        s.createClasses(n)
        return s

    @staticmethod
    def fromMimeData(mimeData:QMimeData):

        if not isinstance(mimeData, QMimeData):
            return None

        if MIMEDATA_KEY in mimeData.formats():
            ba = ClassificationScheme.fromQByteArray(mimeData.data(MIMEDATA_KEY))
            if isinstance(ba, ClassificationScheme):
                return ba
        if MIMEDATA_KEY_TEXT in mimeData.formats():

            ba = ClassificationScheme.fromQByteArray(mimeData.data(MIMEDATA_KEY_TEXT))
            if isinstance(ba, ClassificationScheme):
                return ba

        return None

Benjamin Jakimow's avatar
Benjamin Jakimow committed
    @staticmethod
    def fromMapLayer(layer:QgsMapLayer):
        """

        :param layer:
        :return:
        """
        scheme = None
        if isinstance(layer, QgsRasterLayer):
            scheme = ClassificationScheme.fromRasterRenderer(layer.renderer())
            if not isinstance(scheme, ClassificationScheme):