from numpy import *
import sys
import matplotlib.pyplot as plt

import mdp



normalise = 1 #covert data to %signal change for each individual?

dType = sys.argv[1] 
#dType = 'ActCen_C100'
#dType = 'ActPer_A100'
#dType = 'ActPer_A025'
#dType = 'ActPer_A012'
#dType = 'ActPer_A006'

#dType = 'ActCen_C100'
#dType = 'ActCen_A100'
#dType = 'ActCen_A025'
#dType = 'ActCen_A012'
#dType = 'ActCen_A006'

f = open('/groups/Projects/P1174/data/analysis/2011_results/2011_%s_pupil_area.txt' %dType)
pupil_area_dat = f.readlines()
f.close()

f = open('/groups/Projects/P1174/data/analysis/2011_results/2011_%s_pupil_bgMean.txt' %dType)
pupil_bg_dat = f.readlines()
f.close()

f = open('/groups/Projects/P1174/data/analysis/2011_results/2011_%s_pupil_meta.txt' %dType)
pupil_meta = f.readlines()
f.close()

if len(pupil_area_dat) == len(pupil_bg_dat) == len(pupil_meta): #sanity check 
    pass
else:
    print 'metadata, area data and pupila data all need to be the same length! .. Quitting... \n\n'
    sys.exit()


all_area_dat = ones((len(pupil_area_dat),7680),'f')*nan
all_area_dat_raw = ones((len(pupil_area_dat),7680),'f')*nan
all_indiv_fourier = ones((len(pupil_area_dat),7680),'f')*nan
all_bg_dat = ones((len(pupil_area_dat),7680),'f')*nan

#all_area_dat = zeros((40,7680),'f')
#all_area_dat_raw = zeros((40,7680),'f')
#all_indiv_fourier = zeros((40,7680),'f')
#all_bg_dat = zeros((40,7680),'f')

# this dataset has a big dc shift that causes an artifact

if dType == 'ActCen_A100':
    excludes = [45,47,48]
elif dType == 'ActCen_A025':
    excludes = [31]
elif dType == 'ActCen_A012':
    excludes = []
elif dType == 'ActCen_A006':
    excludes = [27]
elif dType == 'ActCen_C100':
    excludes = [31]
elif dType == 'ActPer_A100':
    excludes = [0,6,9,10,20,21,22,33,36,37,42,43,45,46,54,57,58]
elif dType == 'ActPer_A025':
    excludes = [3,6,9,10,11,24,28,29,30,31]
elif dType == 'ActPer_A012':
    excludes = []
elif dType == 'ActPer_A006':
    excludes = [11,19,20,25,27,29]
elif dType == 'ActPer_C100':
    excludes = [0,13]
else:
    excludes = []


excl_Rnum = ['R2105']
show_indiv = 0

trials_used_Rid = []
all_Rid = []

#for i in range(24):
for i in range(len(pupil_meta)):
    print i
    if i in excludes: 
        pass
    elif pupil_meta[i][71:76] in excl_Rnum:
        print 'skipping', pupil_meta[i][71:76]        
        pass
    else:
        print pupil_meta[i][71:76]
        trials_used_Rid.append([pupil_meta[i][71:76],i])
        all_Rid.append(pupil_meta[i][71:76])
        #get the first frame position
        #for 2011 data this is 7680 frames before the end flash
        cf = int(pupil_meta[i].strip().split(',')[1].split('=')[-1]) -7680
        #print 'jumping to frame %s' %cf
        area_data = array(pupil_area_dat[i].strip().split(', '),'f')
        area_data = array(area_data[cf:cf+7680],'float32')
        bg_data = array(pupil_bg_dat[i].strip().split(', '),'i')
        bg_data = bg_data[cf:cf+7680]
        ## .. do some plots
        #fig = plt.figure()
        #plt.plot(bg_data)
        #plt.plot(area_data)
        #plt.show()
        #generate some timeseries with the blinks removed
        #first use the gradient change to detect probable blinks
        area_data_offset1 = hstack((area_data[1:], area_data[0])) # wrap by 1 
        area_data_offset2 = hstack((area_data[2:], area_data[0:2])) # wrap by 2
        area_data_offset3 = hstack((area_data[3:], area_data[0:3])) # wrap by 3  
        dy =  area_data - area_data_offset1
        dy2 = area_data - area_data_offset2
        dy3 = area_data - area_data_offset3 
        #fig = plt.figure()
        #plt.plot(bg_data)
        #plt.plot(dy)
        #plt.plot(dy2)
        #plt.plot(dy3)
        #copy the orignal series so we can amend it
        area_data_copy = area_data.copy()
        probBlinks = nonzero(abs(dy)>50)[0]
        for j in range(len(probBlinks)-2):
            #plt.plot(probBlinks[j], area_data_offset1[probBlinks[j]],'rx' )
            area_data_copy[probBlinks[j]+1] = 0
            area_data_copy[probBlinks[j]+2] = 0
        probBlinks = nonzero(abs(dy2)>50)[0]
        for j in range(len(probBlinks)-2):
            #plt.plot(probBlinks[j], area_data_offset2[probBlinks[j]],'rx' )
            area_data_copy[probBlinks[j]+1] = 0
            area_data_copy[probBlinks[j]+2] = 0
        probBlinks = nonzero(abs(dy3)>50)[0]
        for j in range(len(probBlinks)-2):
            #plt.plot(probBlinks[j], area_data_offset3[probBlinks[j]],'rx' )
            area_data_copy[probBlinks[j]+1] = 0
            area_data_copy[probBlinks[j]+2] = 0
        #final tidy - amplitude based
        for j in range(len(area_data_copy)):
            if area_data_copy[j] >1000:
                area_data_copy[j-2:j+3] = 0
            if area_data[i] >1000:
                area_data_copy[j-2:j+3] = 0
        #plt.plot(area_data,'r')      
        area_data_final = area_data_copy.copy() 
        area_data_copy[-1] = 100     
        
        do_interpolate_blinks_period = 1
        do_Nan_blinks_period = 0
        
        if do_interpolate_blinks_period == 1:
            smth_cnt = 0
            while smth_cnt < len(area_data_copy):
                #print 'act i:', smth_cnt
                if area_data_copy[smth_cnt] == 0:
                    #print 'got 0 at ', smth_cnt
                    start_index = smth_cnt
                    evt_track = 1
                    while area_data_copy[smth_cnt+evt_track] == 0:              
                        evt_track += 1        
                    smth_cnt = smth_cnt + evt_track-1
                    #print 'evt ended at ', smth_cnt
                    end_index = smth_cnt
                    smth_cnt += 1
                    area_data_final[start_index:end_index+1] = linspace(area_data_copy[start_index-1], area_data_copy[end_index+1], (end_index-start_index+3))[1:-1]
                else:
                    smth_cnt+=1

        elif do_Nan_blinks_period == 1:
            #first conevert all zeros values to NaNs
            area_data_final[area_data_final==0] = nan
            
        if show_indiv == 1: 
            plt.plot(area_data_final,'g')
            plt.title(str(i)+'_'+pupil_meta[i][71:76])  
            plt.show()
        
        plot_indiv = 0

        signal = area_data_final
        fourier = fft.fft(signal)

        if plot_indiv == 1:
            plt.figure()
            n = signal.size
            timestep = (1/30.0)
            freq = fft.fftfreq(n, d=timestep)
            plt.plot(freq, abs(fourier))
            plt.ylim(-100,50000)
            plt.xlim(-0.0,0.33)
            fois = [0.03125, 0.0625, 0.125]
            for val in fois:
                x = ones(50000)*val
                y = arange(50000)
                plt.plot(x,y)
            plt.title('data run %s' %str(i+1))
            plt.legend(('abs(fft)', '1/32Hz','1/16Hz', '1/8Hz'))
            plt.show()



        #append individual vals to array
        all_area_dat[i,:] = area_data_final
        all_area_dat_raw[i,:] = area_data
        all_indiv_fourier[i,:] = abs(fourier)
        all_bg_dat[i,:] = bg_data
        
        #fig = plt.figure()
        #plt.plot(bg_data)
        #plt.title(str(i))
        ##meanBGsc = bg_data.reshape(8,7680/8.0)
        ##meanBGsc = meanBGsc.mean(0)
        ##fig = plt.figure()
        ##plt.plot(meanBGsc)
        ##plt.title(str(i))
        #plt.show()


print trials_used_Rid
trials_used_Rid.sort()
print trials_used_Rid


for i in range(len(all_area_dat)):
    #tidy up the tails
    curr_mean = all_area_dat[i].mean(0)
    all_area_dat[i][:100] = all_area_dat[i][100]
    all_area_dat[i][-100:] = all_area_dat[i][-100]


# processing by grouping all trials per individual so generate an average per participant
#first detect the number of different participants 
indiv_Rids = unique(all_Rid)
indiv_SCs = []

for i in range(len(indiv_Rids)):
    curr_indiv_dat = []
    print indiv_Rids[i]
    for j in range(len(trials_used_Rid)):
        if trials_used_Rid[j][0] == indiv_Rids[i]:
            curr_indiv_dat.append(all_area_dat[trials_used_Rid[j][1]])
    curr_indiv_arr = array(curr_indiv_dat,'f')
    print curr_indiv_arr.shape

    signal = ma.masked_invalid(curr_indiv_arr).mean(0)
    scSignal = signal.data
    scSignal = scSignal.reshape(8,7680/8.0)
    scSignal = scSignal.mean(0)

    if normalise == 1:
        scSignal = (scSignal - scSignal.mean(0))/scSignal.mean(0)

    fig = plt.figure()
    #scSignal.tofile('/tmp/dat.txt', sep=",", format="%s")
    plt.plot(scSignal)
    plt.title(indiv_Rids[i])
    plt.ylabel('Pupil area in pixels')
    plt.xlabel('Video frame sample number (at 30Hz)')
    plt.show()

        
    indiv_SCs.append(scSignal)
    
    

indiv_SCs_arr = array(indiv_SCs,'f')
scSignal_mean = indiv_SCs_arr.mean(0)*100.0
scSignal_StdErr = indiv_SCs_arr.std(0)/sqrt(len(indiv_SCs))*100.0

norm_fig = plt.figure()
plt.fill_between(arange(0,7680/8.0,1.0),scSignal_mean+scSignal_StdErr, scSignal_mean-scSignal_StdErr,facecolor=[0.75,0.75,0.75,1.0],edgecolor=[0.75,0.75,0.75,1.0])
plt.plot(scSignal_mean,'r')
plt.title('Normalised SC of TS of all runs')
plt.ylabel('% change in pupil area')
plt.xlabel('Video frame sample number (at 30Hz)')
y = arange(-5,5,0.05)
x = ones(len(y))*480
plt.plot(x,y,'r')
plt.ylim(-10,10)

plt.show()
norm_fig.savefig('norm_%s_sc.png' %dType)

#processing as independe trials across all individual inputs
all_area_dat.tofile('/tmp/pupilArea.txt',',')

#signal = mean(all_area_dat,0)

signal = ma.masked_invalid(all_area_dat).mean(0)

#string for title bar
if do_interpolate_blinks_period == 1:
    my_str = 'blinkPeriod=Interpolate'
elif do_Nan_blinks_period == 1:
    my_str = 'blinkPeriod=Nan'


#plot mean TS
fig = plt.figure()
plt.plot(signal.data)
plt.title('Mean TS of all runs with %s' %my_str)
plt.ylabel('Pupil area in pixels')
plt.xlabel('Video frame sample number (at 30Hz)')
#plt.ylim(80,140)

#plot FFT of mean
fourier = fft.fft(signal)
n = signal.size
timestep = (1/30.0)
freq = fft.fftfreq(n, d=timestep)
fft_fig = plt.figure()
plt.plot(freq, abs(fourier))
plt.ylim(-100,45000)
plt.xlim(-0.0,0.33)
fois = [0.03125, 0.0625, 0.125]
for val in fois:
    x = ones(50000)*val
    y = arange(50000)
    plt.plot(x,y)

plt.title('FFT of mean TS of all runs with %s' %my_str)
plt.legend(('abs(fft)', '1/32Hz','1/16Hz', '1/8Hz'))
plt.ylabel('Abs')
plt.xlabel('f')




#plot bg mean

meanBG = ma.masked_invalid(all_bg_dat).mean(0)
fig = plt.figure()
plt.plot(meanBG)
plt.title('Mean BG of all runs with %s' %my_str)

meanBGsc = meanBG.reshape(8,7680/8.0)
meanBGsc = meanBGsc.mean(0)
fig = plt.figure()
plt.plot(meanBGsc)
plt.title('Mean single cycle of BG of all runs with %s' %my_str)


#plot single cycle of meanTS
scSignal = signal.data
scSignal = scSignal.reshape(8,7680/8.0)
scSignal = scSignal.mean(0)
fig = plt.figure()
#scSignal.tofile('/tmp/dat.txt', sep=",", format="%s")
plt.plot(scSignal)
plt.title('SC of TS of all runs')
plt.ylabel('Pupil area in pixels')
plt.xlabel('Video frame sample number (at 30Hz)')
plt.ylim(300,400)

y = arange(min(scSignal),max(scSignal),0.05)
x = ones(len(y))*480
plt.plot(x,y,'r')

#for some stats ...
#normalise all signals first
norm_area_dat = all_area_dat.copy()
#matrix is n x 7680 here
for i in range(len(norm_area_dat)):
    #demean
    curr_mean = norm_area_dat[i].mean(0)
    norm_area_dat[i] = norm_area_dat[i] - curr_mean
    #normalise
    norm_area_dat[i] =  norm_area_dat[i]/curr_mean

norm_signal = ma.masked_invalid(norm_area_dat).mean(0)

scSignalNorm = norm_signal.data
scSignalNorm = scSignalNorm.reshape(8,7680/8.0)
scSignalNormMean = scSignalNorm.mean(0)*100.0
scSignalNormStdError = scSignalNorm.std(0)/sqrt(len(norm_area_dat))*100.0
norm_fig = plt.figure()
plt.fill_between(arange(0,7680/8.0,1.0),scSignalNormMean+scSignalNormStdError, scSignalNormMean-scSignalNormStdError,facecolor=[0.75,0.75,0.75,1.0],edgecolor=[0.75,0.75,0.75,1.0])
plt.plot(scSignalNormMean,'r')
#plt.plot(scSignalNormMean+scSignalNormStdError,'r')
#plt.plot(scSignalNormMean-scSignalNormStdError,'r')



plt.title('Normalised SC of TS of all runs')
plt.ylabel('% change in pupil area')
plt.xlabel('Video frame sample number (at 30Hz)')
y = arange(-5,5,0.05)
x = ones(len(y))*480
plt.plot(x,y,'r')
plt.ylim(-10,10)



fft_fig.savefig('norm_%s_fft.png' %dType)
#norm_fig.savefig('norm_%s_sc.png' %dType)

plt.show()


 ##plot mean FFT
#fig = plt.figure()
#plt.plot(freq, mean(all_indiv_fourier,0))
#plt.ylim(-100,25000)
#plt.xlim(-0.0,0.33)
#fois = [0.03125, 0.0625, 0.125]
#for val in fois:
#    x = ones(25000)*val
#    y = arange(25000)
#    plt.plot(x,y)
#
#plt.title('Mean of FFTs of all runs with %s' %my_str)
#plt.legend(('abs(fft)', '1/32Hz','1/16Hz', '1/8Hz'))
#plt.ylabel('Abs')
#plt.xlabel('f')
   
    

##plot mean FFT
#fig = plt.figure()
#plt.plot(freq, mean(all_indiv_fourier,0))
#plt.ylim(-100,25000)
#plt.xlim(-0.0,0.33)
#fois = [0.03125, 0.0625, 0.125]
#for val in fois:
#    x = ones(25000)*val
#    y = arange(25000)
#    plt.plot(x,y)
#
#plt.title('Mean of FFTs of all runs with %s' %my_str)
#plt.legend(('abs(fft)', '1/32Hz','1/16Hz', '1/8Hz'))
#plt.ylabel('Abs')
#plt.xlabel('f')

    
    
    
