From 26005a91e26bd76a07ce9eec0f5c94d9c79c6d41 Mon Sep 17 00:00:00 2001
From: "benjamin.jakimow" <benjamin.jakimow@geo.hu-berlin.de>
Date: Tue, 6 Aug 2019 18:03:30 +0200
Subject: [PATCH] EOTSV reads wavelength information from SPOT DIMAP format and
 RapidEye images

Signed-off-by: benjamin.jakimow <benjamin.jakimow@geo.hu-berlin.de>
---
 eotimeseriesviewer/timeseries.py | 269 +++++++++++++++++++------------
 make/createtestdata.py           |   2 +-
 tests/test_timeseries.py         |  93 +++++++----
 3 files changed, 226 insertions(+), 138 deletions(-)

diff --git a/eotimeseriesviewer/timeseries.py b/eotimeseriesviewer/timeseries.py
index 53a667aa..f318a53b 100644
--- a/eotimeseriesviewer/timeseries.py
+++ b/eotimeseriesviewer/timeseries.py
@@ -23,7 +23,7 @@
 import sys, re, collections, traceback, time, json, urllib, types, enum, typing, pickle, json, uuid
 from xml.etree import ElementTree
 
-
+from qgis.PyQt.QtXml import QDomDocument
 import bisect
 
 from qgis import *
@@ -35,14 +35,14 @@ from qgis.PyQt.QtCore import *
 
 
 LUT_WAVELENGTH_UNITS = {}
-for siUnit in ['nm', r'μm', 'mm', 'cm', 'dm']:
+for siUnit in [r'nm', r'μm', r'mm', r'cm', r'dm']:
     LUT_WAVELENGTH_UNITS[siUnit] = siUnit
-LUT_WAVELENGTH_UNITS['nanometers'] = 'nm'
-LUT_WAVELENGTH_UNITS['micrometers'] = 'μm'
-LUT_WAVELENGTH_UNITS['um'] = 'μm'
-LUT_WAVELENGTH_UNITS['millimeters'] = 'mm'
-LUT_WAVELENGTH_UNITS['centimeters'] = 'cm'
-LUT_WAVELENGTH_UNITS['decimeters'] = 'dm'
+LUT_WAVELENGTH_UNITS[r'nanometers'] = r'nm'
+LUT_WAVELENGTH_UNITS[r'micrometers'] = r'μm'
+LUT_WAVELENGTH_UNITS[r'um'] = r'μm'
+LUT_WAVELENGTH_UNITS[r'millimeters'] = r'mm'
+LUT_WAVELENGTH_UNITS[r'centimeters'] = r'cm'
+LUT_WAVELENGTH_UNITS[r'decimeters'] = r'dm'
 
 
 from osgeo import gdal
@@ -1909,125 +1909,180 @@ def getSpatialPropertiesFromDataset(ds):
 
     return nb, nl, ns, crs, px_x, px_y
 
-
-def extractWavelengths(ds):
+def extractWavelengthsFromGDALMetaData(ds:gdal.Dataset)->(list, str):
     """
-    Returns the wavelength and wavelength units
+    Reads the wavelength info from standard metadata strings
     :param ds: gdal.Dataset
-    :return: ([list-of-wavelength floats], stt of wavelength unit)
+    :return: (list, str)
     """
-    wl = None
-    wlu = None
 
-    # see http://www.harrisgeospatial.com/docs/ENVIHeaderFiles.html for supported wavelength units
-    regWLkey = re.compile('.*wavelength[_ ]*$', re.I)
-    regWLUkey = re.compile('.*wavelength[_ ]*units?$', re.I)
+    regWLkey = re.compile('^(center )?wavelength[_ ]*$', re.I)
+    regWLUkey = re.compile('^wavelength[_ ]*units?$', re.I)
     regNumeric = re.compile(r"([-+]?\d*\.\d+|[-+]?\d+)", re.I)
-    regWLU = re.compile('((micro|nano|centi)meters)|(um|nm|mm|cm|m|GHz|MHz)', re.I)
 
-    if isinstance(ds, QgsRasterLayer):
-        lyr = ds
-        md = [l.split('=') for l in str(lyr.metadata()).splitlines() if 'wavelength' in l.lower()]
-        #see http://www.harrisgeospatial.com/docs/ENVIHeaderFiles.html for supported wavelength units
-        for kv in md:
-            key, value = kv
-            key = key.lower()
-            if key == 'center wavelength':
-                tmp = re.findall(r'\d*\.\d+|\d+', value) #find floats
-                if len(tmp) == 0:
-                    tmp = re.findall(r'\d+', value) #find integers
-                if len(tmp) == lyr.bandCount():
-                    wl = [float(w) for w in tmp]
-
-            if key == 'wavelength units':
-                match = regWLU.search(value)
-                if match:
-                    wlu = match.group()
-
-
-                if wlu in LUT_WAVELENGTH_UNITS.keys():
-                    wlu = LUT_WAVELENGTH_UNITS[wlu]
+    def findKey(d:dict, regex)->str:
+        for key in d.keys():
+            if regex.search(key):
+                return key
 
-    elif isinstance(ds, gdal.Dataset):
-        domains = ds.GetMetadataDomainList()
-        from qgis.PyQt.QtXml import QDomDocument
-        # DIMAP XML metadata?
-        xmlData = None
-        if 'xml:dimap' in domains:
-            md = ds.GetMetadata_Dict('xml:dimap')
-            xml = '<?xml version'
-            if xml in md.keys():
-                xmlData = xml + md[xml]
-                lines = xmlData.splitlines()
-                xmlData = '\n'.join(lines[2:])
+    # 1. try band level
+    wlu = []
+    wl = []
+    for b in range(ds.RasterCount):
+        band = ds.GetRasterBand(b + 1)
+        assert isinstance(band, gdal.Band)
+        md = band.GetMetadata_Dict()
 
-        else:
-            for path in ds.GetFileList():
-                if re.search(r'\.xml$', path, re.I) and not re.search(r'\.aux.xml$', path, re.I):
-                    with open(path, encoding='utf-8') as f:
-                        xmlData = f.read()
+        keyWLU = findKey(md, regWLUkey)
+        keyWL = findKey(md, regWLkey)
 
+        if isinstance(keyWL, str) and isinstance(keyWLU, str):
+            wl.append(float(md[keyWL]))
+            wlu.append(LUT_WAVELENGTH_UNITS[md[keyWLU].lower()])
 
-            s = ""
+    if len(wlu) == len(wl) and len(wl) == ds.RasterCount:
+        return wl, wlu[0]
 
+    # 2. try data set level
+    for domain in ds.GetMetadataDomainList():
+        md = ds.GetMetadata_Dict(domain)
 
-        if xmlData is not None:
-            dom = QDomDocument()
-            dom.setContent(xmlData)
-
-            # try DIMAP XML
-            nodes = dom.elementsByTagName('Band_Spectral_Range')
-            if nodes.count() > 0:
-                candidates = []
-                for element in [nodes.item(i).toElement() for i in range(nodes.count())]:
-                    _band = element.firstChildElement('BAND_ID').text()
-                    _wlu = element.firstChildElement('MEASURE_UNIT').text()
-                    wlMin = float(element.firstChildElement('MIN').text())
-                    wlMax = float(element.firstChildElement('MAX').text())
-                    _wl = 0.5*wlMin+wlMax
-                    candidates.append((_band, _wl, _wlu))
-
-                if len(candidates) == ds.RasterCount:
-                    candidates = sorted(candidates, key=lambda t:t[0])
-
-                    wlu = candidates[0][2]
-                    wlu = LUT_WAVELENGTH_UNITS[wlu]
-                    wl = [c[1] for c in candidates]
-                    return wl, wlu
+        keyWLU = findKey(md, regWLUkey)
+        keyWL = findKey(md, regWLkey)
+
+        if isinstance(keyWL, str) and isinstance(keyWLU, str):
 
-            nodes = dom.elementsByTagName('re:bandSpecificMetadata')
-
-            # test for RapidEye XML
-            # see http://schemas.rapideye.de/products/re/4.0/RapidEye_ProductMetadata_GeocorrectedLevel.xsd
-            # wavelength and units not given in the XML
-            # -> use values from https://www.satimagingcorp.com/satellite-sensors/other-satellite-sensors/rapideye/
-            if nodes.count() == ds.RasterCount and ds.RasterCount == 5:
-                wlu = r'nm'
-                wl = [0.5 * (440 + 510),
-                      0.5 * (520 + 590),
-                      0.5 * (630 + 685),
-                      0.5 * (760 + 850),
-                      0.5 * (760 - 850)
-                ]
+
+            wlu = LUT_WAVELENGTH_UNITS[md[keyWLU].lower()]
+            matches = regNumeric.findall(md[keyWL])
+            wl = [float(n) for n in matches]
+
+
+
+            if len(wl) == ds.RasterCount:
                 return wl, wlu
 
-        for domain in domains:
-            md = ds.GetMetadata_Dict(domain)
-            for key, value in md.items():
-                if wl is None and regWLkey.search(key):
-                    numbers = regNumeric.findall(value)
-                    if len(numbers) == ds.RasterCount:
-                        wl = [float(n) for n in numbers]
+    return None, None
+
+
 
-                if wlu is None and regWLUkey.search(key):
-                    match = regWLU.search(value)
-                    if match:
-                        wlu = match.group().lower()
+def extractWavelengthsFromRapidEyeXML(ds:gdal.Dataset, dom:QDomDocument)->(list, str):
+    nodes = dom.elementsByTagName('re:bandSpecificMetadata')
+    # see http://schemas.rapideye.de/products/re/4.0/RapidEye_ProductMetadata_GeocorrectedLevel.xsd
+    # wavelength and units not given in the XML
+    # -> use values from https://www.satimagingcorp.com/satellite-sensors/other-satellite-sensors/rapideye/
+    if nodes.count() == ds.RasterCount and ds.RasterCount == 5:
+        wlu = r'nm'
+        wl = [0.5 * (440 + 510),
+              0.5 * (520 + 590),
+              0.5 * (630 + 685),
+              0.5 * (760 + 850),
+              0.5 * (760 - 850)
+              ]
+        return wl, wlu
+    return None, None
 
+
+def extractWavelengthsFromDIMAPXML(ds:gdal.Dataset, dom:QDomDocument)->(list, str):
+    """
+    :param dom: QDomDocument | gdal.Dataset
+    :return: (list of wavelengths, str wavelength unit)
+    """
+    # DIMAP XML metadata?
+    assert isinstance(dom, QDomDocument)
+    nodes = dom.elementsByTagName('Band_Spectral_Range')
+    if nodes.count() > 0:
+        candidates = []
+        for element in [nodes.item(i).toElement() for i in range(nodes.count())]:
+            _band = element.firstChildElement('BAND_ID').text()
+            _wlu = element.firstChildElement('MEASURE_UNIT').text()
+            wlMin = float(element.firstChildElement('MIN').text())
+            wlMax = float(element.firstChildElement('MAX').text())
+            _wl = 0.5 * wlMin + wlMax
+            candidates.append((_band, _wl, _wlu))
+
+        if len(candidates) == ds.RasterCount:
+            candidates = sorted(candidates, key=lambda t: t[0])
+
+            wlu = candidates[0][2]
+            wlu = LUT_WAVELENGTH_UNITS[wlu]
+            wl = [c[1] for c in candidates]
+            return wl, wlu
+    return None, None
+
+def extractWavelengths(ds):
+    """
+    Returns the wavelength and wavelength units
+    :param ds: gdal.Dataset
+    :return: (float [list-of-wavelengths], str with wavelength unit)
+    """
+
+    if isinstance(ds, QgsRasterLayer):
+
+        if ds.dataProvider().name() == 'gdal':
+            uri = ds.source()
+            return extractWavelengths(gdal.Open(uri))
+        else:
+
+            md = [l.split('=') for l in str(ds.metadata()).splitlines() if 'wavelength' in l.lower()]
+
+            wl = wlu = None
+            for kv in md:
+                key, value = kv
+                key = key.lower()
+                value = value.strip()
+
+                if key == 'wavelength':
+                    tmp = re.findall(r'\d*\.\d+|\d+', value) #find floats
+                    if len(tmp) == 0:
+                        tmp = re.findall(r'\d+', value) #find integers
+                    if len(tmp) == ds.bandCount():
+                        wl = [float(w) for w in tmp]
+
+                if key == 'wavelength units':
+                    wlu = value
                     if wlu in LUT_WAVELENGTH_UNITS.keys():
                         wlu = LUT_WAVELENGTH_UNITS[wlu]
 
-    return wl, wlu
+                if isinstance(wl, list) and isinstance(wlu, str):
+                    return wl, wlu
+
+    elif isinstance(ds, gdal.Dataset):
+
+        def testWavelLengthInfo(wl, wlu)->bool:
+            return isinstance(wl, list) and len(wl) == ds.RasterCount and isinstance(wlu, str) and wlu in LUT_WAVELENGTH_UNITS.keys()
+
+        # try band-specific metadata
+        wl, wlu = extractWavelengthsFromGDALMetaData(ds)
+        if testWavelLengthInfo(wl, wlu):
+            return wl, wlu
+
+        # try internal locations with XML info
+        # SPOT DIMAP
+        if 'xml:dimap' in ds.GetMetadataDomainList():
+            md = ds.GetMetadata_Dict('xml:dimap')
+            for key in md.keys():
+                dom = QDomDocument()
+                dom.setContent(key + '=' + md[key])
+                wl, wlu = extractWavelengthsFromDIMAPXML(ds, dom)
+                if testWavelLengthInfo(wl, wlu):
+                    return wl, wlu
+
+        # try separate XML files
+        xmlReaders = [extractWavelengthsFromDIMAPXML, extractWavelengthsFromRapidEyeXML]
+        for path in ds.GetFileList():
+            if re.search(r'\.xml$', path, re.I) and not re.search(r'\.aux.xml$', path, re.I):
+                dom = QDomDocument()
+                with open(path, encoding='utf-8') as f:
+                    dom.setContent(f.read())
+
+                if dom.hasChildNodes():
+                    for xmlReader in xmlReaders:
+                        wl, wlu = xmlReader(ds, dom)
+                        if testWavelLengthInfo(wl, wlu):
+                            return wl, wlu
+
+    return None, None
 
 
 
diff --git a/make/createtestdata.py b/make/createtestdata.py
index a6a5f681..e398f204 100644
--- a/make/createtestdata.py
+++ b/make/createtestdata.py
@@ -136,7 +136,7 @@ def groupLandsat(dirIn, dirOut, pattern='L*_sr_band*.img'):
             raise NotImplementedError()
 
         #https://www.harrisgeospatial.com/docs/ENVIHeaderFiles.html
-        dsVRT.SetMetadataItem('wavelength units','Micrometers', 'ENVI')
+        dsVRT.SetMetadataItem('wavelength units', 'Micrometers', 'ENVI')
         dsVRT.SetMetadataItem('wavelength', '{{{}}}'.format(','.join([str(w) for w in cwl])), 'ENVI')
         dsVRT.SetMetadataItem('sensor type', 'Landsat-8 OLI', 'ENVI')
         from eotimeseriesviewer.dateparser import datetime64FromYYYYDOY
diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py
index aea61f7b..f40590f3 100644
--- a/tests/test_timeseries.py
+++ b/tests/test_timeseries.py
@@ -13,7 +13,7 @@ from eotimeseriesviewer.tests import initQgisApplication
 
 QAPP = initQgisApplication()
 
-SHOW_GUI = True and os.environ.get('CI') is None
+SHOW_GUI = False and os.environ.get('CI') is None
 
 class TestInit(unittest.TestCase):
 
@@ -87,6 +87,7 @@ class TestInit(unittest.TestCase):
 
             c2 = sensorIDtoProperties(sid)
             self.assertListEqual(list(conf), list(c2))
+        s = ""
 
     def test_TimeSeriesDate(self):
 
@@ -320,48 +321,80 @@ class TestInit(unittest.TestCase):
 
     def test_pleiades(self):
 
-        p = r'Y:\Pleiades\GFIO_Gp13_Novo_SO16018091-4-01_DS_PHR1A_201703031416139_FR1_PX_W056S07_0906_01636\TPP1600581943\IMG_PHR1A_PMS_001\DIM_PHR1A_PMS_201703031416139_ORT_2224693101-001.XML'
-        #p = r'Y:\Pleiades\GFIO_Gp13_Novo_SO16018091-4-01_DS_PHR1A_201703031416139_FR1_PX_W056S07_0906_01636\TPP1600581943\IMG_PHR1A_PMS_001\IMG_PHR1A_PMS_201703031416139_ORT_2224693101-001_R1C1.JP2'
-        if not os.path.isfile(p):
-            return
+        paths = [r'Y:\Pleiades\GFIO_Gp13_Novo_SO16018091-4-01_DS_PHR1A_201703031416139_FR1_PX_W056S07_0906_01636\TPP1600581943\IMG_PHR1A_PMS_001\DIM_PHR1A_PMS_201703031416139_ORT_2224693101-001.XML'
+                ,r'Y:\Pleiades\GFIO_Gp13_Novo_SO16018091-4-01_DS_PHR1A_201703031416139_FR1_PX_W056S07_0906_01636\TPP1600581943\IMG_PHR1A_PMS_001\IMG_PHR1A_PMS_201703031416139_ORT_2224693101-001_R1C1.JP2'
+                    ]
+        for p in paths:
+            if not os.path.isfile(p):
+                continue
 
-        ds = gdal.Open(p)
-        self.assertIsInstance(ds, gdal.Dataset)
-        band = ds.GetRasterBand(1)
-        self.assertIsInstance(band, gdal.Band)
+            ds = gdal.Open(p)
+            self.assertIsInstance(ds, gdal.Dataset)
+            band = ds.GetRasterBand(1)
+            self.assertIsInstance(band, gdal.Band)
 
 
-        tss = TimeSeriesSource(ds)
-        self.assertIsInstance(tss, TimeSeriesSource)
-        self.assertEqual(tss.mWLU, r'μm')
-        self.assertListEqual(tss.mWL, [0.775, 0.867, 1.017, 1.315])
+            tss = TimeSeriesSource(ds)
+            self.assertIsInstance(tss, TimeSeriesSource)
+            self.assertEqual(tss.mWLU, r'μm')
+            self.assertListEqual(tss.mWL, [0.775, 0.867, 1.017, 1.315])
+
+        s = ""
 
     def test_rapideye(self):
+        from example.Images import re_2014_06_25
+        paths = [r'Y:\RapidEye\3A\2135821_2014-06-25_RE2_3A_328202\2135821_2014-06-25_RE2_3A_328202.tif']
+
+        for p in paths:
+            if not os.path.isfile(p):
+                continue
 
-        p = r'Y:\RapidEye\3A\2135821_2014-06-25_RE2_3A_328202\2135821_2014-06-25_RE2_3A_328202.tif'
+            ds = gdal.Open(p)
+            self.assertIsInstance(ds, gdal.Dataset)
+            band = ds.GetRasterBand(1)
+            self.assertIsInstance(band, gdal.Band)
+
+
+            tss = TimeSeriesSource(ds)
+            self.assertIsInstance(tss, TimeSeriesSource)
+
+            # see https://www.satimagingcorp.com/satellite-sensors/other-satellite-sensors/rapideye/
+            wlu = r'nm'
+            wl = [0.5 * (440 + 510),
+                  0.5 * (520 + 590),
+                  0.5 * (630 + 685),
+                  0.5 * (760 + 850),
+                  0.5 * (760 - 850)
+                  ]
+            self.assertEqual(tss.mWLU, wlu)
+            self.assertListEqual(tss.mWL, wl)
+
+    def test_sentinel2(self):
+
+        p = r'Q:\Processing_BJ\01_Data\Sentinel\T21LXL\S2A_MSIL1C_20161221T141042_N0204_R110_T21LXL_20161221T141040.SAFE\MTD_MSIL1C.xml'
 
         if not os.path.isfile(p):
             return
 
-        ds = gdal.Open(p)
-        self.assertIsInstance(ds, gdal.Dataset)
-        band = ds.GetRasterBand(1)
-        self.assertIsInstance(band, gdal.Band)
+        dsC = gdal.Open(p)
+        self.assertIsInstance(dsC, gdal.Dataset)
+        for item in dsC.GetSubDatasets():
+            path = item[0]
+            ds = gdal.Open(path)
+            gt =  ds.GetGeoTransform()
+            self.assertIsInstance(ds, gdal.Dataset)
 
+            band = ds.GetRasterBand(1)
+            self.assertIsInstance(band, gdal.Band)
 
-        tss = TimeSeriesSource(ds)
-        self.assertIsInstance(tss, TimeSeriesSource)
+            wlu = ds.GetRasterBand(1).GetMetadata_Dict()['WAVELENGTH_UNIT']
+            wl = [float(ds.GetRasterBand(b+1).GetMetadata_Dict()['WAVELENGTH']) for b in range(ds.RasterCount)]
+
+            tss = TimeSeriesSource(ds)
+            self.assertIsInstance(tss, TimeSeriesSource)
 
-        # see https://www.satimagingcorp.com/satellite-sensors/other-satellite-sensors/rapideye/
-        wlu = r'nm'
-        wl = [0.5 * (440 + 510),
-              0.5 * (520 + 590),
-              0.5 * (630 + 685),
-              0.5 * (760 + 850),
-              0.5 * (760 - 850)
-              ]
-        self.assertEqual(tss.mWLU, wlu)
-        self.assertListEqual(tss.mWL, wl)
+            self.assertEqual(tss.mWLU, wlu)
+            self.assertEqual(tss.mWL, wl)
 
 
     def test_sensors(self):
-- 
GitLab