ScatterPlotItem.py 37 KB
Newer Older
1 2 3
from ..Qt import QtGui, QtCore, USE_PYSIDE
from ..Point import Point
from .. import functions as fn
4 5
from .GraphicsItem import GraphicsItem
from .GraphicsObject import GraphicsObject
Guillaume Poulin's avatar
Guillaume Poulin committed
6
from itertools import starmap, repeat
Guillaume Poulin's avatar
Guillaume Poulin committed
7 8 9 10
try:
    from itertools import imap
except ImportError:
    imap = map
11
import numpy as np
Luke Campagnola's avatar
Luke Campagnola committed
12
import weakref
13 14 15 16
from .. import getConfigOption
from .. import debug as debug
from ..pgcollections import OrderedDict
from .. import debug
17 18

__all__ = ['ScatterPlotItem', 'SpotItem']
19 20 21


## Build all symbol paths
Luke Campagnola's avatar
Luke Campagnola committed
22
Symbols = OrderedDict([(name, QtGui.QPainterPath()) for name in ['o', 's', 't', 'd', '+', 'x']])
23 24 25 26 27 28 29 30 31 32 33
Symbols['o'].addEllipse(QtCore.QRectF(-0.5, -0.5, 1, 1))
Symbols['s'].addRect(QtCore.QRectF(-0.5, -0.5, 1, 1))
coords = {
    't': [(-0.5, -0.5), (0, 0.5), (0.5, -0.5)],
    'd': [(0., -0.5), (-0.4, 0.), (0, 0.5), (0.4, 0)],
    '+': [
        (-0.5, -0.05), (-0.5, 0.05), (-0.05, 0.05), (-0.05, 0.5),
        (0.05, 0.5), (0.05, 0.05), (0.5, 0.05), (0.5, -0.05), 
        (0.05, -0.05), (0.05, -0.5), (-0.05, -0.5), (-0.05, -0.05)
    ],
}
34
for k, c in coords.items():
35 36 37 38
    Symbols[k].moveTo(*c[0])
    for x,y in c[1:]:
        Symbols[k].lineTo(x, y)
    Symbols[k].closeSubpath()
Luke Campagnola's avatar
Luke Campagnola committed
39 40 41
tr = QtGui.QTransform()
tr.rotate(45)
Symbols['x'] = tr.map(Symbols['+'])
42

43 44
    
def drawSymbol(painter, symbol, size, pen, brush):
45 46
    if symbol is None:
        return
47 48 49 50 51 52
    painter.scale(size, size)
    painter.setPen(pen)
    painter.setBrush(brush)
    if isinstance(symbol, basestring):
        symbol = Symbols[symbol]
    if np.isscalar(symbol):
Luke Campagnola's avatar
Luke Campagnola committed
53
        symbol = list(Symbols.values())[symbol % len(Symbols)]
54
    painter.drawPath(symbol)
55

56 57 58 59 60 61 62 63 64
    
def renderSymbol(symbol, size, pen, brush, device=None):
    """
    Render a symbol specification to QImage.
    Symbol may be either a QPainterPath or one of the keys in the Symbols dict.
    If *device* is None, a new QPixmap will be returned. Otherwise,
    the symbol will be rendered into the device specified (See QPainter documentation 
    for more information).
    """
Luke Campagnola's avatar
Luke Campagnola committed
65
    ## Render a spot with the given parameters to a pixmap
66
    penPxWidth = max(np.ceil(pen.widthF()), 1)
67 68 69 70
    if device is None:
        device = QtGui.QImage(int(size+penPxWidth), int(size+penPxWidth), QtGui.QImage.Format_ARGB32)
        device.fill(0)
    p = QtGui.QPainter(device)
71 72 73 74 75 76
    try:
        p.setRenderHint(p.Antialiasing)
        p.translate(device.width()*0.5, device.height()*0.5)
        drawSymbol(p, symbol, size, pen, brush)
    finally:
        p.end()
77
    return device
Luke Campagnola's avatar
Luke Campagnola committed
78

79 80 81 82 83
def makeSymbolPixmap(size, pen, brush, symbol):
    ## deprecated
    img = renderSymbol(symbol, size, pen, brush)
    return QtGui.QPixmap(img)
    
84
class SymbolAtlas(object):
85 86 87 88 89 90 91 92 93 94 95 96
    """
    Used to efficiently construct a single QPixmap containing all rendered symbols
    for a ScatterPlotItem. This is required for fragment rendering.
    
    Use example:
        atlas = SymbolAtlas()
        sc1 = atlas.getSymbolCoords('o', 5, QPen(..), QBrush(..))
        sc2 = atlas.getSymbolCoords('t', 10, QPen(..), QBrush(..))
        pm = atlas.getAtlas()
        
    """
    def __init__(self):
Luke Campagnola's avatar
Luke Campagnola committed
97
        # symbol key : QRect(...) coordinates where symbol can be found in atlas.
98 99 100 101 102
        # note that the coordinate list will always be the same list object as 
        # long as the symbol is in the atlas, but the coordinates may
        # change if the atlas is rebuilt.
        # weak value; if all external refs to this list disappear, 
        # the symbol will be forgotten.
Luke Campagnola's avatar
Luke Campagnola committed
103
        self.symbolMap = weakref.WeakValueDictionary()
104 105 106 107
        
        self.atlasData = None # numpy array of atlas image
        self.atlas = None     # atlas as QPixmap
        self.atlasValid = False
108
        self.max_width=0
109 110 111 112 113
        
    def getSymbolCoords(self, opts):
        """
        Given a list of spot records, return an object representing the coordinates of that symbol within the atlas
        """
Luke Campagnola's avatar
Luke Campagnola committed
114
        sourceRect = np.empty(len(opts), dtype=object)
115
        keyi = None
Luke Campagnola's avatar
Luke Campagnola committed
116
        sourceRecti = None
117
        for i, rec in enumerate(opts):
Luke Campagnola's avatar
Luke Campagnola committed
118
            key = (rec[3], rec[2], id(rec[4]), id(rec[5]))   # TODO: use string indexes?
119
            if key == keyi:
Luke Campagnola's avatar
Luke Campagnola committed
120
                sourceRect[i] = sourceRecti
121 122
            else:
                try:
Luke Campagnola's avatar
Luke Campagnola committed
123
                    sourceRect[i] = self.symbolMap[key]
124
                except KeyError:
125
                    newRectSrc = QtCore.QRectF()
Luke Campagnola's avatar
Luke Campagnola committed
126 127 128
                    newRectSrc.pen = rec['pen']
                    newRectSrc.brush = rec['brush']
                    self.symbolMap[key] = newRectSrc
129
                    self.atlasValid = False
Luke Campagnola's avatar
Luke Campagnola committed
130
                    sourceRect[i] = newRectSrc
131
                    keyi = key
Luke Campagnola's avatar
Luke Campagnola committed
132 133
                    sourceRecti = newRectSrc
        return sourceRect
134 135 136 137 138 139 140
        
    def buildAtlas(self):
        # get rendered array for all symbols, keep track of avg/max width
        rendered = {}
        avgWidth = 0.0
        maxWidth = 0
        images = []
Luke Campagnola's avatar
Luke Campagnola committed
141 142 143
        for key, sourceRect in self.symbolMap.items():
            if sourceRect.width() == 0:
                img = renderSymbol(key[0], key[1], sourceRect.pen, sourceRect.brush)
144 145 146
                images.append(img)  ## we only need this to prevent the images being garbage collected immediately
                arr = fn.imageToArray(img, copy=False, transpose=False)
            else:
Luke Campagnola's avatar
Luke Campagnola committed
147
                (y,x,h,w) = sourceRect.getRect()
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
                arr = self.atlasData[x:x+w, y:y+w]
            rendered[key] = arr
            w = arr.shape[0]
            avgWidth += w
            maxWidth = max(maxWidth, w)
            
        nSymbols = len(rendered)
        if nSymbols > 0:
            avgWidth /= nSymbols
            width = max(maxWidth, avgWidth * (nSymbols**0.5))
        else:
            avgWidth = 0
            width = 0
        
        # sort symbols by height
        symbols = sorted(rendered.keys(), key=lambda x: rendered[x].shape[1], reverse=True)
        
        self.atlasRows = []
Luke Campagnola's avatar
Luke Campagnola committed
166

167 168 169 170 171 172 173 174 175 176 177
        x = width
        y = 0
        rowheight = 0
        for key in symbols:
            arr = rendered[key]
            w,h = arr.shape[:2]
            if x+w > width:
                y += rowheight
                x = 0
                rowheight = h
                self.atlasRows.append([y, rowheight, 0])
Luke Campagnola's avatar
Luke Campagnola committed
178
            self.symbolMap[key].setRect(y, x, h, w)
179 180 181
            x += w
            self.atlasRows[-1][2] = x
        height = y + rowheight
182

183 184
        self.atlasData = np.zeros((width, height, 4), dtype=np.ubyte)
        for key in symbols:
Luke Campagnola's avatar
Luke Campagnola committed
185
            y, x, h, w = self.symbolMap[key].getRect()
186 187 188
            self.atlasData[x:x+w, y:y+h] = rendered[key]
        self.atlas = None
        self.atlasValid = True
Luke Campagnola's avatar
Luke Campagnola committed
189
        self.max_width = maxWidth
190 191 192 193 194 195 196 197 198 199 200 201 202 203
    
    def getAtlas(self):
        if not self.atlasValid:
            self.buildAtlas()
        if self.atlas is None:
            if len(self.atlasData) == 0:
                return QtGui.QPixmap(0,0)
            img = fn.makeQImage(self.atlasData, copy=False, transpose=False)
            self.atlas = QtGui.QPixmap(img)
        return self.atlas
        
    
    
    
204
class ScatterPlotItem(GraphicsObject):
205 206 207 208
    """
    Displays a set of x/y points. Instances of this class are created
    automatically as part of PlotDataItem; these rarely need to be instantiated
    directly.
209
    
210 211 212 213 214 215 216 217 218 219 220 221
    The size, shape, pen, and fill brush may be set for each point individually 
    or for all points. 
    
    
    ========================  ===============================================
    **Signals:**
    sigPlotChanged(self)      Emitted when the data being plotted has changed
    sigClicked(self, points)  Emitted when the curve is clicked. Sends a list
                              of all the points under the mouse pointer.
    ========================  ===============================================
    
    """
222 223 224
    #sigPointClicked = QtCore.Signal(object, object)
    sigClicked = QtCore.Signal(object, object)  ## self, points
    sigPlotChanged = QtCore.Signal(object)
225 226 227 228
    def __init__(self, *args, **kargs):
        """
        Accepts the same arguments as setData()
        """
229
        profiler = debug.Profiler()
230
        GraphicsObject.__init__(self)
231 232 233 234
        
        self.picture = None   # QPicture used for rendering when pxmode==False
        self.fragmentAtlas = SymbolAtlas()
        
Luke Campagnola's avatar
Luke Campagnola committed
235
        self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('sourceRect', object), ('targetRect', object), ('width', float)])
236 237 238
        self.bounds = [None, None]  ## caches data bounds
        self._maxSpotWidth = 0      ## maximum size of the scale-variant portion of all spots
        self._maxSpotPxWidth = 0    ## maximum size of the scale-invariant portion of all spots
239 240 241
        self.opts = {
            'pxMode': True, 
            'useCache': True,  ## If useCache is False, symbols are re-drawn on every paint. 
242
            'antialias': getConfigOption('antialias'),
243
            'name': None,
244 245
        }

Luke Campagnola's avatar
Luke Campagnola committed
246 247
        self.setPen(fn.mkPen(getConfigOption('foreground')), update=False)
        self.setBrush(fn.mkBrush(100,100,150), update=False)
Luke Campagnola's avatar
Luke Campagnola committed
248 249
        self.setSymbol('o', update=False)
        self.setSize(7, update=False)
250
        profiler()
251
        self.setData(*args, **kargs)
252 253
        profiler('setData')

254 255
        #self.setCacheMode(self.DeviceCoordinateCache)
        
256
    def setData(self, *args, **kargs):
257
        """
258 259 260 261 262
        **Ordered Arguments:**
        
        * If there is only one unnamed argument, it will be interpreted like the 'spots' argument.
        * If there are two unnamed arguments, they will be interpreted as sequences of x and y values.
        
263
        ====================== ===============================================================================================
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
        **Keyword Arguments:**
        *spots*                Optional list of dicts. Each dict specifies parameters for a single spot:
                               {'pos': (x,y), 'size', 'pen', 'brush', 'symbol'}. This is just an alternate method
                               of passing in data for the corresponding arguments.
        *x*,*y*                1D arrays of x,y values.
        *pos*                  2D structure of x,y pairs (such as Nx2 array or list of tuples)
        *pxMode*               If True, spots are always the same size regardless of scaling, and size is given in px.
                               Otherwise, size is in scene coordinates and the spots scale with the view.
                               Default is True
        *symbol*               can be one (or a list) of:
                               * 'o'  circle (default)
                               * 's'  square
                               * 't'  triangle
                               * 'd'  diamond
                               * '+'  plus
Luke Campagnola's avatar
Luke Campagnola committed
279 280 281
                               * any QPainterPath to specify custom symbol shapes. To properly obey the position and size,
                               custom symbols should be centered at (0,0) and width and height of 1.0. Note that it is also
                               possible to 'install' custom shapes by setting ScatterPlotItem.Symbols[key] = shape.
282 283 284 285 286
        *pen*                  The pen (or list of pens) to use for drawing spot outlines.
        *brush*                The brush (or list of brushes) to use for filling spots.
        *size*                 The size (or list of sizes) of spots. If *pxMode* is True, this value is in pixels. Otherwise,
                               it is in the item's local coordinate system.
        *data*                 a list of python objects used to uniquely identify each spot.
Luke Campagnola's avatar
Luke Campagnola committed
287
        *identical*            *Deprecated*. This functionality is handled automatically now.
288 289 290
        *antialias*            Whether to draw symbols with antialiasing. Note that if pxMode is True, symbols are 
                               always rendered with antialiasing (since the rendered symbols can be cached, this 
                               incurs very little performance cost)
291 292
        *name*                 The name of this item. Names are used for automatically
                               generating LegendItem entries and by some exporters.
293
        ====================== ===============================================================================================
294
        """
295
        oldData = self.data  ## this causes cached pixmaps to be preserved while new data is registered.
Luke Campagnola's avatar
Luke Campagnola committed
296 297
        self.clear()  ## clear out all old data
        self.addPoints(*args, **kargs)
298

Luke Campagnola's avatar
Luke Campagnola committed
299 300 301 302 303
    def addPoints(self, *args, **kargs):
        """
        Add new points to the scatter plot. 
        Arguments are the same as setData()
        """
304
        
305 306 307 308 309 310 311 312
        ## deal with non-keyword arguments
        if len(args) == 1:
            kargs['spots'] = args[0]
        elif len(args) == 2:
            kargs['x'] = args[0]
            kargs['y'] = args[1]
        elif len(args) > 2:
            raise Exception('Only accepts up to two non-keyword arguments.')
313
        
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
        ## convert 'pos' argument to 'x' and 'y'
        if 'pos' in kargs:
            pos = kargs['pos']
            if isinstance(pos, np.ndarray):
                kargs['x'] = pos[:,0]
                kargs['y'] = pos[:,1]
            else:
                x = []
                y = []
                for p in pos:
                    if isinstance(p, QtCore.QPointF):
                        x.append(p.x())
                        y.append(p.y())
                    else:
                        x.append(p[0])
                        y.append(p[1])
                kargs['x'] = x
                kargs['y'] = y
332
        
333 334 335 336 337
        ## determine how many spots we have
        if 'spots' in kargs:
            numPts = len(kargs['spots'])
        elif 'y' in kargs and kargs['y'] is not None:
            numPts = len(kargs['y'])
338
        else:
339 340 341
            kargs['x'] = []
            kargs['y'] = []
            numPts = 0
342
        
Luke Campagnola's avatar
Luke Campagnola committed
343 344 345 346 347 348
        ## Extend record array
        oldData = self.data
        self.data = np.empty(len(oldData)+numPts, dtype=self.data.dtype)
        ## note that np.empty initializes object fields to None and string fields to ''
        
        self.data[:len(oldData)] = oldData
349 350
        #for i in range(len(oldData)):
            #oldData[i]['item']._data = self.data[i]  ## Make sure items have proper reference to new array
Luke Campagnola's avatar
Luke Campagnola committed
351 352 353
            
        newData = self.data[len(oldData):]
        newData['size'] = -1  ## indicates to use default size
354

355 356
        if 'spots' in kargs:
            spots = kargs['spots']
357
            for i in range(len(spots)):
358 359
                spot = spots[i]
                for k in spot:
Luke Campagnola's avatar
Luke Campagnola committed
360
                    if k == 'pos':
361 362 363 364 365
                        pos = spot[k]
                        if isinstance(pos, QtCore.QPointF):
                            x,y = pos.x(), pos.y()
                        else:
                            x,y = pos[0], pos[1]
Luke Campagnola's avatar
Luke Campagnola committed
366 367
                        newData[i]['x'] = x
                        newData[i]['y'] = y
368 369
                    elif k == 'pen':
                        newData[i][k] = fn.mkPen(spot[k])
370 371
                    elif k == 'brush':
                        newData[i][k] = fn.mkBrush(spot[k])
372
                    elif k in ['x', 'y', 'size', 'symbol', 'brush', 'data']:
Luke Campagnola's avatar
Luke Campagnola committed
373
                        newData[i][k] = spot[k]
374 375 376
                    else:
                        raise Exception("Unknown spot parameter: %s" % k)
        elif 'y' in kargs:
Luke Campagnola's avatar
Luke Campagnola committed
377 378
            newData['x'] = kargs['x']
            newData['y'] = kargs['y']
379
        
Luke Campagnola's avatar
Luke Campagnola committed
380
        if 'pxMode' in kargs:
381
            self.setPxMode(kargs['pxMode'])
382 383
        if 'antialias' in kargs:
            self.opts['antialias'] = kargs['antialias']
Luke Campagnola's avatar
Luke Campagnola committed
384
            
385
        ## Set any extra parameters provided in keyword arguments
Luke Campagnola's avatar
Luke Campagnola committed
386
        for k in ['pen', 'brush', 'symbol', 'size']:
387 388
            if k in kargs:
                setMethod = getattr(self, 'set' + k[0].upper() + k[1:])
389
                setMethod(kargs[k], update=False, dataSet=newData, mask=kargs.get('mask', None))
390

Luke Campagnola's avatar
Luke Campagnola committed
391 392
        if 'data' in kargs:
            self.setPointData(kargs['data'], dataSet=newData)
393

Luke Campagnola's avatar
Luke Campagnola committed
394
        self.prepareGeometryChange()
Luke Campagnola's avatar
Luke Campagnola committed
395
        self.informViewBoundsChanged()
396
        self.bounds = [None, None]
397 398
        self.invalidate()
        self.updateSpots(newData)
Luke Campagnola's avatar
Luke Campagnola committed
399
        self.sigPlotChanged.emit(self)
400
        
401 402 403 404 405
    def invalidate(self):
        ## clear any cached drawing state
        self.picture = None
        self.update()
        
Luke Campagnola's avatar
Luke Campagnola committed
406
    def getData(self):
Luke Campagnola's avatar
Luke Campagnola committed
407
        return self.data['x'], self.data['y']    
408 409
        
    def setPoints(self, *args, **kargs):
410
        ##Deprecated; use setData
411 412
        return self.setData(*args, **kargs)
        
413 414 415 416 417 418
    def implements(self, interface=None):
        ints = ['plotData']
        if interface is None:
            return ints
        return interface in ints
    
419 420 421
    def name(self):
        return self.opts.get('name', None)
    
422
    def setPen(self, *args, **kargs):
Luke Campagnola's avatar
Luke Campagnola committed
423 424 425 426 427 428
        """Set the pen(s) used to draw the outline around each spot. 
        If a list or array is provided, then the pen for each spot will be set separately.
        Otherwise, the arguments are passed to pg.mkPen and used as the default pen for 
        all spots which do not have a pen explicitly set."""
        update = kargs.pop('update', True)
        dataSet = kargs.pop('dataSet', self.data)
429

430 431
        if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)):
            pens = args[0]
Luke Campagnola's avatar
Luke Campagnola committed
432
            if 'mask' in kargs and kargs['mask'] is not None:
433
                pens = pens[kargs['mask']]
Luke Campagnola's avatar
Luke Campagnola committed
434 435 436
            if len(pens) != len(dataSet):
                raise Exception("Number of pens does not match number of points (%d != %d)" % (len(pens), len(dataSet)))
            dataSet['pen'] = pens
437 438
        else:
            self.opts['pen'] = fn.mkPen(*args, **kargs)
Luke Campagnola's avatar
Luke Campagnola committed
439
        
Luke Campagnola's avatar
Luke Campagnola committed
440
        dataSet['sourceRect'] = None
Luke Campagnola's avatar
Luke Campagnola committed
441 442
        if update:
            self.updateSpots(dataSet)
443 444
        
    def setBrush(self, *args, **kargs):
Luke Campagnola's avatar
Luke Campagnola committed
445 446 447 448 449 450 451
        """Set the brush(es) used to fill the interior of each spot. 
        If a list or array is provided, then the brush for each spot will be set separately.
        Otherwise, the arguments are passed to pg.mkBrush and used as the default brush for 
        all spots which do not have a brush explicitly set."""
        update = kargs.pop('update', True)
        dataSet = kargs.pop('dataSet', self.data)
            
452 453
        if len(args) == 1 and (isinstance(args[0], np.ndarray) or isinstance(args[0], list)):
            brushes = args[0]
454
            if 'mask' in kargs and kargs['mask'] is not None:
455
                brushes = brushes[kargs['mask']]
Luke Campagnola's avatar
Luke Campagnola committed
456 457 458 459 460
            if len(brushes) != len(dataSet):
                raise Exception("Number of brushes does not match number of points (%d != %d)" % (len(brushes), len(dataSet)))
            #for i in xrange(len(brushes)):
                #self.data[i]['brush'] = fn.mkBrush(brushes[i], **kargs)
            dataSet['brush'] = brushes
461 462
        else:
            self.opts['brush'] = fn.mkBrush(*args, **kargs)
463
            #self._spotPixmap = None
Luke Campagnola's avatar
Luke Campagnola committed
464
        
Luke Campagnola's avatar
Luke Campagnola committed
465
        dataSet['sourceRect'] = None
Luke Campagnola's avatar
Luke Campagnola committed
466 467
        if update:
            self.updateSpots(dataSet)
468

469
    def setSymbol(self, symbol, update=True, dataSet=None, mask=None):
Luke Campagnola's avatar
Luke Campagnola committed
470 471 472 473 474 475 476
        """Set the symbol(s) used to draw each spot. 
        If a list or array is provided, then the symbol for each spot will be set separately.
        Otherwise, the argument will be used as the default symbol for 
        all spots which do not have a symbol explicitly set."""
        if dataSet is None:
            dataSet = self.data
            
477 478
        if isinstance(symbol, np.ndarray) or isinstance(symbol, list):
            symbols = symbol
Luke Campagnola's avatar
Luke Campagnola committed
479 480
            if mask is not None:
                symbols = symbols[mask]
Luke Campagnola's avatar
Luke Campagnola committed
481 482 483
            if len(symbols) != len(dataSet):
                raise Exception("Number of symbols does not match number of points (%d != %d)" % (len(symbols), len(dataSet)))
            dataSet['symbol'] = symbols
484 485
        else:
            self.opts['symbol'] = symbol
Luke Campagnola's avatar
Luke Campagnola committed
486
            self._spotPixmap = None
487
        
Luke Campagnola's avatar
Luke Campagnola committed
488
        dataSet['sourceRect'] = None
Luke Campagnola's avatar
Luke Campagnola committed
489 490 491
        if update:
            self.updateSpots(dataSet)
    
492
    def setSize(self, size, update=True, dataSet=None, mask=None):
Luke Campagnola's avatar
Luke Campagnola committed
493 494 495 496 497 498 499
        """Set the size(s) used to draw each spot. 
        If a list or array is provided, then the size for each spot will be set separately.
        Otherwise, the argument will be used as the default size for 
        all spots which do not have a size explicitly set."""
        if dataSet is None:
            dataSet = self.data
            
500 501
        if isinstance(size, np.ndarray) or isinstance(size, list):
            sizes = size
502 503
            if mask is not None:
                sizes = sizes[mask]
Luke Campagnola's avatar
Luke Campagnola committed
504 505 506
            if len(sizes) != len(dataSet):
                raise Exception("Number of sizes does not match number of points (%d != %d)" % (len(sizes), len(dataSet)))
            dataSet['size'] = sizes
507 508
        else:
            self.opts['size'] = size
Luke Campagnola's avatar
Luke Campagnola committed
509 510
            self._spotPixmap = None
            
Luke Campagnola's avatar
Luke Campagnola committed
511
        dataSet['sourceRect'] = None
Luke Campagnola's avatar
Luke Campagnola committed
512 513
        if update:
            self.updateSpots(dataSet)
514
        
515
    def setPointData(self, data, dataSet=None, mask=None):
Luke Campagnola's avatar
Luke Campagnola committed
516 517 518
        if dataSet is None:
            dataSet = self.data
            
519
        if isinstance(data, np.ndarray) or isinstance(data, list):
520 521
            if mask is not None:
                data = data[mask]
Luke Campagnola's avatar
Luke Campagnola committed
522 523
            if len(data) != len(dataSet):
                raise Exception("Length of meta data does not match number of points (%d != %d)" % (len(data), len(dataSet)))
Luke Campagnola's avatar
Luke Campagnola committed
524 525 526
        
        ## Bug: If data is a numpy record array, then items from that array must be copied to dataSet one at a time.
        ## (otherwise they are converted to tuples and thus lose their field names.
527
        if isinstance(data, np.ndarray) and (data.dtype.fields is not None)and len(data.dtype.fields) > 1:
Luke Campagnola's avatar
Luke Campagnola committed
528 529 530 531
            for i, rec in enumerate(data):
                dataSet['data'][i] = rec
        else:
            dataSet['data'] = data
532
        
533
    def setPxMode(self, mode):
Luke Campagnola's avatar
Luke Campagnola committed
534 535 536
        if self.opts['pxMode'] == mode:
            return
            
537
        self.opts['pxMode'] = mode
538
        self.invalidate()
539
        
Luke Campagnola's avatar
Luke Campagnola committed
540 541 542
    def updateSpots(self, dataSet=None):
        if dataSet is None:
            dataSet = self.data
543

544 545
        invalidate = False
        if self.opts['pxMode']:
Luke Campagnola's avatar
Luke Campagnola committed
546
            mask = np.equal(dataSet['sourceRect'], None)
547 548 549
            if np.any(mask):
                invalidate = True
                opts = self.getSpotOpts(dataSet[mask])
Luke Campagnola's avatar
Luke Campagnola committed
550 551
                sourceRect = self.fragmentAtlas.getSymbolCoords(opts)
                dataSet['sourceRect'][mask] = sourceRect
Guillaume Poulin's avatar
Guillaume Poulin committed
552
                
Luke Campagnola's avatar
Luke Campagnola committed
553 554
            self.fragmentAtlas.getAtlas() # generate atlas so source widths are available.
            
Luke Campagnola's avatar
Luke Campagnola committed
555
            dataSet['width'] = np.array(list(imap(QtCore.QRectF.width, dataSet['sourceRect'])))/2
Luke Campagnola's avatar
Luke Campagnola committed
556
            dataSet['targetRect'] = None
Luke Campagnola's avatar
Luke Campagnola committed
557
            self._maxSpotPxWidth = self.fragmentAtlas.max_width
558 559 560 561 562
        else:
            self._maxSpotWidth = 0
            self._maxSpotPxWidth = 0
            self.measureSpotSizes(dataSet)

563 564
        if invalidate:
            self.invalidate()
565

566
    def getSpotOpts(self, recs, scale=1.0):
567 568 569 570 571 572 573 574 575 576 577 578 579 580
        if recs.ndim == 0:
            rec = recs
            symbol = rec['symbol']
            if symbol is None:
                symbol = self.opts['symbol']
            size = rec['size']
            if size < 0:
                size = self.opts['size']
            pen = rec['pen']
            if pen is None:
                pen = self.opts['pen']
            brush = rec['brush']
            if brush is None:
                brush = self.opts['brush']
581
            return (symbol, size*scale, fn.mkPen(pen), fn.mkBrush(brush))
582 583 584 585
        else:
            recs = recs.copy()
            recs['symbol'][np.equal(recs['symbol'], None)] = self.opts['symbol']
            recs['size'][np.equal(recs['size'], -1)] = self.opts['size']
586
            recs['size'] *= scale
587 588 589 590 591 592
            recs['pen'][np.equal(recs['pen'], None)] = fn.mkPen(self.opts['pen'])
            recs['brush'][np.equal(recs['brush'], None)] = fn.mkBrush(self.opts['brush'])
            return recs
            
            
        
593
    def measureSpotSizes(self, dataSet):
594
        for rec in dataSet:
595
            ## keep track of the maximum spot size and pixel size
596
            symbol, size, pen, brush = self.getSpotOpts(rec)
597 598
            width = 0
            pxWidth = 0
599
            if self.opts['pxMode']:
600
                pxWidth = size + pen.widthF()
601
            else:
602
                width = size
603
                if pen.isCosmetic():
604
                    pxWidth += pen.widthF()
605
                else:
606
                    width += pen.widthF()
607 608
            self._maxSpotWidth = max(self._maxSpotWidth, width)
            self._maxSpotPxWidth = max(self._maxSpotPxWidth, pxWidth)
609
        self.bounds = [None, None]
610 611
    
    
612
    def clear(self):
Luke Campagnola's avatar
Luke Campagnola committed
613
        """Remove all spots from the scatter plot"""
614
        #self.clearItems()
Luke Campagnola's avatar
Luke Campagnola committed
615 616
        self.data = np.empty(0, dtype=self.data.dtype)
        self.bounds = [None, None]
617
        self.invalidate()
Luke Campagnola's avatar
Luke Campagnola committed
618

619
    def dataBounds(self, ax, frac=1.0, orthoRange=None):
Luke Campagnola's avatar
Luke Campagnola committed
620
        if frac >= 1.0 and orthoRange is None and self.bounds[ax] is not None:
621
            return self.bounds[ax]
622
        
Luke Campagnola's avatar
Luke Campagnola committed
623
        #self.prepareGeometryChange()
624 625 626 627 628
        if self.data is None or len(self.data) == 0:
            return (None, None)
        
        if ax == 0:
            d = self.data['x']
629
            d2 = self.data['y']
630 631
        elif ax == 1:
            d = self.data['y']
632 633 634 635 636 637
            d2 = self.data['x']
        
        if orthoRange is not None:
            mask = (d2 >= orthoRange[0]) * (d2 <= orthoRange[1])
            d = d[mask]
            d2 = d2[mask]
638 639
            
        if frac >= 1.0:
640
            self.bounds[ax] = (np.nanmin(d) - self._maxSpotWidth*0.7072, np.nanmax(d) + self._maxSpotWidth*0.7072)
641 642 643 644
            return self.bounds[ax]
        elif frac <= 0.0:
            raise Exception("Value for parameter 'frac' must be > 0. (got %s)" % str(frac))
        else:
645 646
            mask = np.isfinite(d)
            d = d[mask]
Antony Lee's avatar
Antony Lee committed
647 648
            return np.percentile(d, [50 * (1 - frac), 50 * (1 + frac)])

649
    def pixelPadding(self):
650
        return self._maxSpotPxWidth*0.7072
651 652

    def boundingRect(self):
653 654 655 656 657 658 659 660
        (xmn, xmx) = self.dataBounds(ax=0)
        (ymn, ymx) = self.dataBounds(ax=1)
        if xmn is None or xmx is None:
            xmn = 0
            xmx = 0
        if ymn is None or ymx is None:
            ymn = 0
            ymx = 0
661 662
        
        px = py = 0.0
663 664
        pxPad = self.pixelPadding()
        if pxPad > 0:
665 666
            # determine length of pixel in local x, y directions    
            px, py = self.pixelVectors()
667 668 669 670 671 672 673 674
            try:
                px = 0 if px is None else px.length()
            except OverflowError:
                px = 0
            try:
                py = 0 if py is None else py.length()
            except OverflowError:
                py = 0
675 676
            
            # return bounds expanded by pixel size
677 678
            px *= pxPad
            py *= pxPad
679
        return QtCore.QRectF(xmn-px, ymn-py, (2*px)+xmx-xmn, (2*py)+ymx-ymn)
680

681
    def viewTransformChanged(self):
Luke Campagnola's avatar
Luke Campagnola committed
682
        self.prepareGeometryChange()
683
        GraphicsObject.viewTransformChanged(self)
684
        self.bounds = [None, None]
Luke Campagnola's avatar
Luke Campagnola committed
685
        self.data['targetRect'] = None
686

687 688 689
    def setExportMode(self, *args, **kwds):
        GraphicsObject.setExportMode(self, *args, **kwds)
        self.invalidate()
Guillaume Poulin's avatar
Guillaume Poulin committed
690 691


Luke Campagnola's avatar
Luke Campagnola committed
692 693
    def mapPointsToDevice(self, pts):
        # Map point locations to device        
694 695
        tr = self.deviceTransform()
        if tr is None:
Luke Campagnola's avatar
Luke Campagnola committed
696
            return None
Guillaume Poulin's avatar
Guillaume Poulin committed
697

Luke Campagnola's avatar
Luke Campagnola committed
698 699 700
        #pts = np.empty((2,len(self.data['x'])))
        #pts[0] = self.data['x']
        #pts[1] = self.data['y']
701
        pts = fn.transformCoordinates(tr, pts)
702
        pts -= self.data['width']
Luke Campagnola's avatar
Luke Campagnola committed
703
        pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault.
704
        
Luke Campagnola's avatar
Luke Campagnola committed
705 706 707 708 709 710 711 712
        return pts

    def getViewMask(self, pts):
        # Return bool mask indicating all points that are within viewbox
        # pts is expressed in *device coordiantes*
        vb = self.getViewBox()
        if vb is None:
            return None
713 714 715 716 717 718
        viewBounds = vb.mapRectToDevice(vb.boundingRect())
        w = self.data['width']
        mask = ((pts[0] + w > viewBounds.left()) &
                (pts[0] - w < viewBounds.right()) &
                (pts[1] + w > viewBounds.top()) &
                (pts[1] - w < viewBounds.bottom())) ## remove out of view points 
Luke Campagnola's avatar
Luke Campagnola committed
719 720
        return mask
        
Luke Campagnola's avatar
Luke Campagnola committed
721
        
722
    @debug.warnOnException  ## raising an exception here causes crash
723
    def paint(self, p, *args):
724

725 726
        #p.setPen(fn.mkPen('r'))
        #p.drawRect(self.boundingRect())
727
        
728 729 730 731 732 733 734
        if self._exportOpts is not False:
            aa = self._exportOpts.get('antialias', True)
            scale = self._exportOpts.get('resolutionScale', 1.0)  ## exporting to image; pixel resolution may have changed
        else:
            aa = self.opts['antialias']
            scale = 1.0
            
735
        if self.opts['pxMode'] is True:
736 737
            p.resetTransform()
            
Luke Campagnola's avatar
Luke Campagnola committed
738 739 740 741
            # Map point coordinates to device
            pts = np.vstack([self.data['x'], self.data['y']])
            pts = self.mapPointsToDevice(pts)
            if pts is None:
Guillaume Poulin's avatar
Guillaume Poulin committed
742 743
                return
            
Luke Campagnola's avatar
Luke Campagnola committed
744 745 746 747 748
            # Cull points that are outside view
            viewMask = self.getViewMask(pts)
            #pts = pts[:,mask]
            #data = self.data[mask]
            
Guillaume Poulin's avatar
Guillaume Poulin committed
749
            if self.opts['useCache'] and self._exportOpts is False:
Luke Campagnola's avatar
Luke Campagnola committed
750 751 752 753 754 755 756 757 758
                # Draw symbols from pre-rendered atlas
                atlas = self.fragmentAtlas.getAtlas()
                
                # Update targetRects if necessary
                updateMask = viewMask & np.equal(self.data['targetRect'], None)
                if np.any(updateMask):
                    updatePts = pts[:,updateMask]
                    width = self.data[updateMask]['width']*2
                    self.data['targetRect'][updateMask] = list(imap(QtCore.QRectF, updatePts[0,:], updatePts[1,:], width, width))
Guillaume Poulin's avatar
Guillaume Poulin committed
759
                
Luke Campagnola's avatar
Luke Campagnola committed
760
                data = self.data[viewMask]
Guillaume Poulin's avatar
Guillaume Poulin committed
761
                if USE_PYSIDE:
Luke Campagnola's avatar
Luke Campagnola committed
762
                    list(imap(p.drawPixmap, data['targetRect'], repeat(atlas), data['sourceRect']))
Guillaume Poulin's avatar
Guillaume Poulin committed
763
                else:
Luke Campagnola's avatar
Luke Campagnola committed
764
                    p.drawPixmapFragments(data['targetRect'].tolist(), data['sourceRect'].tolist(), atlas)
765
            else:
Luke Campagnola's avatar
Luke Campagnola committed
766
                # render each symbol individually
767
                p.setRenderHint(p.Antialiasing, aa)
Guillaume Poulin's avatar
Guillaume Poulin committed
768

Luke Campagnola's avatar
Luke Campagnola committed
769 770 771
                data = self.data[viewMask]
                pts = pts[:,viewMask]
                for i, rec in enumerate(data):
772
                    p.resetTransform()
Guillaume Poulin's avatar
Guillaume Poulin committed
773
                    p.translate(pts[0,i] + rec['width'], pts[1,i] + rec['width'])
774
                    drawSymbol(p, *self.getSpotOpts(rec, scale))
775 776 777 778 779
        else:
            if self.picture is None:
                self.picture = QtGui.QPicture()
                p2 = QtGui.QPainter(self.picture)
                for rec in self.data:
780 781 782
                    if scale != 1.0:
                        rec = rec.copy()
                        rec['size'] *= scale
783 784
                    p2.resetTransform()
                    p2.translate(rec['x'], rec['y'])
785
                    drawSymbol(p2, *self.getSpotOpts(rec, scale))
786 787
                p2.end()
                
788
            p.setRenderHint(p.Antialiasing, aa)
789
            self.picture.play(p)
790
        
791
    def points(self):
792 793 794
        for rec in self.data:
            if rec['item'] is None:
                rec['item'] = SpotItem(rec, self)
Luke Campagnola's avatar
Luke Campagnola committed
795 796
        return self.data['item']
        
797 798 799 800 801 802
    def pointsAt(self, pos):
        x = pos.x()
        y = pos.y()
        pw = self.pixelWidth()
        ph = self.pixelHeight()
        pts = []
Luke Campagnola's avatar
Luke Campagnola committed
803
        for s in self.points():
804
            sp = s.pos()
Luke Campagnola's avatar
Luke Campagnola committed
805
            ss = s.size()
806 807 808
            sx = sp.x()
            sy = sp.y()
            s2x = s2y = ss * 0.5
809
            if self.opts['pxMode']:
810 811 812 813 814 815 816 817
                s2x *= pw
                s2y *= ph
            if x > sx-s2x and x < sx+s2x and y > sy-s2y and y < sy+s2y:
                pts.append(s)
                #print "HIT:", x, y, sx, sy, s2x, s2y
            #else:
                #print "No hit:", (x, y), (sx, sy)
                #print "       ", (sx-s2x, sy-s2y), (sx+s2x, sy+s2y)
818 819
        #pts.sort(lambda a,b: cmp(b.zValue(), a.zValue()))
        return pts[::-1]
820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
            

    def mouseClickEvent(self, ev):
        if ev.button() == QtCore.Qt.LeftButton:
            pts = self.pointsAt(ev.pos())
            if len(pts) > 0:
                self.ptsClicked = pts
                self.sigClicked.emit(self, self.ptsClicked)
                ev.accept()
            else:
                #print "no spots"
                ev.ignore()
        else:
            ev.ignore()


836
class SpotItem(object):
Luke Campagnola's avatar
Luke Campagnola committed
837 838 839 840 841
    """
    Class referring to individual spots in a scatter plot.
    These can be retrieved by calling ScatterPlotItem.points() or 
    by connecting to the ScatterPlotItem's click signals.
    """
842

Luke Campagnola's avatar
Luke Campagnola committed
843
    def __init__(self, data, plot):
844
        #GraphicsItem.__init__(self, register=False)
Luke Campagnola's avatar
Luke Campagnola committed
845 846
        self._data = data
        self._plot = plot
847 848 849
        #self.setParentItem(plot)
        #self.setPos(QtCore.QPointF(data['x'], data['y']))
        #self.updateItem()
850
    
Luke Campagnola's avatar
Luke Campagnola committed
851 852 853 854 855 856 857 858 859 860 861 862
    def data(self):
        """Return the user data associated with this spot."""
        return self._data['data']
    
    def