From a20e732f650a9d6dd9bbcbd2e6dfee624953583a Mon Sep 17 00:00:00 2001
From: Luke Campagnola <luke.campagnola@gmail.com>
Date: Thu, 4 Jul 2013 11:21:50 -0400
Subject: [PATCH] Added GL picking, matrix retrieval methods

---
 pyqtgraph/opengl/GLViewWidget.py | 67 +++++++++++++++++++++++---------
 1 file changed, 49 insertions(+), 18 deletions(-)

diff --git a/pyqtgraph/opengl/GLViewWidget.py b/pyqtgraph/opengl/GLViewWidget.py
index 12984c86..d8f70055 100644
--- a/pyqtgraph/opengl/GLViewWidget.py
+++ b/pyqtgraph/opengl/GLViewWidget.py
@@ -80,23 +80,26 @@ class GLViewWidget(QtOpenGL.QGLWidget):
         #self.update()
 
     def setProjection(self, region=None):
+        m = self.projectionMatrix(region)
+        glMatrixMode(GL_PROJECTION)
+        glLoadIdentity()
+        a = np.array(m.copyDataTo()).reshape((4,4))
+        glMultMatrixf(a.transpose())
+
+    def projectionMatrix(self, region=None):
         # Xw = (Xnd + 1) * width/2 + X
         if region is None:
             region = (0, 0, self.width(), self.height())
-        ## Create the projection matrix
-        glMatrixMode(GL_PROJECTION)
-        glLoadIdentity()
-        #w = self.width()
-        #h = self.height()
+        
         x0, y0, w, h = self.getViewport()
         dist = self.opts['distance']
         fov = self.opts['fov']
         nearClip = dist * 0.001
         farClip = dist * 1000.
-        
+
         r = nearClip * np.tan(fov * 0.5 * np.pi / 180.)
         t = r * h / w
-        
+
         # convert screen coordinates (region) to normalized device coordinates
         # Xnd = (Xw - X0) * 2/width - 1
         ## Note that X0 and width in these equations must be the values used in viewport
@@ -104,21 +107,46 @@ class GLViewWidget(QtOpenGL.QGLWidget):
         right = r * ((region[0]+region[2]-x0) * (2.0/w) - 1)
         bottom = t * ((region[1]-y0) * (2.0/h) - 1)
         top    = t * ((region[1]+region[3]-y0) * (2.0/h) - 1)
-        
-        glFrustum( left, right, bottom, top, nearClip, farClip)
-        #glFrustum(-r, r, -t, t, nearClip, farClip)
+
+        tr = QtGui.QMatrix4x4()
+        tr.frustum(left, right, bottom, top, nearClip, farClip)
+        return tr
         
     def setModelview(self):
         glMatrixMode(GL_MODELVIEW)
         glLoadIdentity()
-        glTranslatef( 0.0, 0.0, -self.opts['distance'])
-        glRotatef(self.opts['elevation']-90, 1, 0, 0)
-        glRotatef(self.opts['azimuth']+90, 0, 0, -1)
+        m = self.viewMatrix()
+        a = np.array(m.copyDataTo()).reshape((4,4))
+        glMultMatrixf(a.transpose())
+        
+    def viewMatrix(self):
+        tr = QtGui.QMatrix4x4()
+        tr.translate( 0.0, 0.0, -self.opts['distance'])
+        tr.rotate(self.opts['elevation']-90, 1, 0, 0)
+        tr.rotate(self.opts['azimuth']+90, 0, 0, -1)
         center = self.opts['center']
-        glTranslatef(-center.x(), -center.y(), -center.z())
+        tr.translate(-center.x(), -center.y(), -center.z())
+        return tr
+
+    def itemsAt(self, region=None):
+        #buf = np.zeros(100000, dtype=np.uint)
+        buf = glSelectBuffer(100000)
+        try:
+            glRenderMode(GL_SELECT)
+            glInitNames()
+            glPushName(0)
+            self._itemNames = {}
+            self.paintGL(region=region, useItemNames=True)
+            
+        finally:
+            hits = glRenderMode(GL_RENDER)
+
+        items = [(h.near, h.names[0]) for h in hits]
+        items.sort(key=lambda i: i[0])
         
+        return [self._itemNames[i[1]] for i in items]
         
-    def paintGL(self, region=None, viewport=None):
+    def paintGL(self, region=None, viewport=None, useItemNames=False):
         """
         viewport specifies the arguments to glViewport. If None, then we use self.opts['viewport']
         region specifies the sub-region of self.opts['viewport'] that should be rendered.
@@ -131,9 +159,9 @@ class GLViewWidget(QtOpenGL.QGLWidget):
         self.setProjection(region=region)
         self.setModelview()
         glClear( GL_DEPTH_BUFFER_BIT | GL_COLOR_BUFFER_BIT )
-        self.drawItemTree()
+        self.drawItemTree(useItemNames=useItemNames)
         
-    def drawItemTree(self, item=None):
+    def drawItemTree(self, item=None, useItemNames=False):
         if item is None:
             items = [x for x in self.items if x.parentItem() is None]
         else:
@@ -146,6 +174,9 @@ class GLViewWidget(QtOpenGL.QGLWidget):
             if i is item:
                 try:
                     glPushAttrib(GL_ALL_ATTRIB_BITS)
+                    if useItemNames:
+                        glLoadName(id(i))
+                        self._itemNames[id(i)] = i
                     i.paint()
                 except:
                     import pyqtgraph.debug
@@ -168,7 +199,7 @@ class GLViewWidget(QtOpenGL.QGLWidget):
                     tr = i.transform()
                     a = np.array(tr.copyDataTo()).reshape((4,4))
                     glMultMatrixf(a.transpose())
-                    self.drawItemTree(i)
+                    self.drawItemTree(i, useItemNames=useItemNames)
                 finally:
                     glMatrixMode(GL_MODELVIEW)
                     glPopMatrix()
-- 
GitLab