Commit 73a079a6 authored by Guillaume Poulin's avatar Guillaume Poulin
Browse files

Improve ScatterPlotItem.py

Add optimization for PySide, Plot only visible symbole, cache rectTarg
parent bd43a750
...@@ -3,7 +3,7 @@ from pyqtgraph.Point import Point ...@@ -3,7 +3,7 @@ from pyqtgraph.Point import Point
import pyqtgraph.functions as fn import pyqtgraph.functions as fn
from .GraphicsItem import GraphicsItem from .GraphicsItem import GraphicsItem
from .GraphicsObject import GraphicsObject from .GraphicsObject import GraphicsObject
from itertools import starmap from itertools import starmap, repeat
try: try:
from itertools import imap from itertools import imap
except ImportError: except ImportError:
...@@ -102,7 +102,6 @@ class SymbolAtlas(object): ...@@ -102,7 +102,6 @@ class SymbolAtlas(object):
self.symbolPen = weakref.WeakValueDictionary() self.symbolPen = weakref.WeakValueDictionary()
self.symbolBrush = weakref.WeakValueDictionary() self.symbolBrush = weakref.WeakValueDictionary()
self.symbolRectSrc = weakref.WeakValueDictionary() self.symbolRectSrc = weakref.WeakValueDictionary()
self.symbolRectTarg = weakref.WeakValueDictionary()
self.atlasData = None # numpy array of atlas image self.atlasData = None # numpy array of atlas image
self.atlas = None # atlas as QPixmap self.atlas = None # atlas as QPixmap
...@@ -114,33 +113,25 @@ class SymbolAtlas(object): ...@@ -114,33 +113,25 @@ class SymbolAtlas(object):
Given a list of spot records, return an object representing the coordinates of that symbol within the atlas Given a list of spot records, return an object representing the coordinates of that symbol within the atlas
""" """
rectSrc = np.empty(len(opts), dtype=object) rectSrc = np.empty(len(opts), dtype=object)
rectTarg = np.empty(len(opts), dtype=object)
keyi = None keyi = None
rectSrci = None rectSrci = None
rectTargi = None
for i, rec in enumerate(opts): for i, rec in enumerate(opts):
key = (rec[3], rec[2], id(rec[4]), id(rec[5])) key = (rec[3], rec[2], id(rec[4]), id(rec[5]))
if key == keyi: if key == keyi:
rectSrc[i] = rectSrci rectSrc[i] = rectSrci
rectTarg[i] = rectTargi
else: else:
try: try:
rectSrc[i] = self.symbolRectSrc[key] rectSrc[i] = self.symbolRectSrc[key]
rectTarg[i] = self.symbolRectTarg[key]
except KeyError: except KeyError:
newRectSrc = QtCore.QRectF() newRectSrc = QtCore.QRectF()
newRectTarg = QtCore.QRectF()
self.symbolPen[key] = rec['pen'] self.symbolPen[key] = rec['pen']
self.symbolBrush[key] = rec['brush'] self.symbolBrush[key] = rec['brush']
self.symbolRectSrc[key] = newRectSrc self.symbolRectSrc[key] = newRectSrc
self.symbolRectTarg[key] = newRectTarg
self.atlasValid = False self.atlasValid = False
rectSrc[i] = self.symbolRectSrc[key] rectSrc[i] = self.symbolRectSrc[key]
rectTarg[i] = self.symbolRectTarg[key]
keyi = key keyi = key
rectSrci = self.symbolRectSrc[key] rectSrci = self.symbolRectSrc[key]
rectTargi = self.symbolRectTarg[key] return rectSrc
return rectSrc, rectTarg
def buildAtlas(self): def buildAtlas(self):
# get rendered array for all symbols, keep track of avg/max width # get rendered array for all symbols, keep track of avg/max width
...@@ -195,7 +186,6 @@ class SymbolAtlas(object): ...@@ -195,7 +186,6 @@ class SymbolAtlas(object):
self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte) self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte)
for key in symbols: for key in symbols:
y, x, h, w = self.symbolRectSrc[key].getRect() y, x, h, w = self.symbolRectSrc[key].getRect()
self.symbolRectTarg[key].setRect(-h/2, -w/2, h, w)
self.atlasData[x:x+w, y:y+h] = rendered[key] self.atlasData[x:x+w, y:y+h] = rendered[key]
self.atlas = None self.atlas = None
self.atlasValid = True self.atlasValid = True
...@@ -247,7 +237,7 @@ class ScatterPlotItem(GraphicsObject): ...@@ -247,7 +237,7 @@ class ScatterPlotItem(GraphicsObject):
self.target = None self.target = None
self.fragmentAtlas = SymbolAtlas() self.fragmentAtlas = SymbolAtlas()
self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('rectSrc', object), ('rectTarg', object)]) self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('rectSrc', object), ('rectTarg', object), ('width', float)])
self.bounds = [None, None] ## caches data bounds self.bounds = [None, None] ## caches data bounds
self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots self._maxSpotWidth = 0 ## maximum size of the scale-variant portion of all spots
self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots self._maxSpotPxWidth = 0 ## maximum size of the scale-invariant portion of all spots
...@@ -561,15 +551,17 @@ class ScatterPlotItem(GraphicsObject): ...@@ -561,15 +551,17 @@ class ScatterPlotItem(GraphicsObject):
if np.any(mask): if np.any(mask):
invalidate = True invalidate = True
opts = self.getSpotOpts(dataSet[mask]) opts = self.getSpotOpts(dataSet[mask])
rectSrc, rectTarg = self.fragmentAtlas.getSymbolCoords(opts) rectSrc = self.fragmentAtlas.getSymbolCoords(opts)
dataSet['rectSrc'][mask] = rectSrc dataSet['rectSrc'][mask] = rectSrc
dataSet['rectTarg'][mask] = rectTarg
#for rec in dataSet: #for rec in dataSet:
#if rec['fragCoords'] is None: #if rec['fragCoords'] is None:
#invalidate = True #invalidate = True
#rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec)) #rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec))
self.fragmentAtlas.getAtlas() self.fragmentAtlas.getAtlas()
dataSet['width'] = np.array(list(imap(QtCore.QRectF.width, dataSet['rectSrc'])))/2
dataSet['rectTarg'] = list(imap(QtCore.QRectF, repeat(0), repeat(0), dataSet['width']*2, dataSet['width']*2))
self._maxSpotPxWidth=self.fragmentAtlas.max_width self._maxSpotPxWidth=self.fragmentAtlas.max_width
else: else:
self._maxSpotWidth = 0 self._maxSpotWidth = 0
...@@ -699,9 +691,15 @@ class ScatterPlotItem(GraphicsObject): ...@@ -699,9 +691,15 @@ class ScatterPlotItem(GraphicsObject):
tr = self.deviceTransform() tr = self.deviceTransform()
if tr is None: if tr is None:
return return
pts = np.empty((2,len(self.data['x']))) mask = np.logical_and(
pts[0] = self.data['x'] np.logical_and(self.data['x'] - self.data['width'] > range[0][0],
pts[1] = self.data['y'] self.data['x'] + self.data['width'] < range[0][1]),
np.logical_and(self.data['y'] - self.data['width'] > range[1][0],
self.data['y'] + self.data['width'] < range[1][1])) ## remove out of view points
data = self.data[mask]
pts = np.empty((2,len(data['x'])))
pts[0] = data['x']
pts[1] = data['y']
pts = fn.transformCoordinates(tr, pts) pts = fn.transformCoordinates(tr, pts)
self.fragments = [] self.fragments = []
pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault. pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault.
...@@ -746,18 +744,39 @@ class ScatterPlotItem(GraphicsObject): ...@@ -746,18 +744,39 @@ class ScatterPlotItem(GraphicsObject):
p.resetTransform() p.resetTransform()
if not USE_PYSIDE and self.opts['useCache'] and self._exportOpts is False: if self.opts['useCache'] and self._exportOpts is False:
tr = self.deviceTransform() tr = self.deviceTransform()
if tr is None: if tr is None:
return return
pts = np.empty((2,len(self.data['x']))) w = np.empty((2,len(self.data['width'])))
pts[0] = self.data['x'] w[0] = self.data['width']
pts[1] = self.data['y'] w[1] = self.data['width']
q, intv = tr.inverted()
if intv:
w = fn.transformCoordinates(q, w)
w=np.abs(w)
range = self.getViewBox().viewRange()
mask = np.logical_and(
np.logical_and(self.data['x'] + w[0,:] > range[0][0],
self.data['x'] - w[0,:] < range[0][1]),
np.logical_and(self.data['y'] + w[0,:] > range[1][0],
self.data['y'] - w[0,:] < range[1][1])) ## remove out of view points
data = self.data[mask]
else:
data = self.data
pts = np.empty((2,len(data['x'])))
pts[0] = data['x']
pts[1] = data['y']
pts = fn.transformCoordinates(tr, pts) pts = fn.transformCoordinates(tr, pts)
pts -= data['width']
pts = np.clip(pts, -2**30, 2**30) pts = np.clip(pts, -2**30, 2**30)
if self.target == None: if self.target == None:
self.target = list(imap(QtCore.QRectF.translated, self.data['rectTarg'], pts[0,:], pts[1,:])) list(imap(QtCore.QRectF.moveTo, data['rectTarg'], pts[0,:], pts[1,:]))
p.drawPixmapFragments(self.target, self.data['rectSrc'].tolist(), atlas) self.target=data['rectTarg']
if USE_PYSIDE:
list(imap(p.drawPixmap, self.target, repeat(atlas), data['rectSrc']))
else:
p.drawPixmapFragments(self.target.tolist(), data['rectSrc'].tolist(), atlas)
#p.drawPixmapFragments(self.fragments, atlas) #p.drawPixmapFragments(self.fragments, atlas)
else: else:
if self.fragments is None: if self.fragments is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment