
# some modules for a GUI
import wxversion
wxversion.select('2.8')
import wx

#import the core modules we are going to use
import os, sys
import cv2.cv as cv
import numpy as np
import time

#import some modules for some summary plots at the end
import matplotlib.pyplot as plt

#import some modules that allows us to read .mat files
import scipy.io


### ------------------------------------------------- ###
# TEST CODE - some hard-coded variables for testing
#input setting for initial testing
#inFileName = '/home/ag/pupil_detection/sample_avis/20100427-095732.avi'
#inFileName = '/home/ag/pupil_detection/sample_avis/20110623-150633.avi'
#inFileName = '/home/ag/pupil_detection/sample_avis/20110215-111741.avi'

#x_min, y_min, x_size, y_size = 100,65,100,100
#x_min, y_min, x_size, y_size = 80,120,100,100
#x_min, y_min, x_size, y_size = 105,70,100,100  
# END TEST CODE
### ------------------------------------------------- ###


# -----------
#some placeholders
matData = None
showFramesYesNo = None


gsmth = 3


#for correlation with background to auto-detect flashes
flash_target = np.array(
   [47,55,61,62,62,62,62,62,54,47,47,47,47,47,
    47,47,54,61,62,62,62,62,62,55,47,47,47,47,
    47,47,47,54,61,62,62,62,62,62,57,47,47,47,
    47,47,47,47,53,61,62,62,62,62,62,58,47,47,
    47,47,47,47,47,52,61,62,62,62,62,62,59,47])




global clickPoints, finalPoints, inFileName, matData, showFramesYesNo, x_start, x_end, gsmth

### ------------------------------------------------- ###
# OPEN CV - engine room of the processing

def onclickStart(event):
    global x_start
    print 'button=%d, x=%d, y=%d, xdata=%f, ydata=%f'%(
        event.button, event.x, event.y, event.xdata, event.ydata)
    x_start = event.xdata
    print 'you gave set the start time to frame:' , x_start

def onclickEnd(event):
    global x_end
    print 'button=%d, x=%d, y=%d, xdata=%f, ydata=%f'%(
        event.button, event.x, event.y, event.xdata, event.ydata)
    x_end = event.xdata
    print 'you gave set the end time to frame:' , x_end

def runPupilAnalysis(coordinates, inFileName, matData=None):
    global x_start, x_end, gsmth  

    x, y, w, h, bg_x, bg_y, bg_w, bg_h = coordinates  

    x_min, y_min, x_size, y_size = x,y,w,h

    cv.NamedWindow("VideoFrames", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("BG_ROI", cv.CV_WINDOW_AUTOSIZE)    
    cv.NamedWindow("3_channel_main", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("XMarkers", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("MarkerPos", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("watershed transform", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("watershed fast", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("pupil_centre", cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow('Final pupil and CR', cv.CV_WINDOW_AUTOSIZE)
    cv.NamedWindow("PUPIL_ROI", cv.CV_WINDOW_AUTOSIZE)

    capture = cv.CreateFileCapture(inFileName)    

    cnt = 0 #keep track of which frame we're at

    if matData == None:
        #get the first frame  
        frame = cv.QueryFrame(capture)
        gray_img = cv.CreateImage((frame.width, frame.height),8,1)
    
    else:
        matTotalFrames = matData.shape[2]
        
        #get the first frame data     
        matFrame = matData[:,:,cnt].copy() #NB need the .copy() for correct handling by cv
        cvdat = cv.fromarray(matFrame)
        
        #set up a window to show the image
        gray_img = cv.CreateImage((matFrame.shape[1], matFrame.shape[0]),8,1)
        frame = cv.CreateImage((matFrame.shape[1], matFrame.shape[0]),8,3)

        cv.Merge(cvdat, cvdat, cvdat, None, frame)

    #images in memory to pass matrices to
    bg_roi_img = cv.CreateImage((bg_w, bg_h),8,1)
    pupil_roi_img = cv.CreateImage((w, h),8,1)
    smooth_img = cv.CreateImage((w, h),8,1)
	
	#images to hold the final output
    tmp_pupil = cv.CreateImage((w, h),8,1)
    cv.Zero(tmp_pupil)
    tmp_CR = cv.CreateImage((w, h),8,1)
    cv.Zero(tmp_CR)
    final_pupil_CR = cv.CreateImage((w, h),8,1)
    cv.Zero(final_pupil_CR)

    marker_positions = cv.CreateImage((w, h),8,1)
    pupil_positions = cv.CreateImage((w, h),8,1)

    markers = cv.CreateImage((w, h), cv.IPL_DEPTH_32S, 1)

    main_img_3channel = cv.CreateImage((w, h),8,3)

    rng = cv.RNG(-1) #??needed but why
    color_tab = [(cv.RandInt(rng) % 180 + 50, cv.RandInt(rng) % 180 + 50, cv.RandInt(rng) % 180 + 50) for i in range(2)]

    single_fast_wshed = cv.CreateImage((w, h),8,1)


    #lets keep hold of some values so that we can save our data
    pupil_vox_count = []
    img_bg_mean = []
    img_bg_mean_float = [] # for plotting
    pupil_pos_X = []
    pupil_pos_Y = []
    pupil_pos_accum = np.zeros((h,w),'f')

    #lets keep the clipped video data in a matrix
    clipped_vid_data = np.zeros((h,w,25000),'uint8') #well allocate here and overwrite later

    # a counter to keep track of where we are
    cnt=0    

    while frame != None and cnt<25000: # and cnt<1000:#for testing
        cnt+=1
        print cnt    

        print cv.GetCaptureProperty(capture, cv.CV_CAP_PROP_FPS)
        print cv.GetCaptureProperty(capture, cv.CV_CAP_PROP_POS_FRAMES)   

        #cast capture to 1 channel grayscale
        cv.CvtColor(frame, gray_img, cv.CV_BGR2GRAY)
        
        #get some background info restrict to roi
        my_bg = cv.GetSubRect(gray_img,(bg_x, bg_y, bg_w, bg_h)) #bottom-right
        #my_bg = cv.GetSubRect(gray_img,(1,1,w,h)) #top-left

        #cv.Smooth(my_bg, my_bg, cv.CV_GAUSSIAN, 9, 9)
        my_bg_avg = cv.Avg(my_bg)[0]
        
        x_max = x_min+x_size    
        y_max = y_min+y_size    

        #lets restrict our pupil data to a smaller FOV
        my_roi = cv.GetSubRect(gray_img,(x_min, y_min, x_size, y_size))

        #copy data to arrys we can work with
        cv.Copy(my_bg, bg_roi_img)
        cv.Copy(my_roi, pupil_roi_img)

        curr_vid_clipped = np.fromstring(pupil_roi_img.tostring(),'uint8')
        curr_vid_clipped = curr_vid_clipped.reshape(pupil_roi_img.height, pupil_roi_img.width) #not x*y but rows*cols when using images
        clipped_vid_data[:,:,cnt] = curr_vid_clipped

        ### now calculate the probable position of the pupil
        #try filtering to remove the corneal reflection
        cv.Smooth(pupil_roi_img, smooth_img, cv.CV_GAUSSIAN, gsmth, gsmth)
        cv.Smooth(pupil_roi_img, smooth_img, cv.CV_GAUSSIAN, 25, 25)

        roi_py = np.fromstring(smooth_img.tostring(),'uint8')
        roi_py_max = np.fromstring(pupil_roi_img.tostring(),'uint8')
        roi_py = roi_py.reshape(pupil_roi_img.height, pupil_roi_img.width) #not x*y but rows*cols when using images
        roi_py_max = roi_py_max.reshape(pupil_roi_img.height, pupil_roi_img.width) #not x*y but rows*cols when using images

        # look through each column/row and find those that have max/min variance
        # logic: a col/row with most variance has bg+pupil data
        x_Imax = np.argmax(np.max(roi_py_max,0)-np.min(roi_py_max,0))
        y_Imax = np.argmax(np.max(roi_py_max,1)-np.min(roi_py_max,1))
        x_Imin = np.argmin(np.min(roi_py,0)-np.min(roi_py,0))
        y_Imin = np.argmin(np.min(roi_py,1)-np.min(roi_py,1))
        x_Imin_pup = np.argmin(np.min(roi_py,0)) #find pupil?
        y_Imin_pup = np.argmin(np.min(roi_py,1)) #find pupil?



        #print x_Imax 
        #print y_Imax 
        #print x_Imin
        #print y_Imin
        #print x_Imin_pup 
        #print y_Imin_pup
                
        pos_min = (x_Imin,y_Imin)
        pos_max = (x_Imax,y_Imax)


        #now to dow watershedding, we need a 3 channel image
        # TODO: decide whether to use a snm=moothed input or raw  
        cv.Merge(smooth_img, smooth_img, smooth_img, None, main_img_3channel)
        #cv.Merge(pupil_roi_img, pupil_roi_img, pupil_roi_img, None, main_img_3channel)

        # clear the image that will contain our markers
        cv.Zero(marker_positions)

        #mark our max variance position in our image
        cv.Set2D(marker_positions, pos_max[1], pos_max[0], (255,255,255,0))

        #mark our min variance in our image (will probably be the top left corner)
        cv.Set2D(marker_positions, pos_min[1]+2, pos_min[0]+2, (128,128,128,0))

        a = time.time()
        storage = cv.CreateMemStorage(0)
        #cv.SaveImage("wshed_mask.png", marker_mask)
        #marker_mask = cv.LoadImage("wshed_mask.png", 0)
        contours = cv.FindContours(marker_positions, storage, cv.CV_RETR_CCOMP, cv.CV_CHAIN_APPROX_SIMPLE)
        #print time.time()-a
        def contour_iterator(contour):
            while contour:
                yield contour
                contour = contour.h_next()

        cv.Zero(markers)
        comp_count = 0
        for c in contour_iterator(contours):
            cv.DrawContours(markers,
                            c,
                            cv.ScalarAll(comp_count + 1),
                            cv.ScalarAll(comp_count + 1),
                            -1,

                            -1,
                            8)
            comp_count += 1

        print 'contours found:', comp_count 
        #print time.time()-a
        cv.Watershed(main_img_3channel, markers)

        wshed = cv.CloneImage(main_img_3channel)
        #print time.time()-a
        img_gray = cv.CloneImage(main_img_3channel)
        cv.CvtColor(marker_positions, img_gray, cv.CV_GRAY2BGR)

    #    # old slow code
    #    # paint the watershed image this actually really slows down the code!
    #    # because we're looping through each voxel 
    #    for j in range(markers.height):
    #        for i in range(markers.width):
    #            idx = markers[j, i]
    #            if idx != -1:
    #                wshed[j, i] = color_tab[int(idx - 1)]
        

        # Attempt to vectorize:z
        # paint the watershed image in a single matrix calculation
        cv.Convert(markers, single_fast_wshed)
        cv.ConvertScale(single_fast_wshed, single_fast_wshed, 1.0, 1.0)
        cv.ConvertScale(markers, markers, 1.0, 1.0)###
        cv.Normalize(single_fast_wshed, single_fast_wshed, 0, 255, cv.CV_MINMAX)
        cv.FloodFill(single_fast_wshed, (2,2), (0,0,0),0,1)
        cv.FloodFill(single_fast_wshed, pos_max, (255,0,255),0,1)
        #cv.FloodFill(single_fast_wshed, pos_min, (128,128,128),0,1)



        ##################        
        #make a 'final' image to hold he pupil and corneal reflection only
		#gather up the corneal reflection
        cv.Copy(pupil_roi_img, tmp_CR)
        cv.FloodFill(tmp_CR, pos_max, (255,255,255),(35,35,35),(35,35,35),cv.CV_FLOODFILL_FIXED_RANGE)
        cv.Threshold(tmp_CR,tmp_CR, 254, 255, cv.CV_THRESH_BINARY)

		#gather up the pupil
        cv.Copy(pupil_roi_img, tmp_pupil)                
        cv.Smooth(tmp_pupil, tmp_pupil, cv.CV_GAUSSIAN, 15, 15)        
        cv.FloodFill(tmp_pupil, (x_Imin_pup,y_Imin_pup), (255,255,255),(150,150,150), (10,10,10),cv.CV_FLOODFILL_FIXED_RANGE)
        cv.Threshold(tmp_pupil,tmp_pupil, 254, 128, cv.CV_THRESH_BINARY)

		#now combine tmp pupil and corneal reflection
        cv.Add(tmp_pupil, tmp_CR, final_pupil_CR)
        
        
        ##################
        cv.Smooth(single_fast_wshed, single_fast_wshed, cv.CV_GAUSSIAN, 9, 9)
        cv.Threshold(single_fast_wshed, single_fast_wshed,25,255,cv.CV_THRESH_BINARY)


        cv.ShowImage("watershed fast", single_fast_wshed)   

        cv.AddWeighted(wshed, 0.5, img_gray, 0.5, 0, wshed)
        cv.ShowImage("watershed transform", wshed)    
        #print time.time()-a

        cv.ShowImage("VideoFrames", gray_img)
        cv.ShowImage("BG_ROI", bg_roi_img)
        cv.ShowImage("PUPIL_ROI", pupil_roi_img)

        cv.ShowImage("MarkerPos", marker_positions)
        cv.ShowImage("XMarkers", markers)
        cv.ShowImage("3_channel_main", main_img_3channel)

        cv.ShowImage("Final pupil and CR", final_pupil_CR)

        # count number of white voxels
        voxCount = cv.Sum(single_fast_wshed)[0]/255.0 # effectively pupil area
        
        print 'pixels', voxCount
        pupil_vox_count.append(int(voxCount))
        img_bg_mean.append(int(my_bg_avg))
        img_bg_mean_float.append(my_bg_avg)

        ## PUPIL POSITION CODE ------------
        #lets do some work to calculate the eye position
        # we will do this based on a centre-of-mass calculation
        # performed on the binarised pupil blob image
        
        #get the binarised pupil image data
        pos_roi = np.fromstring(tmp_pupil.tostring(),'uint8')
        #print max(pos_roi) 
        pos_roi = pos_roi.reshape(tmp_pupil.height, tmp_pupil.width)/128.0 #not x*y but rows*cols when using images       

               

        #we use meshgrids to generate weigthed sums or our populated x/y coordinates
        x_pts = np.arange(1,pupil_roi_img.width+1,1)
        y_pts = np.arange(1,pupil_roi_img.height+1,1)
        y_accum, x_accum = np.meshgrid(x_pts, y_pts)
        
        x_occpd = pos_roi*x_accum
        y_occpd = pos_roi*y_accum
        
        try: #sometimes we get nans here, if so set to 3,3
            x_pos = int(float(sum(sum(x_occpd)))/(sum(sum(pos_roi))))
            y_pos = int(float(sum(sum(y_occpd)))/(sum(sum(pos_roi))))
        except:
            x_pos = 3
            y_pos = 3

        #print x_pos, y_pos        

        pupil_pos_X.append(x_pos)
        pupil_pos_Y.append(y_pos)
        pupil_pos_accum[x_pos,y_pos] += 1



        ## PUPIL POSITION CODE ------------

        
        if showFramesYesNo ==1:
            #show a figure to check our result
            cv.Zero(pupil_positions)
            
            try:
            
		        #mark our max variance position in our image
		        cv.Set2D(pupil_roi_img, x_pos, y_pos, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos-1, y_pos, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos+1, y_pos, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos, y_pos-1, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos, y_pos+1, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos-2, y_pos, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos+2, y_pos, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos, y_pos-2, ( 128, 128, 128,0))
		        cv.Set2D(pupil_roi_img, x_pos, y_pos+2, ( 128, 128, 128,0))
		        cv.ShowImage("pupil_centre", pupil_roi_img)
		        cv.WaitKey(1)
            except:
	        	print 'Centre of mass at edge of image - probably a blink'

        if matData == None:
            frame = cv.QueryFrame(capture)
            if inFileName[-3:] == 'ogg': # with OGG, OPENCV pick up two fields for each frame; just ignore every second frame. #TODO FIX??
            	frame = cv.QueryFrame(capture) 
        else:
            if cnt < matTotalFrames: #check we haven't run out of frames
                matFrame = matData[:,:,cnt].copy() #NB need the .copy() for correct handling by cv
                cvdat = cv.fromarray(matFrame)
                
                #set up a window to show the image
                gray_img = cv.CreateImage((matFrame.shape[1], matFrame.shape[0]),8,1)
                frame = cv.CreateImage((matFrame.shape[1], matFrame.shape[0]),8,3)
                cv.Merge(cvdat, cvdat, cvdat, None, frame)
            else: #we're out of data
                frame = None


    area_data_final = np.array(pupil_vox_count,'i')
    plt.plot(area_data_final)
    plt.title('pupila area in voxels')    
    plt.show()

    #once done .. do some plots
    fig = plt.figure()
    plt.plot(pupil_pos_X[:550], pupil_pos_Y[:550],'rx')

    fig1 = plt.figure()
    plt.plot(pupil_pos_X[550:], pupil_pos_Y[550:],'bx')

    fig2 = plt.figure()
    plt.imshow(pupil_pos_accum)
    plt.show()
    
    ##try to auto-find the flashes in the bg data    
    #corr_bg = []
    #for k in range(len(img_bg_mean_float)-len(flash_target)):
    #    corr_bg.append(np.corrcoef(flash_target, np.array(img_bg_mean_float)[k:k+70])[0][1])
    #    #print k

    fig = plt.figure()
    plt.plot(range(len(pupil_vox_count)), pupil_vox_count)
    plt.plot(range(len(pupil_vox_count)), np.array(img_bg_mean_float)*20.0)
    #plt.plot(range(len(img_bg_mean_float)-len(flash_target)), np.array(corr_bg)*1000.0)
    cid = fig.canvas.mpl_connect('button_press_event', onclickStart)
    plt.title('click the start of the first flash sequence')
    plt.show()

    fig = plt.figure()
    plt.plot(range(len(pupil_vox_count)), pupil_vox_count)
    plt.plot(range(len(pupil_vox_count)), np.array(img_bg_mean_float)*20.0)
    #plt.plot(range(len(img_bg_mean_float)-len(flash_target)), np.array(corr_bg)*1000.0)

    cid = fig.canvas.mpl_connect('button_press_event', onclickEnd)
    plt.title('click the start of the end flash sequence')
    plt.show()


    clipped_vid_data = clipped_vid_data[:,:,x_start:x_end+1]

    
    dlg = wx.MessageDialog(None, 
        "Save the data?",
        "Save the data? Yes to save, No to quit without saving ... ", wx.YES|wx.NO|wx.ICON_QUESTION)
    result = dlg.ShowModal()
    dlg.Destroy()
    if result == wx.ID_YES:
        #write out arrays to temp file
        f_pup = open('/tmp/pupil_area.txt','a')
        f_bg = open('/tmp/pupil_bgMean.txt','a')   
        f_meta = open('/tmp/pupil_meta.txt','a')
        f_pupilPosX = open('/tmp/pupil_posX.txt','a')
        f_pupilPosY = open('/tmp/pupil_posY.txt','a')

        f_meta.write(str(inFileName)+',start_frame=%s=end_frame=%s\n' %(int(x_start), int(x_end)))
        f_pup.write(str(pupil_vox_count[:])[1:-1] + '\n')
        f_bg.write(str(img_bg_mean[:])[1:-1] + '\n')
        f_pupilPosX.write(str(int(x_start)) + ';' + str(int(x_end)) + ';' + str(pupil_pos_X[:])[1:-1] + '\n')
        f_pupilPosY.write(str(int(x_start)) + ';' + str(int(x_end)) + ';' + str(pupil_pos_Y[:])[1:-1] + '\n')

        f_pup.close()
        f_bg.close()
        f_meta.close()
        f_pupilPosX.close()
        f_pupilPosY.close()

        #save the clipped video matrix
        clipped_vid_data = {'data':clipped_vid_data, 'coordinates':coordinates}
        scipy.io.savemat('/var/tmp/%s' %os.path.basename(str(inFileName))[:-4]+'_gs%s_clipped.mat' %str(gsmth), clipped_vid_data, do_compression=True)

        print 'Data saved ...  processed %s \n\n\n' %inFileName
        
    else:
        print 'quitting without saving any data ... was processing %s \n\n\n' %inFileName




# END OPEN CV stuff
### ------------------------------------------------- ###


### ------------------------------------------------- ###
# GUI - allows us to pick points and input

class CvDisplayPanel(wx.Panel):
    #global clickPoints, inFileName
    
    def __init__(self, parent):
        global clickPoints, finalPoints, inFileName, matData, showFramesYesNo
        wx.Panel.__init__(self, parent, -1)

        #load a file
        if len(sys.argv) > 1:
            inFileName = sys.argv[1]
            print 'Video file provided at command line:' , inFileName
        else:
            filters = 'avi files (*.avi)|*.avi|matlab files (*.mat)|*.mat|ogg files (*.ogg)|*.ogg|mpg files (*.mpg)|*.mpg'
            dialog = wx.FileDialog ( None, message = 'Open a video or matlab file ....', wildcard = filters, style = wx.OPEN)
            if dialog.ShowModal() == wx.ID_OK:
                inFileName = dialog.GetPath()
                print 'Video file selected in GUI:' , inFileName                
            else:
                print 'Nothing was selected. Quitting ...'
                sys.exit()
            dialog.Destroy()
        
        #determine the data type we are loading: avi, ogg or mat
        if inFileName[-3:] == 'avi':
            self.dataType = 'avi'
            #get a frame from the file
            self.capture = cv.CreateFileCapture(inFileName)
            for i in range(4000):
                self.vidframe_orig = cv.QueryFrame(self.capture)
                cv.CvtColor(self.vidframe_orig, self.vidframe_orig, cv.CV_BGR2RGB)

            self.vidframe_copy = cv.CreateImage((self.vidframe_orig.width, self.vidframe_orig.height),8,3)
            cv.Copy(self.vidframe_orig, self.vidframe_copy)

        elif inFileName[-3:] == 'mpg':
            self.dataType = 'mpg'
            #get a frame from the file
            self.capture = cv.CreateFileCapture(inFileName)
            for i in range(4000):
                self.vidframe_orig = cv.QueryFrame(self.capture)
                cv.CvtColor(self.vidframe_orig, self.vidframe_orig, cv.CV_BGR2RGB)

            self.vidframe_copy = cv.CreateImage((self.vidframe_orig.width, self.vidframe_orig.height),8,3)
            cv.Copy(self.vidframe_orig, self.vidframe_copy)

        elif inFileName[-3:] == 'ogg':
            self.dataType = 'ogg'
            #get a frame from the file
            self.capture = cv.CreateFileCapture(inFileName)
            for i in range(4000):
                self.vidframe_orig = cv.QueryFrame(self.capture)
                cv.CvtColor(self.vidframe_orig, self.vidframe_orig, cv.CV_BGR2RGB)

            self.vidframe_copy = cv.CreateImage((self.vidframe_orig.width, self.vidframe_orig.height),8,3)
            cv.Copy(self.vidframe_orig, self.vidframe_copy)

        elif inFileName[-3:] == 'mat':
            self.dataType = 'mat'
            print 'loading data ...'
            #load the data
            mat = scipy.io.loadmat(inFileName)
            self.matData = np.array(mat['data'],'uint8')
            ##self.matData = np.array(mat['whole_small_data'],'uint8')
            matData = self.matData #for global parsing
            print 'loading done ...'
            #get a frame from the file
            matFrame = self.matData[:,:,4000].copy() #NB need the .copy() for correct handling by cv
            self.matPos = 4000 #need this in case we want to increment later
            cvdat = cv.fromarray(matFrame)
            #now set up an image to show it
            self.vidframe_orig = cv.CreateImage((self.matData.shape[1], self.matData.shape[0]),8,3)
            
            cv.Merge(cvdat, cvdat, cvdat, None, self.vidframe_orig)
            
            self.vidframe_copy = cv.CreateImage((self.matData.shape[1], self.matData.shape[0]),8,3)

            cv.Copy(self.vidframe_orig, self.vidframe_copy)
            #cvRectangle(
            #   myImg,
            #   cvPoint(5,10),
            #   cvPoint(20,30),
            #   cvScalar(255,255,255)
            #);




        else:
            print 'Problem! - file type not recognised (not .mat or .avi) ... Quitting ...'
            sys.exit()
                        

        frame = wx.Frame(self, -1, size=(self.vidframe_copy.width, self.vidframe_copy.height))      
        frame.Show(True)

        # Convert the raw image data to something wxpython can handle.
        self.bmp = wx.BitmapFromBuffer(self.vidframe_copy.width, self.vidframe_copy.height,\
                                       self.vidframe_copy.tostring())
        # Display the resulting image
        self.sbmp = wx.StaticBitmap(self, -1, bitmap=self.bmp)
        
        self.sbmp.Bind(wx.EVT_LEFT_DOWN, self.LeftClick)
        self.sbmp.Bind(wx.EVT_RIGHT_DOWN, self.ReturnCoords)
        self.sbmp.Bind(wx.EVT_LEFT_DCLICK, self.NextFrame)

        clickPoints = []

    def LeftClick(obj, evt):
        global clickPoints
        print 'click'
        print evt.GetPosition()
        clickPoints.append(evt.GetPosition())

        cv.Copy(obj.vidframe_orig, obj.vidframe_copy) 

        if len(clickPoints)%4 == 0:
            if len(clickPoints) >  3:
                pup_pt1, pup_pt2 = clickPoints[-2], clickPoints[-1]
                cv.Rectangle(obj.vidframe_copy,(int(pup_pt1[0]), int(pup_pt1[1])),(int(pup_pt2[0]), int(pup_pt2[1])),(255,255,0))
                bg_pt1, bg_pt2 = clickPoints[-4], clickPoints[-3]
                cv.Rectangle(obj.vidframe_copy,(int(bg_pt1[0]), int(bg_pt1[1])),(int(bg_pt2[0]), int(bg_pt2[1])),(148,148,148))

        obj.bmp = wx.BitmapFromBuffer(obj.vidframe_copy.width, obj.vidframe_copy.height,\
                                       obj.vidframe_copy.tostring())
        # Display the resulting image
        obj.sbmp.SetBitmap(obj.bmp)
  
    def ReturnCoords(obj, evt):
        global clickPoints, finalPoints, showFramesYesNo
        print 'got coords command .. closing GUI and moving on to analysis'
        pt1, pt2 = clickPoints[-2], clickPoints[-1]
        w, h = (pt2[0]-pt1[0]), pt2[1]-pt1[1] 
        x, y = pt1[0], pt1[1]
        bg_pt1, bg_pt2 = clickPoints[-4], clickPoints[-3]
        bg_w, bg_h = (bg_pt2[0]-bg_pt1[0]), bg_pt2[1]-bg_pt1[1] 
        bg_x, bg_y = bg_pt1[0], bg_pt1[1]

        finalPoints = x, y, w, h, bg_x, bg_y, bg_w, bg_h

        dlg = wx.MessageDialog(obj, 
            "Would you like to see the frames?",
            "View video frames?", wx.YES|wx.NO|wx.CANCEL|wx.ICON_QUESTION)
        result = dlg.ShowModal()
        dlg.Destroy()
        if result == wx.ID_YES:
            showFramesYesNo = 1
            obj.GetParent().Destroy()
        elif result == wx.ID_NO:
            showFramesYesNo = 0
            obj.GetParent().Destroy()
        else:
            pass #do nought
        


    def NextFrame(obj, evt):
        print 'trying a different frame'
        #get a different frame from the file
        if obj.dataType == 'avi':
            for i in range(120):
                obj.vidframe = cv.QueryFrame(obj.capture)
                cv.CvtColor(obj.vidframe_orig, obj.vidframe_orig, cv.CV_BGR2RGB)
        elif obj.dataType == 'mat':
            obj.matPos += 120
            matFrame = obj.matData[:,:,obj.matPos].copy() #NB need the .copy() for correct handling by cv

            cvdat = cv.fromarray(matFrame)
            cv.Merge(cvdat, cvdat, cvdat, None, obj.vidframe_orig)            

        # Convert the raw image data to something wxpython can handle
        obj.bmp = wx.BitmapFromBuffer(obj.vidframe_orig.width, obj.vidframe_orig.height,\
                                       obj.vidframe_orig.tostring())
        # Display the resulting image
        obj.sbmp = wx.StaticBitmap(obj, -1, bitmap=obj.bmp)
        clickPoints = []
        

# END GUI code
### ------------------------------------------------- ###


if __name__=="__main__":
    app = wx.App()
    app.RestoreStdio()
    frame = wx.Frame(None, -1, size=(800,520))
    CvDisplayPanel(frame)
    frame.Show(True)
    app.MainLoop()
    
    print 'using this point data: ', finalPoints
    runPupilAnalysis(finalPoints, inFileName, matData)
    











