from meg_reader import *

def Run_MEG_Reader(source_file):
    pretrig_dur = 100
    epoch_dur = 200
    m = MegReader(filename, pretrig_dur, epoch_dur)
    #
    #slices_per_epoch = m.slices_per_epoch
    #sample_rate = m.sample_rate
    
    #extract a list of unique trigger values
    w = list(set(list(array(m.TriggerCodes)[:,1])))
    w.sort()
    w = str(w)[1:-1].replace(" ","")    
    
    
    #allow user to pick codes
    tst = contour_options_dialog(w)
    args = tst.split(';')
    
    #   triggers
    trig_list = []
    if args[2] == 'all':
        trig_list = list(set(list(array(m.TriggerCodes)[:,1])))
    else:
        num_triggers = len(args[2].split(','))
        for r in range(num_triggers):
            trig_list.append(int(args[2].split(',')[r]))
    return trig_list

 



def ContourPlot3d():
    global projectF, vlp, vl, s_scalars, s_meshActor, s_meshActor2
    #TODO - exclude rejected coils
    #       path to files to load
    #       titles for axes / plots

    ### load the data, extract epochs, filter and average    
    filename = my_file_dialog(51,0,0)
    if filename == '':
        print 'Load cancelled\n' 
        return
    
    #set preliminary values to query the file .. adjusted later with user values - HACK!
    pretrig_dur = 100
    epoch_dur = 200
    m = MegReader(filename, pretrig_dur, epoch_dur)
    #
    #slices_per_epoch = m.slices_per_epoch
    #sample_rate = m.sample_rate
    
    #extract a list of unique trigger values
    w = list(set(list(array(m.TriggerCodes)[:,1])))
    w.sort()
    w = str(w)[1:-1].replace(" ","")    
    
    
    #allow user to pick codes
    tst = contour_options_dialog(w)
    args = tst.split(';')
    
    #now pass arguments to reader
    #   timings
    pretrig_dur, epoch_dur = int(args[0]), int(args[1])
    
    #   triggers
    trig_list = []
    if args[2] == 'all':
        trig_list = list(set(list(array(m.TriggerCodes)[:,1])))
    else:
        num_triggers = len(args[2].split(','))
        for r in range(num_triggers):
            trig_list.append(int(args[2].split(',')[r]))

    #   frequency filter
    freqs = args[3]
    if str(freqs) == 'broadband\n':
        fband = []
    else:
        freq1 = args[3].split(',')
        fband = [int(freq1[0]),int(freq1[1])]
        

    group_list = []
    artifact_list = []
    
    #now update with the user's selected values and re-read - HACK!
    m = MegReader(filename, pretrig_dur, epoch_dur)
    
    slices_per_epoch = m.slices_per_epoch 
    sample_rate = m.sample_rate
    
    el = m.GetEpochList(trig_list, group_list, artifact_list)
    Blst = []
    for i in range(len(el)):
        B = m.GetEpoch_MEG(el[i])
        if fband == []:
            Blst.append(RemoveDC(B, len(B)))
        
        elif fband[0] and fband[1]:
            Blst.append(MegFilter(RemoveDC(B, len(B)), m.sample_rate, fband[0], fband[1]))
            
    davg = dstack(Blst).mean(axis=2)
    
    #prepare x-axis values
    w = linspace(0-pretrig_dur, epoch_dur-pretrig_dur, (sample_rate/1000)*epoch_dur)


    #start a thread to plot a butterfly plot of the required epoch
    global ren, renWin    
    my_thread = x_surf_Thread(arg1=w, arg2=trig_list, arg3=fband, arg4=davg)
    my_thread.start()


    math = vtk.vtkMath()
    points = vtk.vtkPoints()
    pointst = vtk.vtkPoints()
    
    #get the sensor co-ordinate data
    def get_coords():
        coil_filename = my_file_dialog(47,0,0)
        #coil_filename = filename+'.m4d'
        if coil_filename == '':
                print 'Load cancelled\n' 
                return
        coil_transform = my_file_dialog(48,0,0)
        if coil_transform == '':
                print 'Load cancelled\n' 
                return
        
        """Read in coil position and orientation in meg SCS from a .m4d file."""
        f=open(coil_filename)
        d = f.readlines()
        f.close()
        
        info_start = d.index('MSI.Meg_Position_Information.Begin:\n')
        info_end = d.index('MSI.Meg_Position_Information.End:\n')
        
        my_points = empty( (248,6), float)
        for y in d[info_start+1:info_end]:
            if y[0] != 'A':
                continue
            z=y.split()
            for i in range(1,7):
                my_points[int(z[0][1:])-1, i-1] = float(z[i])
                
                
        # Total number of points.
        numberOfInputPoints = int(len(my_points))
        
        inputPoints = vtk.vtkPoints()
        
        # List of points for each line to be drawn through
        glyphList = []
        
        e = reshape(fromfile(coil_transform, sep=' ') , (4,4))
        
        new_coordslist=[]
        
        for i in range(numberOfInputPoints):
            # Extract values for each point ???and convert from metres to mm
            x = float(my_points[i][0])
            y = float(my_points[i][1])
            z = float(my_points[i][2]) 
            #now transform into head co-ordinate space
            coord1 = array([x, y, z, 1]).reshape(4,1)
            trans_coord1 = dot(e, coord1).reshape(1,4)
            new_coordslist.append([trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]])
            inputPoints.InsertPoint(i, trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
        
        return inputPoints, new_coordslist
    
    t,u = get_coords()
    p = array(u)
    x_min, x_max, y_min, y_max, z_min, z_max = min(p[:,0]), max(p[:,0]), min(p[:,1]), max(p[:,1]), min(p[:,2]), max(p[:,2]),
    
    projection_pt = ((x_min+x_max)/2, (y_min+y_max)/2, z_max*1.0001)
    
    
    num_coils = 0
    mypts=[]
    myptst=[]
    s_scalars = []
    mx,my,mz = projection_pt[0],projection_pt[1], projection_pt[2]  
    
    for i in range(248):
        x, y, z = float(p[i][0]), float(p[i][1]), float(p[i][2])
        scf=0.001
        xt = mx + scf*(mz-z)*(x-mx)
        yt = my + scf*(mz-z)*(y-my)
        chn = i
        mypts.append((chn,(x,y,z)))
        myptst.append((chn, (xt,yt,z_min)))
        num_coils += 1
    
    
    lo = 10000
    hi = -10000
    f = open('/home/andre/scratch/DOWNLOADS/meglst.txt')
    line = f.readline()
    vl = []
    while line:
        vtmp = line.split()
        ftmp = [float(y) for y in vtmp]
        vl.append(ftmp)
        if max(ftmp) > hi:
            hi = max(ftmp)
        if min(ftmp) < lo:
            lo = min(ftmp)
        line = f.readline()
    
    for i in range(248):
        s_scalars.append(vl[0][i])
    vlp = 55
    
    mypts.sort()
    myptst.sort()
    for i in range(len(mypts)):
        points.InsertNextPoint(mypts[i][1])
        pointst.InsertNextPoint(myptst[i][1])
    
    
    # Create a polydata with the points we just created.
    profile = vtk.vtkPolyData()
    profile.SetPoints(points)
    
    thin = vtk.vtkThinPlateSplineTransform()
    thin.SetSourceLandmarks(points)
    thin.SetTargetLandmarks(pointst)
    thin.SetBasisToR2LogR()
    
    # Perform a 2D Delaunay triangulation on them.
    delny = vtk.vtkDelaunay2D()
    delny.SetInput(profile)
    delny.SetTransform(thin)
    delny.SetTolerance(0.001)
    
    smoother = vtk.vtkSmoothPolyDataFilter()
    smoother.SetInput(delny.GetOutput())
    smoother.SetNumberOfIterations(150)
    
    smoother.Update()
    model = smoother.GetOutput()
    
    projectF = vtk.vtkProgrammableFilter()
    projectF.SetInput(model)
    
    def project():
        input = projectF.GetPolyDataInput()
        numPts = input.GetNumberOfPoints()
        newPts = vtk.vtkPoints()
        derivs = vtk.vtkFloatArray()
    
        for i in range(0, numPts):
            x,y,z = input.GetPoint(i)
            
            newPts.InsertPoint(i, x, y, z)
            derivs.InsertValue(i, s_scalars[i]) 
    
        projectF.GetPolyDataOutput().CopyStructure(input)
        projectF.GetPolyDataOutput().SetPoints(newPts)
        projectF.GetPolyDataOutput().GetPointData().SetScalars(derivs)
    
    projectF.SetExecuteMethod(project) 
    
    cf = vtk.vtkContourFilter()
    cf.SetInput(projectF.GetOutput())
    cf.GenerateValues(100,lo,hi)
    
    # Sometimes the contouring algorithm can create a volume whose gradient
    # vector and ordering of polygon (using the right hand rule) are
    # inconsistent. vtkReverseSense cures this problem.
    reverse = vtk.vtkReverseSense()
    reverse.SetInput(cf.GetOutput())
    reverse.ReverseCellsOn()
    reverse.ReverseNormalsOn()
    
    s_lut = vtk.vtkLookupTable()
    mapMesh = vtk.vtkPolyDataMapper()
    mapMesh.SetLookupTable(s_lut)
    mapMesh.SetInput(projectF.GetOutput())
    mapMesh.SetScalarRange(lo,hi)
    s_meshActor = vtk.vtkActor()
    s_meshActor.SetMapper(mapMesh)
    s_meshActor.GetProperty().SetColor(.1, .2, .4)
    
    mapMesh2 = vtk.vtkPolyDataMapper()
    mapMesh2.SetInput(reverse.GetOutput())
    s_meshActor2 = vtk.vtkActor()
    s_meshActor2.SetMapper(mapMesh2)
    s_meshActor2.GetProperty().SetColor(0,0,0)
    
    s_lut.SetHueRange(0.0, 0.667)
    s_lut.SetNumberOfColors(256)
    s_lut.Build()
    
    # Add the actors to the renderer, set the background and size
    ren.AddActor(s_meshActor)
    ren.AddActor(s_meshActor2)
    renWin.Render()
    return m