diff --git a/pyqtgraph/graphicsItems/ScatterPlotItem.py b/pyqtgraph/graphicsItems/ScatterPlotItem.py
index 8de985fca51378c7e7bad3e4862a9c4b0ac1fc51..1c11fcf9b80b094811d81d58e2bb3fe2d20e63b6 100644
--- a/pyqtgraph/graphicsItems/ScatterPlotItem.py
+++ b/pyqtgraph/graphicsItems/ScatterPlotItem.py
@@ -228,8 +228,6 @@ class ScatterPlotItem(GraphicsObject):
         GraphicsObject.__init__(self)
         
         self.picture = None   # QPicture used for rendering when pxmode==False
-        self.fragments = None # fragment specification for pxmode; updated every time the view changes.
-        self.target = None
         self.fragmentAtlas = SymbolAtlas()
         
         self.data = np.empty(0, dtype=[('x', float), ('y', float), ('size', float), ('symbol', object), ('pen', object), ('brush', object), ('data', object), ('item', object), ('sourceRect', object), ('targetRect', object), ('width', float)])
@@ -394,6 +392,7 @@ class ScatterPlotItem(GraphicsObject):
             self.setPointData(kargs['data'], dataSet=newData)
             
         self.prepareGeometryChange()
+        self.informViewBoundsChanged()
         self.bounds = [None, None]
         self.invalidate()
         self.updateSpots(newData)
@@ -402,13 +401,10 @@ class ScatterPlotItem(GraphicsObject):
     def invalidate(self):
         ## clear any cached drawing state
         self.picture = None
-        self.fragments = None
-        self.target = None
         self.update()
         
     def getData(self):
-        return self.data['x'], self.data['y']
-    
+        return self.data['x'], self.data['y']    
         
     def setPoints(self, *args, **kargs):
         ##Deprecated; use setData
@@ -554,14 +550,10 @@ class ScatterPlotItem(GraphicsObject):
                 sourceRect = self.fragmentAtlas.getSymbolCoords(opts)
                 dataSet['sourceRect'][mask] = sourceRect
                 
-                
-            #for rec in dataSet:
-                #if rec['fragCoords'] is None:
-                    #invalidate = True
-                    #rec['fragCoords'] = self.fragmentAtlas.getSymbolCoords(*self.getSpotOpts(rec))
-            self.fragmentAtlas.getAtlas()
+            self.fragmentAtlas.getAtlas() # generate atlas so source widths are available.
+            
             dataSet['width'] = np.array(list(imap(QtCore.QRectF.width, dataSet['sourceRect'])))/2
-            dataSet['targetRect'] = list(imap(QtCore.QRectF, repeat(0), repeat(0), dataSet['width']*2, dataSet['width']*2))
+            dataSet['targetRect'] = None
             self._maxSpotPxWidth = self.fragmentAtlas.max_width
         else:
             self._maxSpotWidth = 0
@@ -684,40 +676,42 @@ class ScatterPlotItem(GraphicsObject):
         self.prepareGeometryChange()
         GraphicsObject.viewTransformChanged(self)
         self.bounds = [None, None]
-        self.fragments = None
-        self.target = None
+        self.data['targetRect'] = None
 
     def setExportMode(self, *args, **kwds):
         GraphicsObject.setExportMode(self, *args, **kwds)
         self.invalidate()
 
 
-    def getTransformedPoint(self):
-        # Map point locations to device
-        
-        vb = self.getViewBox()
-        if vb is None:
-            return None, None
+    def mapPointsToDevice(self, pts):
+        # Map point locations to device        
         tr = self.deviceTransform()
         if tr is None:
-            return None, None
+            return None
 
-        pts = np.empty((2,len(self.data['x'])))
-        pts[0] = self.data['x']
-        pts[1] = self.data['y']
+        #pts = np.empty((2,len(self.data['x'])))
+        #pts[0] = self.data['x']
+        #pts[1] = self.data['y']
         pts = fn.transformCoordinates(tr, pts)
         pts -= self.data['width']
         pts = np.clip(pts, -2**30, 2**30) ## prevent Qt segmentation fault.
         
-        ## Remove out of view points
+        return pts
+
+    def getViewMask(self, pts):
+        # Return bool mask indicating all points that are within viewbox
+        # pts is expressed in *device coordiantes*
+        vb = self.getViewBox()
+        if vb is None:
+            return None
         viewBounds = vb.mapRectToDevice(vb.boundingRect())
         w = self.data['width']
         mask = ((pts[0] + w > viewBounds.left()) &
                 (pts[0] - w < viewBounds.right()) &
                 (pts[1] + w > viewBounds.top()) &
                 (pts[1] - w < viewBounds.bottom())) ## remove out of view points 
-        print np.sum(mask)
-        return self.data[mask], pts[:, mask]
+        return mask
+        
         
     @debug.warnOnException  ## raising an exception here causes crash
     def paint(self, p, *args):
@@ -733,27 +727,42 @@ class ScatterPlotItem(GraphicsObject):
             scale = 1.0
             
         if self.opts['pxMode'] is True:
-            atlas = self.fragmentAtlas.getAtlas()
             p.resetTransform()
             
-            data, pts =  self.getTransformedPoint()
-            if data is None:
+            # Map point coordinates to device
+            pts = np.vstack([self.data['x'], self.data['y']])
+            pts = self.mapPointsToDevice(pts)
+            if pts is None:
                 return
             
+            # Cull points that are outside view
+            viewMask = self.getViewMask(pts)
+            #pts = pts[:,mask]
+            #data = self.data[mask]
+            
             if self.opts['useCache'] and self._exportOpts is False:
+                # Draw symbols from pre-rendered atlas
+                atlas = self.fragmentAtlas.getAtlas()
+                
+                # Update targetRects if necessary
+                updateMask = viewMask & np.equal(self.data['targetRect'], None)
+                if np.any(updateMask):
+                    updatePts = pts[:,updateMask]
+                    width = self.data[updateMask]['width']*2
+                    self.data['targetRect'][updateMask] = list(imap(QtCore.QRectF, updatePts[0,:], updatePts[1,:], width, width))
                 
-                if self.target == None:
-                    list(imap(QtCore.QRectF.moveTo, data['targetRect'], pts[0,:], pts[1,:]))
-                    self.target = data['targetRect']
+                data = self.data[viewMask]
                 if USE_PYSIDE:
-                    list(imap(p.drawPixmap, self.target, repeat(atlas), data['sourceRect']))
+                    list(imap(p.drawPixmap, data['targetRect'], repeat(atlas), data['sourceRect']))
                 else:
-                    p.drawPixmapFragments(self.target.tolist(), data['sourceRect'].tolist(), atlas)
+                    p.drawPixmapFragments(data['targetRect'].tolist(), data['sourceRect'].tolist(), atlas)
             else:
+                # render each symbol individually
                 p.setRenderHint(p.Antialiasing, aa)
 
-                for i in range(len(self.data)):
-                    rec = data[i]
+                data = self.data[viewMask]
+                pts = pts[:,viewMask]
+                for i, rec in enumerate(data):
                     p.resetTransform()
                     p.translate(pts[0,i] + rec['width'], pts[1,i] + rec['width'])
                     drawSymbol(p, *self.getSpotOpts(rec, scale))
diff --git a/pyqtgraph/graphicsItems/tests/ScatterPlotItem.py b/pyqtgraph/graphicsItems/tests/ScatterPlotItem.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef8271bfe391785001d8a82c0cfc56274b8f4d24
--- /dev/null
+++ b/pyqtgraph/graphicsItems/tests/ScatterPlotItem.py
@@ -0,0 +1,23 @@
+import pyqtgraph as pg
+import numpy as np
+app = pg.mkQApp()
+plot = pg.plot()
+app.processEvents()
+
+# set view range equal to its bounding rect. 
+# This causes plots to look the same regardless of pxMode.
+plot.setRange(rect=plot.boundingRect())
+
+
+def test_modes():
+    for i, pxMode in enumerate([True, False]):
+        for j, useCache in enumerate([True, False]):
+            s = pg.ScatterPlotItem()
+            s.opts['useCache'] = useCache
+            plot.addItem(s)
+            s.setData(x=np.array([10,40,20,30])+i*100, y=np.array([40,60,10,30])+j*100, pxMode=pxMode)
+            s.addPoints(x=np.array([60, 70])+i*100, y=np.array([60, 70])+j*100, size=[20, 30])
+
+
+if __name__ == '__main__':
+    test_modes()