#import some modules for some summary plots at the end
import matplotlib.pyplot as plt

#import some modules that allows us to read .mat files
import scipy.io

# some modules for a GUI
import wxversion
wxversion.select('2.8')
import wx

#import the core modules we are going to use
import os, sys
import cv2.cv as cv
import numpy as np
import time

#some globals to keep track of variables
global xFileName, yFileName



### ------------------------------------------------- ###
# GUI - allows us to pick our file if required

class CvDisplayPanel(wx.Panel):
    #global clickPoints, inFileName
    
    def __init__(self, parent):
        global xFileName, yFileName
        wx.Panel.__init__(self, parent, -1)

        #load a file
        if len(sys.argv) > 2:
            xFileName = sys.argv[1]
            yFileName = sys.argv[2]
            print 'X and Y pos files provided at command line ...'
        else:
            filters = 'position files (*.*)|*.*'
            dialog = wx.FileDialog ( None, message = 'Open the x_pos file ....', wildcard = filters, style = wx.OPEN)
            if dialog.ShowModal() == wx.ID_OK:
                xFileName = dialog.GetPath()
                print 'X File selected in GUI:' , xFileName                
            else:
                print 'Nothing was selected. Quitting ...'
                sys.exit()
            dialog.Destroy()

            filters = 'position files (*.*)|*.*'
            dialog = wx.FileDialog ( None, message = 'Open the y_pos file ....', wildcard = filters, style = wx.OPEN)
            if dialog.ShowModal() == wx.ID_OK:
                yFileName = dialog.GetPath()
                print 'Y File selected in GUI:' , yFileName                
            else:
                print 'Nothing was selected. Quitting ...'
                sys.exit()
            dialog.Destroy()

        self.GetParent().Close()

# END GUI code
### ------------------------------------------------- ###


### ------------------------------------------------- ###
# LOAD DATA

#place holders
all_x_starts = []
all_x_ends = []
all_x_dat = []

all_y_starts = []
all_y_ends = []
all_y_dat = []

def loadPosData(xFileName, yFileName):    
    f = open(xFileName)
    raw_x = f.readlines()
    f.close()

    f = open(yFileName)
    raw_y = f.readlines()
    f.close()

    if len(raw_x) != len(raw_y):
        print 'PROBLEM! - x and y not same length! Quitting...\n\n\n'
        sys.exit()

    for i in range(len(raw_x)):
        curr_raw_x = raw_x[i].strip().split(';')
        x_start = curr_raw_x[0]
        x_end = curr_raw_x[1]
        x_dat = curr_raw_x[2]
        #print x_start, x_end, x_dat
        
        curr_raw_y = raw_y[i].strip().split(';')
        y_start = curr_raw_y[0]
        y_end = curr_raw_y[1]
        y_dat = curr_raw_y[2]
        #print y_start, y_end, y_dat

        if x_start != y_start or x_end != y_end:
            print 'PROBLEM! - x start/end not same value in loop %s! Quitting...\n\n\n' %i
            sys.exit()

        else:
            all_x_starts.append(int(x_start))
            all_x_ends.append(int(x_end))
            all_x_dat.append(np.array(x_dat.split(','),'uint8'))

            all_y_starts.append(int(y_start))
            all_y_ends.append(int(y_end))
            all_y_dat.append(np.array(y_dat.split(','),'uint8'))
        
    return all_x_starts, all_x_ends, all_x_dat, all_y_starts, all_y_ends, all_y_dat

# END LOAD DATA
### ------------------------------------------------- ###



### ------------------------------------------------- ###
# ANALYSIS - once the data is loaded, let's process it


def runCalibrationPeriod(all_x_starts, all_x_ends, all_x_dat, all_y_starts, all_y_ends, all_y_dat):

    ## for 2010 data (Python)
    #this function just looks at the first 21.5 secs (645 frames) incl the first flash sequence
    # this corresponds to the period where the participant is tracking the dot
    # order is: flashes for 2.5s (75 frames)
    #           instructions for 5s (150 frames)
    #           centre for 2s (60 frames)
    #           top left for 2s (60 frames)
    #           bottom left for 2s (60 frames)
    #           bottom right for 2s (60 frames)    
    #           top right for 4s (120 frames)
    #           centre for 2s (60 frames)

    ## for 2011 data (Visage)
    #this function just looks at the first 16.25 secs (488 frames) incl the first flash sequence
    # this corresponds to the period where the participant is tracking the dot
    # order is: flashes for 2.25s (67 frames)
    #           instructions for 2s (60 frames)
    #           centre for 2s (60 frames)
    #           top left for 2s (60 frames)
    #           bottom left for 2s (60 frames)
    #           bottom right for 2s (60 frames)    
    #           top right for 2s (60 frames)
    #           centre for 2s (60 frames)
        
    for i in range(len(all_x_dat)):
        curr_x = all_x_dat[i]
        print curr_x
        print type(curr_x)
        print curr_x.shape
        
        curr_x_start = all_x_starts[i]
        curr_x_end = all_x_ends[i]
        curr_x = curr_x[curr_x_start+225:curr_x_start+225+420]

        curr_y = all_y_dat[i]
        curr_y_start = all_y_starts[i]
        curr_y_end = all_y_ends[i]
        curr_y = curr_y[curr_y_start+225:curr_y_start+225+420]
        
        #for j in range(len(curr_x)):
        #    plt.figure()
        #    plt.plot(curr_x[j], curr_y[j], 'bx')
        #    thismanager = plt.get_current_fig_manager()
        #    thismanager.window.wm_geometry("+100+100")
        #    plt.ylim(25,75)
        #    plt.xlim(25,75)
        #    plt.show()         

        plt.figure()
        plt.plot(curr_x, curr_y, 'bx')
        plt.show()

        #plot a cumulative representation of each coordinate's frequency
        curr_cum = np.zeros((150,150),'i')        
        for j in range(len(curr_x)):
            curr_cum[curr_x[j], curr_y[j]] += 1
        plt.figure()        
        plt.imshow(curr_cum)
        plt.title('all calibration points')
        plt.colorbar()

        exp_data_x = all_x_dat[i][curr_x_end-7680:curr_x_end]
        exp_data_y = all_y_dat[i][curr_y_end-7680:curr_y_end]

        plt.figure()
        plt.plot(exp_data_x,'r')
        plt.plot(exp_data_y,'g')
        plt.show()

        print exp_data_x
        print type(exp_data_x)
        print exp_data_x.shape

        curr_exp_cum = np.zeros((150,150),'i')    
        for j in range(len(exp_data_x)):
            curr_exp_cum[exp_data_x[j], exp_data_y[j]] += 1
        plt.figure()        
        plt.imshow(curr_exp_cum)
        plt.colorbar()        

        plt.show()
            
        
# END ANALYSIS code
### ------------------------------------------------- ###




if __name__=="__main__":
    app = wx.App()
    app.RestoreStdio()

    frame = wx.Frame(None, -1, size=(350,240), title='close this to continue')
    CvDisplayPanel(frame)
    frame.Show(True)
    app.MainLoop()

    all_x_starts, all_x_ends, all_x_dat, all_y_starts, all_y_ends, all_y_dat = loadPosData(xFileName, yFileName)
    runCalibrationPeriod(all_x_starts, all_x_ends, all_x_dat, all_y_starts, all_y_ends, all_y_dat)
    


    
