from meg_reader import *
import re
from mplot import mplot
import threading
from pylab import *
import vtk
from numpy import *
from dv3dObjectsToRender import *
from dv3dPropertyFrames import *


class x_surf_Thread( threading.Thread ):
    def __init__ (self, group=None, target=None, name=None, *args, **kwargs):
        threading.Thread.__init__(self)
        self.arg1 = kwargs['arg1']
        self.arg2 = kwargs['arg2']
        self.arg3 = kwargs['arg3']
        self.arg4 = kwargs['arg4']
    def run(self):

        w = self.arg1
        davg = self.arg4
        figure()
        xlabel = 'time in ms' 
        mplot(self.arg4,'',w)
        show()


def Extract_triggers(source_file):
    pretrig_dur = 100
    epoch_dur = 200
    m = MegReader(source_file, pretrig_dur, epoch_dur)
    #
    #slices_per_epoch = m.slices_per_epoch
    #sample_rate = m.sample_rate
    
    #extract a list of unique trigger values
    w = list(set(list(array(m.TriggerCodes)[:,1])))
    w.sort()
    w = str(w)[1:-1].replace(" ","")    
    
    return w

 



def ContourPlot3d(filename, pretrig, duration, trigger_string, filter_string, the_parent_frame):
    #global projectF, vlp, vl, s_scalars, s_meshActor, s_meshActor2
    #TODO - exclude rejected coils
    #       path to files to load
    #       titles for axes / plots

    ### load the data, extract epochs, filter and average
    
    #set preliminary values to query the file .. adjusted later with user values -
    #  HACK! to allow us to read in avalaible trigger values 
    pretrig_dur = pretrig
    epoch_dur = duration
    m = MegReader(filename, pretrig_dur, epoch_dur)
    #
    #slices_per_epoch = m.slices_per_epoch
    #sample_rate = m.sample_rate
    
      
    #   triggers
    trig_list = []
    if (trigger_string == 'all') or (trigger_string == 'All') or (trigger_string == 'ALL'):
        trig_list = list(set(list(array(m.TriggerCodes)[:,1])))
    else:
        num_triggers = len(trigger_string.split(',')) # TODO error handling for incorrect strings
        for r in range(num_triggers):
            trig_list.append(int(trigger_string.split(',')[r]))

    #   frequency filter
    freqs = filter_string
    if str(freqs) == 'broadband':
        fband = []
    else:
        freq1 = filter_string.split(',')
        fband = [int(freq1[0]),int(freq1[1])]
        

    group_list = []
    artifact_list = []
    
    #now update with the user's selected values and re-read - HACK!
    m = MegReader(filename, pretrig, duration)
    
    slices_per_epoch = m.slices_per_epoch 
    sample_rate = m.sample_rate
    
    el = m.GetEpochList(trig_list, group_list, artifact_list)
    Blst = []
    for i in range(len(el)):
        B = m.GetEpoch_MEG(el[i])
        if fband == []:
            Blst.append(remove_DC(B, len(B)))        
        elif fband[0] and fband[1]:
            Blst.append(meg_filter(remove_DC(B, len(B)), m.sample_rate, fband[0], fband[1]))
    
    #davg = dstack(Blst).mean(axis=2)
    
    davg = dstack(Blst)
    
    davg = davg.mean(axis=2)/davg.std(axis=2)
    
    #prepare x-axis values
    w = linspace(0-pretrig_dur, epoch_dur-pretrig_dur, (sample_rate/1000)*epoch_dur)


    #start a thread to plot a butterfly plot of the required epoch
    my_thread = x_surf_Thread(arg1=w, arg2=trig_list, arg3=fband, arg4=davg)
    my_thread.start()


    math = vtk.vtkMath()
    points = vtk.vtkPoints()
    pointst = vtk.vtkPoints()
    
    #get the sensor co-ordinate data
    def get_coords():
        coil_filename = filename+'.m4d'
        if coil_filename == '':
                print 'Load cancelled\n' 
                pass#return
        #
        dlg = wx.FileDialog(the_parent_frame, "Choose the transform file for this session", '', "", "*", wx.OPEN)
        if dlg.ShowModal() == wx.ID_OK:
            path = dlg.GetPath()
            coil_transform = path
        # 
        #coil_transform = '/groups/Projects/P1119/data/MEG/evoked/Session2/R1025/R1025/R1025_P1119a_09_08_13_13_29_2/R1025_P1119a_09_08_13_13_29_2_transform.txt'#./DV3D_essentials/identity.mat.txt'
        if coil_transform == '':
                print 'Load cancelled\n' 
                pass#return
        
        """Read in coil position and orientation in meg SCS from a .m4d file."""
        f=open(coil_filename)
        d = f.readlines()
        f.close()
        
        info_start = d.index('MSI.Meg_Position_Information.Begin:\n')
        info_end = d.index('MSI.Meg_Position_Information.End:\n')
        
        my_points = empty( (248,6), float)
        for y in d[info_start+1:info_end]:
            if y[0] != 'A':
                continue
            z=y.split()
            #my_index_order.append(z[0][1:]) #remove the 'A' and add the actual cahnnel number to a list so we can sort later
            for i in range(1,7):
                my_points[int(z[0][1:])-1, i-1] = float(z[i])
        
        # Total number of points.
        numberOfInputPoints = int(len(my_points))
        
        inputPoints = vtk.vtkPoints()
        
        f = reshape(fromfile(coil_transform,sep=' '),(4,4))
                
        new_coordslist=[]
        
        for i in range(248):
            # Extract values for each point ???and convert from metres to mm
            x = float(my_points[i][0])
            y = float(my_points[i][1])
            z = float(my_points[i][2]) 
            #now transform into head co-ordinate space
            coord1 = array([x, y, z, 1]).reshape(4,1)
            trans_coord1 = dot(f, coord1).reshape(1,4)
            new_coordslist.append([trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]])
            inputPoints.InsertPoint(i, trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
        
        return inputPoints, new_coordslist
    
    t,u = get_coords()
    p = array(u)
    x_min, x_max, y_min, y_max, z_min, z_max = min(p[:,0]), max(p[:,0]), min(p[:,1]), max(p[:,1]), min(p[:,2]), max(p[:,2]),
    
    projection_pt = ((x_min+x_max)/2, (y_min+y_max)/2, z_max*1.0001)
    
    
    num_coils = 0
    mypts=[]
    myptst=[]
    s_scalars = []
    mx,my,mz = projection_pt[0],projection_pt[1], projection_pt[2]  
    
    for i in range(248):
        x1, y1, z1 = float(p[i][0]), float(p[i][1]), float(p[i][2])
        scf=0.001
        xt = mx + scf*(mz-z1)*(x1-mx)
        yt = my + scf*(mz-z1)*(y1-my)
        chn = i
        mypts.append((chn,(x1,y1,z1)))
        myptst.append((chn, (xt,yt,z_min)))
        num_coils += 1
    
    
    ###lo = 10000
    ###hi = -10000
    #### TODO - option to pass coordinate file for non-4D systems
    ###f = open('./DV3D_essentials/meglst.txt')
    ###line = f.readline()
    ###vl = []
    ###cnt = 0
    ###while line:
    ###    vtmp = line.split()
    ###    ftmp = [float(y) for y in vtmp]
    ###    vl.append(ftmp)
    ###    if max(ftmp) > hi:
    ###        hi = max(ftmp)
    ###    if min(ftmp) < lo:
    ###        lo = min(ftmp)
    ###    line = f.readline()
    
    vl = davg.transpose()
    
    hi = davg.max()
    lo = davg.min()
    
    for i in range(248):
        s_scalars.append(vl[i][122])
    vlp = 122
    
    #mypts.sort()
    #myptst.sort()
    for i in range(len(mypts)):
        points.InsertNextPoint(mypts[i][1])
        pointst.InsertNextPoint(myptst[i][1])    
    
    # Create a polydata with the points we just created.
    profile = vtk.vtkPolyData()
    profile.SetPoints(points)

    thin = vtk.vtkThinPlateSplineTransform()
    thin.SetSourceLandmarks(points)
    thin.SetTargetLandmarks(pointst)
    thin.SetBasisToR2LogR()
    
    # Perform a 2D Delaunay triangulation on them.
    delny = vtk.vtkDelaunay2D()
    delny.SetInput(profile)
    delny.SetTransform(thin)
    delny.SetTolerance(0.001)
    
    smoother = vtk.vtkSmoothPolyDataFilter()
    smoother.SetInput(delny.GetOutput())
    smoother.SetNumberOfIterations(150)
    
    smoother.Update()
    model = smoother.GetOutput()
    
    projectF = vtk.vtkProgrammableFilter()
    projectF.SetInput(model)
    
    def project():
        input = projectF.GetPolyDataInput()
        numPts = input.GetNumberOfPoints()
        newPts = vtk.vtkPoints()
        derivs = vtk.vtkFloatArray()
    
        for i in range(0, numPts):
            x2,y2,z2 = input.GetPoint(i)
            
            newPts.InsertPoint(i, -x2, y2, z2) #LR flip
            derivs.InsertValue(i, s_scalars[i]) 
    
        projectF.GetPolyDataOutput().CopyStructure(input)
        projectF.GetPolyDataOutput().SetPoints(newPts)
        projectF.GetPolyDataOutput().GetPointData().SetScalars(derivs)
    
    projectF.SetExecuteMethod(project) 
    
    cf = vtk.vtkContourFilter()
    cf.SetInput(projectF.GetOutput())
    cf.GenerateValues(100,lo,hi)
    
    # Sometimes the contouring algorithm can create a volume whose gradient
    # vector and ordering of polygon (using the right hand rule) are
    # inconsistent. vtkReverseSense cures this problem.
    reverse = vtk.vtkReverseSense()
    reverse.SetInput(cf.GetOutput())
    reverse.ReverseCellsOn()
    reverse.ReverseNormalsOn()
    
    s_lut = vtk.vtkLookupTable()
    mapMesh = vtk.vtkPolyDataMapper()
    mapMesh.ImmediateModeRenderingOn()
    mapMesh.SetLookupTable(s_lut)
    mapMesh.SetInput(projectF.GetOutput())
    mapMesh.SetScalarRange(lo,hi)
    s_meshActor = vtk.vtkActor()
    s_meshActor.SetMapper(mapMesh)
    s_meshActor.GetProperty().SetColor(.1, .2, .4)
    
    mapMesh2 = vtk.vtkPolyDataMapper()
    mapMesh2.ImmediateModeRenderingOn()
    mapMesh2.SetInput(reverse.GetOutput())
    s_meshActor2 = vtk.vtkActor()
    s_meshActor2.SetMapper(mapMesh2)
    s_meshActor2.GetProperty().SetColor(0,0,0)
    
    s_lut.SetHueRange(0.0, 0.667)
    s_lut.SetNumberOfColors(256)
    s_lut.Build()
    
    # Add the actors to the renderer, set the background and size
    the_parent_frame.ren.AddActor(s_meshActor)
    #the_parent_frame.ren.AddActor(s_meshActor2) TODO - re-add the contours?
    
    AddVTKObjectWithAttributes(s_meshActor,\
                            the_parent_frame,\
                            '3D_contour_plot',\
                            '3D contour plot',\
                            None,\
                            None,\
                            0,\
                            1,\
                            1,\
                            'structural',\
                            None,\
                            filename,\
                            1,\
                            '3DContourPlotProperties',\
                            None)
    
                #vlp -= 1
                #if vlp < 0:
                #    vlp = len(vl)
                #for i in range(248):
                #    s_scalars[i] = vl[vlp][i]
                #projectF.Modified()
                #cur_time_pt = 1000/vlp - 200
                #print 'Contour plot at %f ms.\n' % cur_time_pt
    
    
    #the_parent_frame.ortho_window.ortho_ren1.AddActor(the_parent_frame.ListOfObjects[-1])
    #the_parent_frame.ortho_window.ortho_ren2.AddActor(the_parent_frame.ListOfObjects[-1])
    #the_parent_frame.ortho_window.ortho_ren3.AddActor(the_parent_frame.ListOfObjects[-1])
    the_parent_frame.ren.AddActor(the_parent_frame.ListOfObjects[-1])
    the_parent_frame.ListOfObjects[-1].vlp = vlp
    the_parent_frame.ListOfObjects[-1].vl = vl
    the_parent_frame.ListOfObjects[-1].s_scalars = s_scalars
    the_parent_frame.ListOfObjects[-1].pretrig = pretrig
    the_parent_frame.ListOfObjects[-1].duration = duration
    the_parent_frame.ListOfObjects[-1].projectF = projectF
    
    #TODO-this currently reads the megslst.txt file not the actual data!
    #print len(vl)
    
    new_surface = the_parent_frame.tree.AppendItem(the_parent_frame.tree_structural, the_parent_frame.ListOfObjects[-1].my_label)
    new_surface.my_index_into_objects = len(the_parent_frame.ListOfObjects)-1
    #add the newly created treeitem as an atribute of the plane_group
    the_parent_frame.ListOfObjects[-1].my_treeitem = new_surface
    the_parent_frame.tree.Refresh()
    
    # .. create a properties window for the plane set ...
    # we pass the parent frame and the newly created object as args
    frame = MyPropertyFrame(the_parent_frame, the_parent_frame.ListOfObjects[-1], 0)
    frame.Show(True)
    
    # ... and apply it to the relevant list object
    the_parent_frame.ListOfObjects[-1].my_property_frame_instance = frame
    
    the_parent_frame.widget.Render()
    #print len(vl), len(vlp)
    #print shape(vl), shape(vlp)
    #print size(vl), size(vlp)
    
    return m
    
