Skip to content
Snippets Groups Projects
classificationscheme.py 58.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • # -*- coding: utf-8 -*-
    
    """
    ***************************************************************************
        classificationscheme.py
    
        Methods and Objects to describe raster classifications
        ---------------------
        Date                 : Juli 2017
        Copyright            : (C) 2017 by Benjamin Jakimow
        Email                : benjamin.jakimow@geo.hu-berlin.de
    ***************************************************************************
    *                                                                         *
    *   This program is free software; you can redistribute it and/or modify  *
    *   it under the terms of the GNU General Public License as published by  *
    *   the Free Software Foundation; either version 2 of the License, or     *
    *   (at your option) any later version.                                   *
    *                                                                         *
    ***************************************************************************
    """
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    import os, json, pickle, warnings, csv, re, sys
    
    from qgis.core import *
    from qgis.gui import *
    from qgis.PyQt.QtCore import *
    from qgis.PyQt.QtGui import *
    from qgis.PyQt.QtWidgets import *
    import numpy as np
    from osgeo import gdal
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    from qps.utils import gdalDataset, nextColor, loadUIFormClass, findMapLayer, registeredMapLayers
    
    
    
    loadClassificationUI = lambda name: loadUIFormClass(os.path.join(os.path.dirname(__file__), name))
    
    DEFAULT_UNCLASSIFIEDCOLOR = QColor('black')
    DEFAULT_FIRST_COLOR = QColor('#a6cee3')
    
    MIMEDATA_KEY = 'hub-classscheme'
    MIMEDATA_KEY_TEXT = 'text/plain'
    MIMEDATA_INTERNAL_IDs = 'classinfo_ids'
    
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
    def findMapLayersWithClassInfo()->list:
    
        """
        Returns QgsMapLayers from which a ClassificationScheme can be derived.
        Searches in all QgsMapLayerStores known to classification.MAP_LAYER_STORES
        :return: [list-of-QgsMapLayer]
        """
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        results = []
        for lyr in registeredMapLayers():
            if isinstance(lyr, QgsVectorLayer) and isinstance(lyr.renderer(), QgsCategorizedSymbolRenderer):
                results.append(lyr)
            elif isinstance(lyr, QgsRasterLayer) and isinstance(lyr.renderer(), QgsPalettedRasterRenderer):
                results.append(lyr)
    
        return results
    
    
    
    
    def hasClassification(pathOrDataset):
        """
        This function tests if a gdal-readable raster data set contains
        categorical information that can be used to retrieve a ClassificationScheme
        :param pathOrDataset: string | gdal.Dataset
        :return: True | False
        """
        ds = None
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        try:
            if isinstance(pathOrDataset, gdal.Dataset):
                ds = pathOrDataset
            elif isinstance(pathOrDataset, str):
                ds = gdal.Open(pathOrDataset)
            elif isinstance(ds, QgsRasterLayer):
                ds = gdal.Open(ds.source())
        except Exception as ex:
            pass
    
    
        if not isinstance(ds, gdal.Dataset):
            return False
    
        for b in range(ds.RasterCount):
            band = ds.GetRasterBand(b + 1)
            assert isinstance(band, gdal.Band)
            if band.GetCategoryNames() or band.GetColorTable():
                return True
        return False
    
    
    def getTextColorWithContrast(c:QColor)->QColor:
        """
        Returns a QColor with good contrast to c
        :param c: QColor
        :return: QColor
        """
        assert isinstance(c, QColor)
        if c.lightness() < 0.5:
            return QColor('white')
        else:
            return QColor('black')
    
    
    
    class ClassInfo(QObject):
        sigSettingsChanged = pyqtSignal()
    
        def __init__(self, label=0, name=None, color=None, parent=None):
            super(ClassInfo, self).__init__(parent)
    
            if name is None:
                name = 'Unclassified' if label == 0 else 'Class {}'.format(label)
    
            if color is None:
                if label == 0:
                    color = DEFAULT_UNCLASSIFIEDCOLOR
                else:
                    color = DEFAULT_FIRST_COLOR
    
    
            self.mName = name
            self.mLabel = label
            self.mColor = color
            if color:
                self.setColor(color)
    
    
        def setLabel(self, label:int):
            """
            Sets the label value.
            :param label: int, must be >= 0
            """
            assert isinstance(label, int)
            assert label >= 0
            self.mLabel = label
            self.sigSettingsChanged.emit()
    
        def label(self)->int:
            """
            Returns the class label values
            :return: int
            """
            return self.mLabel
    
        def color(self)->QColor:
            """
            Returns the class color.
            :return: QColor
            """
            return QColor(self.mColor)
    
        def name(self)->str:
            """
            Returns the class name
            :return: str
            """
            return self.mName
    
        def setColor(self, color:QColor):
            """
            Sets the class color.
            :param color: QColor
            """
            assert isinstance(color, QColor)
            self.mColor = color
            self.sigSettingsChanged.emit()
    
        def setName(self, name:str):
            """
            Sets thes class name
            :param name: str
            """
            assert isinstance(name, str)
            self.mName = name
            self.sigSettingsChanged.emit()
    
    
        def pixmap(self, *args)->QPixmap:
            """
            Returns a QPixmap. Default size is 20x20px
            :param args: QPixmap arguments.
            :return: QPixmap
            """
            if len(args) == 0:
                args = (QSize(20, 20),)
    
            pm = QPixmap(*args)
            pm.fill(self.mColor)
            return pm
    
        def icon(self, *args)->QIcon:
            """
            Returns the class color as QIcon
            :param args: QPixmap arguments
            :return: QIcon
            """
            return QIcon(self.pixmap(*args))
    
        def clone(self):
            """
            Create a copy of this ClassInfo
            :return: ClassInfo
            """
            return ClassInfo(name=self.mName, color=self.mColor)
    
        def __ne__(self, other):
            return not self.__eq__(other)
    
        def __eq__(self, other):
            if not isinstance(other, ClassInfo):
                return False
            return other.mName == self.mName and \
                   other.mLabel == self.mLabel and \
                   other.mColor.getRgb() == self.mColor.getRgb()
    
        def __repr__(self):
            return 'ClassInfo' + self.__str__()
    
        def __str__(self):
            return '{} "{}" ({})'.format(self.mLabel, self.mName, self.mColor.name())
    
        def json(self)->str:
            return json.dumps([self.label(), self.name(), self.color().name()])
    
        def fromJSON(self, jsonString:str):
            try:
                label, name, color = json.loads(jsonString)
                color = QColor(color)
                return ClassInfo(label=label, name=name, color=color)
            except:
                return None
    
    
    class ClassificationScheme(QAbstractTableModel):
    
        sigClassesRemoved = pyqtSignal(list)
        #sigClassRemoved = pyqtSignal(ClassInfo, int)
        #sigClassAdded = pyqtSignal(ClassInfo, int)
        sigClassesAdded = pyqtSignal(list)
        sigNameChanged = pyqtSignal(str)
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        def __init__(self, name : str = None):
    
            super(ClassificationScheme, self).__init__()
            self.mClasses = []
            self.mName = name
            self.mIsEditable = True
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
            if name is None:
                name = 'Classification'
    
    
            self.mColColor = 'Color'
            self.mColName = 'Name'
            self.mColLabel = 'Label'
    
        def setIsEditable(self, b:bool):
            """
            Sets if class names and colors can be changed
            :param b: bool
            """
            if b != self.mIsEditable:
                self.mIsEditable = True
                self.dataChanged(self.createIndex(0,0),
                                 self.createIndex(self.rowCount()-1, self.columnCount()-1))
    
        def isEditable(self)->bool:
            """
            Returns if class names and colors can be changed.
            :return: bool
            """
            return self.mIsEditable
    
        def columnNames(self)->list:
            """
            Returns the column names.
            :return: [list-of-str]
            """
            return [self.mColLabel, self.mColName, self.mColColor]
    
        def dropMimeData(self, mimeData:QMimeData, action:Qt.DropAction, row:int, column:int, parent:QModelIndex):
            if row == -1:
                row = parent.row()
            if action == Qt.MoveAction:
                if MIMEDATA_INTERNAL_IDs in mimeData.formats():
                    ba = bytes(mimeData.data(MIMEDATA_INTERNAL_IDs))
                    ids = pickle.loads(ba)
    
                    classesToBeMoved = [c for c in self if id(c) in ids]
                    self.beginResetModel()
                    for c in reversed(classesToBeMoved):
                        idx = self.classInfo2index(c)
    
    
                        #self.beginMoveRows(QModelIndex(), idx.row(), idx.row(), QModelIndex(), row)
                        del self.mClasses[idx.row()]
                        self.mClasses.insert(row, c)
                        #self.endMoveRows()
                    self.endResetModel()
                    self._updateLabels()
                    return True
            elif action == Qt.CopyAction:
                if MIMEDATA_KEY in mimeData.formats():
                    cs = ClassificationScheme.fromQByteArray(mimeData.data(MIMEDATA_KEY))
                    self.insertClasses(cs[:], row)
    
            return False
    
        def mimeData(self, indexes)->QMimeData:
            """
            Returns class infos as QMimeData.
            :param indexes:
            :return:
            """
    
            if indexes is None:
                indexes = [self.createIndex(r, 0) for r in range(len(self))]
    
            classes = [self[idx.row()] for idx in indexes]
            cs = ClassificationScheme()
            cs.insertClasses(classes)
            mimeData = QMimeData()
            mimeData.setData(MIMEDATA_KEY, cs.qByteArray())
            mimeData.setData(MIMEDATA_INTERNAL_IDs, QByteArray(pickle.dumps([id(c) for c in classes ])))
            mimeData.setText(cs.toString())
            return mimeData
    
        def mimeTypes(self)->list:
            """
            Returns a list of supported mimeTypes.
            :return: [list-of-str]
            """
            return [MIMEDATA_KEY, MIMEDATA_INTERNAL_IDs, MIMEDATA_KEY_TEXT]
    
    
        def rowCount(self, parent:QModelIndex=None):
            """
            Returns the number of row / ClassInfos.
            :param parent: QModelIndex
            :return: int
            """
            return len(self.mClasses)
    
        def columnCount(self, parent: QModelIndex=None):
            return len(self.columnNames())
    
    
        def index2ClassInfo(self, index)->ClassInfo:
            if isinstance(index, QModelIndex):
                index = index.row()
            return self.mClasses[index]
    
        def classInfo2index(self, classInfo:ClassInfo)->QModelIndex:
            row = self.mClasses.index(classInfo)
            return self.createIndex(row, 0)
    
    
        def data(self, index: QModelIndex, role: int = Qt.DisplayRole):
            if not index.isValid():
                return None
    
            value = None
            col = index.column()
            row = index.row()
            classInfo = self.index2ClassInfo(row)
    
            if role == Qt.DisplayRole:
                if col == 0:
                    return classInfo.label()
                if col == 1:
                    return classInfo.name()
                if col == 2:
                    return classInfo.color().name()
    
            if role == Qt.ForegroundRole:
                if col == self.mColColor:
                    return QBrush(getTextColorWithContrast(classInfo.color()))
    
            if role == Qt.BackgroundColorRole:
                if col == 2:
                    return QBrush(classInfo.color())
    
            if role == Qt.AccessibleTextRole:
                if col == 0:
                    return str(classInfo.label())
                if col == 1:
                    return classInfo.name()
                if col == 2:
                    return classInfo.color().name()
    
            if role == Qt.ToolTipRole:
                if col == 0:
                    return 'Class label "{}"'.format(classInfo.label())
                if col == 1:
                    return 'Class name "{}"'.format(classInfo.name())
                if col == 2:
                    return 'Class color "{}"'.format(classInfo.color().name())
    
            if role == Qt.EditRole:
                if col == 1:
                    return classInfo.name()
                if col == 2:
                    return classInfo.color()
    
            if role == Qt.UserRole:
                return classInfo
    
            return None
    
        def supportedDragActions(self):
            return Qt.MoveAction
    
        def supportedDropActions(self):
            return Qt.MoveAction | Qt.CopyAction
    
        def setData(self, index: QModelIndex, value, role: int):
            if not index.isValid():
                return False
    
            col = index.column()
            row = index.row()
            classInfo = self.index2ClassInfo(row)
            b = False
            if role == Qt.EditRole:
                if col == 1:
                    classInfo.setName(value)
                    b = True
                if col == 2:
                    classInfo.setColor(value)
                    b = True
            if b:
                self.dataChanged.emit(index, index, [role])
            return False
    
        def flags(self, index: QModelIndex):
            if not index.isValid():
                return Qt.NoItemFlags
            col = index.column()
    
            flags = Qt.ItemIsSelectable | Qt.ItemIsEnabled
            if self.mIsEditable:
                flags |= Qt.ItemIsDragEnabled | Qt.ItemIsDropEnabled
                if col == 1:
                    flags |= Qt.ItemIsEditable
            return flags
    
    
        def headerData(self, section: int, orientation: Qt.Orientation, role: int = Qt.DisplayRole):
    
            if role == Qt.DisplayRole:
                if orientation == Qt.Horizontal:
                    return self.columnNames()[section]
    
            return super(ClassificationScheme, self).headerData(section, orientation, role)
    
    
        def setName(self, name:str='')->str:
            """
            Sets ClassificationScheme name
            :param name: str
            :return: str, the name
            """
            b = name != self.mName
            self.mName = name
            if b:
                self.sigNameChanged.emit(self.mName)
            return self.mName
    
        def name(self)->str:
            """
            Returns the ClassificationScheme name
            :return:
            """
            return self.mName
    
        def json(self)->str:
            """
            Returns a JSON string of this ClassificationScheme which can be deserialized with ClassificationScheme.fromJSON()
            :return: str, JSON string
            """
            data = {'name':self.mName,
                    'classes':[(c.label(), c.name(), c.color().name()) for c in self]
                    }
    
            return json.dumps(data)
    
        def pickle(self)->bytes:
            """
            Serializes this ClassificationScheme a byte object, which can be deserializes with ClassificationScheme.fromPickle()
            :return: bytes
            """
            return pickle.dumps(self.json())
    
        def qByteArray(self)->QByteArray:
            """
            Serializes this ClassicationScheme as QByteArray.
            Can be deserialized with ClassificationScheme.fromQByteArray()
            :return: QByteArray
            """
            return QByteArray(self.pickle())
    
        @staticmethod
        def fromQByteArray(array:QByteArray):
            return ClassificationScheme.fromPickle(bytes(array))
    
        @staticmethod
        def fromPickle(pkl:bytes):
            return ClassificationScheme.fromJson(pickle.loads(pkl))
    
    
        @staticmethod
        def fromFile(p:str):
            try:
                if os.path.isfile(p):
                    if p.endswith('.json'):
                        jsonStr = None
                        with open(p, 'r') as f:
                            jsonStr = f.read()
                        return ClassificationScheme.fromJson(jsonStr)
    
            except Exception as ex:
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                print(ex, file=sys.stderr)
    
            return None
    
        @staticmethod
        def fromJson(jsonStr:str):
            try:
                data = json.loads(jsonStr)
    
                s = ""
                cs = ClassificationScheme(name= data['name'])
                classes = []
                for classData in data['classes']:
                    label, name, colorName = classData
                    classes.append(ClassInfo(label=label, name=name, color=QColor(colorName)))
                cs.insertClasses(classes)
                return cs
            except Exception as ex:
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                print(ex, file=sys.stderr)
    
                return None
    
    
        def rasterRenderer(self, band=0)->QgsPalettedRasterRenderer:
            """
            Returns the ClassificationScheme as QgsPalettedRasterRenderer
            :return: ClassificationScheme
            """
            #DUMMY_RASTERINTERFACE = QgsSingleBandGrayRenderer(None, 0)
    
    
            classes = []
            for classInfo in self:
                qgsClass = QgsPalettedRasterRenderer.Class(
                    classInfo.label(),
                    classInfo.color(),
                    classInfo.name())
                classes.append(qgsClass)
            renderer = QgsPalettedRasterRenderer(None, band, classes)
            return renderer
    
        @staticmethod
        def fromRasterRenderer(renderer:QgsRasterRenderer):
            """
            Extracts a ClassificatonScheme from a QgsRasterRenderer
            :param renderer: QgsRasterRenderer
            :return: ClassificationScheme
            """
            if not isinstance(renderer, QgsPalettedRasterRenderer):
                return None
    
            classes = []
            for qgsClass in renderer.classes():
                classInfo = ClassInfo(label=qgsClass.value,
                                      name=qgsClass.label,
                                      color=QColor(qgsClass.color))
                classes.append(classInfo)
    
            cs = ClassificationScheme()
            cs.insertClasses(classes)
    
            return cs
    
        def featureRenderer(self)->QgsCategorizedSymbolRenderer:
            """
            Returns the ClassificationScheme as QgsCategorizedSymbolRenderer
            :return: ClassificationScheme
            """
    
            r = QgsCategorizedSymbolRenderer('dummy', [])
    
            for c in self:
                assert isinstance(c, ClassInfo)
                symbol = QgsMarkerSymbol()
                symbol.setColor(QColor(c.color()))
                cat = QgsRendererCategory(c.label(), symbol, c.name(), render=True)
                r.addCategory(cat)
            return r
    
    
        @staticmethod
        def fromFeatureRenderer(renderer:QgsCategorizedSymbolRenderer):
            """
            Extracts a ClassificatonScheme from a QgsCategorizedSymbolRenderer
            :param renderer: QgsCategorizedSymbolRenderer
            :return: ClassificationScheme
            """
            if not isinstance(renderer, QgsCategorizedSymbolRenderer):
                return None
            classes = []
            for cat in sorted(renderer.categories(), key = lambda c:c.value()):
                assert isinstance(cat, QgsRendererCategory)
                c = ClassInfo(name=cat.label(), color=QColor(cat.symbol().color()))
                classes.append(c)
            cs = ClassificationScheme()
            cs.insertClasses(classes)
            return cs
    
    
        def clear(self):
            """
            Removes all ClassInfos
            """
            self.beginRemoveColumns(QModelIndex(), 0, self.rowCount()-1)
            removed = self.mClasses[:]
            del self.mClasses[:]
            self.endRemoveRows()
            self.sigClassesRemoved.emit(removed)
    
    
        def clone(self):
            return self.copy()
    
        def copy(self):
            """
            Create a copy of this ClassificationScheme
            :return:
            """
            cs = ClassificationScheme()
            classes = [c.clone() for c in self.mClasses]
            cs.insertClasses(classes, 0)
            return cs
    
        def __getitem__(self, slice):
            return self.mClasses[slice]
    
        def __delitem__(self, slice):
            classes = self[slice]
            self.removeClasses(classes)
    
        def __contains__(self, item):
            return item in self.mClasses
    
        def __len__(self):
            return len(self.mClasses)
    
        def __iter__(self):
            return iter(self.mClasses)
    
        def __ne__(self, other):
            return not self.__eq__(other)
    
        def __eq__(self, other):
            if not (isinstance(other, ClassificationScheme) and len(self) == len(other)):
                return False
            return all(self[i] == other[i] for i in range(len(self)))
    
        def __str__(self):
            return self.__repr__() + '{} classes'.format(len(self))
    
    
        def range(self):
            """
            Returns the class label range (min,max).
            """
            labels = self.classLabels()
            return min(labels), max(labels)
    
        def classNames(self):
            """
            Returns all class names.
            :return: [list-of-class-names (str)]
            """
            return [c.name() for c in self.mClasses]
    
        def classColors(self):
            """
            Returns all class color.
            :return: [list-of-class-colors (QColor)]
            """
            return [QColor(c.color()) for c in self.mClasses]
    
        def classLabels(self)->list:
            """
            Returns the list of class labels [0,...,n-1]
            :return: [list-of-int]
            """
            return [c.label() for c in self.mClasses]
    
        def classColorArray(self)->np.ndarray:
            """
            Returns the RGBA class-colors as array
            :return: numpy.ndarray([nClasses,4])
            """
            return np.asarray([c.color().getRgb() for c in self])
    
        def gdalColorTable(self)->gdal.ColorTable:
            """
            Returns the class colors as GDAL Color Table
            :return: gdal.Colortable
            """
            ct = gdal.ColorTable()
            for i, c in enumerate(self):
                assert isinstance(c, ClassInfo)
                ct.SetColorEntry(i, c.mColor.getRgb())
            return ct
    
        def _updateLabels(self):
            """
            Assigns class labels according to the ClassInfo position
            """
            for i, c in enumerate(self.mClasses):
                c.mLabel = i
            self.dataChanged.emit(self.createIndex(0,0),
                                  self.createIndex(self.rowCount()-1,0),
                                  [Qt.DisplayRole, Qt.ToolTipRole])
            s = ""
    
        def removeClasses(self, classes):
            """
            Removes as single ClassInfo or a list of ClassInfos.
            :param classes: ClassInfo or [list-of-ClassInfo-to-remove]
            :returns: [list-of-removed-ClassInfo]
            """
            if isinstance(classes, ClassInfo):
                classes = [classes]
            assert isinstance(classes, list)
    
            removedIndices = []
            for c in classes:
                assert c in self.mClasses
                removedIndices.append(self.mClasses.index(c))
    
            removedIndices = list(reversed(sorted(removedIndices)))
            removedClasses = []
            for i in removedIndices:
                c = self.mClasses[i]
                self.beginRemoveRows(QModelIndex(), i, i)
                self.mClasses.remove(c)
                removedClasses.append(c)
                self.endRemoveRows()
            self._updateLabels()
            self.sigClassesRemoved.emit(removedClasses)
    
    
        def createClasses(self, n:int):
            """
            Creates n new classes with default an default initialization.
            Can be used to populate the ClassificationScheme.
            :param n: int, number of classes to add.
            """
            assert isinstance(n, int)
            assert n >= 0
            classes = []
    
            if len(self) > 0:
                nextCol = nextColor(self[-1].color())
            else:
                nextCol = DEFAULT_FIRST_COLOR
    
            for i in range(n):
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                j = len(self) + i
                if j == 0:
    
                    color = QColor('black')
                    name = 'Unclassified'
                else:
                    color = QColor(nextCol)
                    nextCol = nextColor(nextCol)
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
                    name = 'Class {}'.format(j)
    
                classes.append(ClassInfo(name=name, color=color))
            self.insertClasses(classes)
    
        def addClasses(self, classes, index=None):
            warnings.warn('use insertClasses()', DeprecationWarning)
            self.insertClasses(classes, index=index)
    
        def insertClasses(self, classes, index=None):
            """
            Adds / inserts a list of ClassInfos
            :param classes: [list-of-ClassInfo]
            :param index: int, index to insert the first of the new classes.
                               defaults to len(ClassificationScheme)
    
            """
            if isinstance(classes, ClassInfo):
                classes = [ClassInfo]
    
            assert isinstance(classes, list)
            if len(classes) == 0:
                return
    
            for c in classes:
                assert isinstance(c, ClassInfo)
                assert id(c) not in [id(c) for c in self.mClasses], 'You cannot add the same ClassInfo instance to a ClassificationScheme twice. Create a copy first.'
    
            if index is None:
                #default: add new classes to end of list
                index = len(self.mClasses)
            #negative index? insert to beginning
            index = max(index, 0)
    
    
            self.beginInsertRows(QModelIndex(), index, index+len(classes)-1)
            for i, c in enumerate(classes):
                assert isinstance(c, ClassInfo)
                index = index + i
                #c.sigSettingsChanged.connect(self.onClassInfoSettingChanged)
                self.mClasses.insert(index, c)
            self.endInsertRows()
            self._updateLabels()
            self.sigClassesAdded.emit(classes)
    
    
        #sigClassInfoChanged = pyqtSignal(ClassInfo)
        #def onClassInfoSettingChanged(self, *args):
        #    self.sigClassInfoChanged.emit(self.sender())
    
        def classIndexFromValue(self, value, matchSimilarity=False)->int:
            """
            Get a values and returns the index of ClassInfo that matches best to.
            :param value: any
            :return: int
            """
            classNames = self.classNames()
            i = -1
    
            #1. match on identity
            if isinstance(value, (int, float)):
                i = int(value)
    
            elif isinstance(value, str):
                if value in classNames:
                    i = classNames.index(value)
    
            #2. not found? match on similarity
            if i == -1 and matchSimilarity == True:
                if isinstance(value, (int, float)):
                    pass
    
                elif isinstance(value, str):
                    if value in classNames:
                        i = classNames.index(value)
                pass
            return i
    
        def classFromValue(self, value, matchSimilarity=False)->ClassInfo:
            i = self.classIndexFromValue(value, matchSimilarity=matchSimilarity)
            if i != -1:
                return self[i]
            else:
                return None
    
        def addClass(self, c, index=None):
            warnings.warn('Use insert class', DeprecationWarning)
    
    
        def insertClass(self, c, index=None):
            """
            Adds a ClassInfo
            :param c: ClassInfo
            :param index: int, index to add the ClassInfo. Defaults to the end.
            """
            assert isinstance(c, ClassInfo)
            self.insertClasses([c], index=index)
    
    
        def saveToRasterBand(self, band:gdal.Band):
            """
            Saves the ClassificationScheme to the gdal.Band.
            ClassInfo names are stored by gdal.Band.SetCategoryNames and colors as gdal.ColorTable.
            :param band: gdal.Band
            """
            assert isinstance(band, gdal.Band)
            ct = gdal.ColorTable()
            cat = []
            for i, classInfo in enumerate(self.mClasses):
                c = classInfo.mColor
                cat.append(classInfo.mName)
                assert isinstance(c, QColor)
                rgba = (c.red(), c.green(), c.blue(), c.alpha())
                ct.SetColorEntry(i, *rgba)
    
            band.SetColorTable(ct)
            band.SetCategoryNames(cat)
    
    
        def saveToRaster(self, path, bandIndex=0):
            """
            Saves this ClassificationScheme to an raster image
            :param path: path (str) of raster image or gdal.Dataset instance
            :param bandIndex: band index of raster band to set this ClassificationScheme.
                              Defaults to 0 = the first band
            """
            if isinstance(path, str):
                ds = gdal.Open(path)
            elif isinstance(path, gdal.Dataset):
                ds = path
    
            assert isinstance(ds, gdal.Dataset)
            assert ds.RasterCount < bandIndex
            band = ds.GetRasterBand(bandIndex + 1)
            self.saveToRasterBand(band)
    
    
            ds = None
    
        def toString(self, sep=';')->str:
            """
            A quick dump of all ClassInfos
            :param sep: value separator, ';' by default
            :return: str
            """
            lines = ['ClassificationScheme("{}")'.format(self.name())]
            lines += [sep.join(['label', 'name', 'color'])]
            for classInfo in self.mClasses:
                c = classInfo.color()
                info = [classInfo.label(), classInfo.name(), c.name()]
                info = ['{}'.format(v) for v in info]
                lines.append(sep.join(info))
            return '\n'.join(lines)
    
        def saveToCsv(self, path:str, sep:str=';', mode:str = None)->str:
            """
            Saves the ClassificationScheme as CSV table.
            :param path: str, path of CSV file
            :param sep: separator (';' by default)
            :returns: the path of written file (if something was written)
            """
            if mode == None:
                lines = self.toString(sep=sep)
                with open(path, 'w') as f:
                    f.write(lines)
    
                return path
    
            return None
    
    
        def saveToJson(self, path:str, mode:str=None)->str:
            """
            Save the ClassificationScheme as JSON file.
            :param path: str, path of JSON file
            :return: path of written file
            """
            if mode == None:
                lines = self.json()
                with open(path, 'w') as f:
                    f.write(lines)
                return path
    
            return None
    
    
        @staticmethod
        def create(n):
            """
            Create a ClassificationScheme with n classes (including 'Unclassified' with label = 0)
            :param n: number of classes including 'Unclassified'
            :return: ClassificationScheme
            """
            s = ClassificationScheme()
            s.createClasses(n)
            return s
    
        @staticmethod
        def fromMimeData(mimeData:QMimeData):
    
            if not isinstance(mimeData, QMimeData):
                return None
    
            if MIMEDATA_KEY in mimeData.formats():
                ba = ClassificationScheme.fromQByteArray(mimeData.data(MIMEDATA_KEY))
                if isinstance(ba, ClassificationScheme):
                    return ba
            if MIMEDATA_KEY_TEXT in mimeData.formats():
    
                ba = ClassificationScheme.fromQByteArray(mimeData.data(MIMEDATA_KEY_TEXT))
                if isinstance(ba, ClassificationScheme):
                    return ba
    
            return None
    
    
    Benjamin Jakimow's avatar
    Benjamin Jakimow committed
        @staticmethod
        def fromMapLayer(layer:QgsMapLayer):
            """
    
            :param layer:
            :return:
            """
            scheme = None
            if isinstance(layer, QgsRasterLayer):
                scheme = ClassificationScheme.fromRasterRenderer(layer.renderer())
                if not isinstance(scheme, ClassificationScheme):