
#!/usr/bin/python
import vtk
from vtk.util.colors import *
from struct import unpack, calcsize #from surface
from numpy import *
import scipy as Sci
import scipy.linalg
import sys
import os
from subprocess import Popen, PIPE
import time


#set globals 
x_const = 0
y_const = 0
z_const = 0


def get_path(path):
    return os.path.join(os.path.split(__file__)[0], path)

# --------- File handling ------------

def my_file_dialog(processing_type):
    p = Popen('pythonw %s %s' % (get_path('ynicDV3D_wxdialogs.py'), processing_type), stdout=PIPE, shell=True)
    p.wait() 
    (so, se) = p.communicate()
    return so.strip()


# --------- DTI-Processing ------------

def create_dti_array(): ## returns dti_data_array, dti_x_const, dti_y_const, dti_z_const 
    global coreg_transformation
    #ROUTINE FOR DEALING WITH VOLUME COREGISTRATION

    fn = './DV3D_examples/dti/all_20k.out'
    dti_filename = fn
    
    #get the offsets of the original volume
    ###dti_ref_filename = my_file_dialog(7)
    dti_ref_filename = './DV3D_examples/dti/dti_FA.nii.gz'
    ###if dti_ref_filename == '':
    ###    print 'Load cancelled ..\n' 
    ###    return
    
    ###mat_filename = my_file_dialog(8)
    mat_filename = './DV3D_examples/dti/dti_FA_to_betT1.nii.mat'
    ###if mat_filename == '':
    ###    print 'Load cancelled ..\n' 
    ###    return
        
    #p = os.popen('avwhd %s' %dti_ref_filename)
    #d = p.readlines()
    #p.close()
    
    orig_offsetList=[]
    orig_offsetList.append(float(132.158020))
    orig_offsetList.append(float(-101.527100))
    orig_offsetList.append(float(-8.95230))
    orig_offsetList.append(float(1.000000))
    
    #for c in d:
    #    s = c.split()
    #    try:
    #        if s[0][0:8] == 'qto_xyz:':
    #            orig_offsetList.append(float(s[4]))
    #    except:
    #        pass

    coreg_corections = []
    coreg_transformation = []
    
    f = open(mat_filename)
    d = f.readlines()
    f.close()
    
    for c in d:
        s = c.split()
        try:
            coreg_transformation.append(float(s[0]))
            coreg_transformation.append(float(s[1]))
            coreg_transformation.append(float(s[2]))
            coreg_corections.append(float(s[3]))
        except:
            pass
        
    r = array([coreg_transformation])
    s = r.reshape(4,3)
    t = array(s[:3,:3])

    dti_x_const = coreg_corections[0]# + orig_offsetList[0] ###+ manual_correction[0]
    dti_y_const = coreg_corections[1]# - orig_offsetList[1] ###+ manual_correction[1]
    dti_z_const = coreg_corections[2]# - orig_offsetList[2] ###+ manual_correction[2]


    #ROUTINE TO CREATE THE DATA ARRAY IN MEMORY
    #NB this is stored as a numpy array for efficiency
    
    lineData = []
    #f = open('/home/andre/scratch/ptq_dti/dodti/all_20k.out')
    f = open(dti_filename)
    s = f.readlines()
    f.close()
    print 'restricted to 1000 fibers for demo'
    for x in s[0:1000]:
        lineData.append([float(y) for y in x.split()])
    dim1 = len(lineData) #number of lines
    dim2 = 3 #xyz
    dim3 = int(max(lineData)[0]) #number of points on longest line
    #lines are different lengths (have different numbers of points defining their paths
    #so we create a null array to hold the values of the longest line and pad the rest 
    #with an unlikely value i.e. 100000 mm
    dti_data_array = ones((dim1,dim2,dim3))
    dti_data_array = dti_data_array*100000.0

    for j in range(dim1):
        for k in range(dim2):
            for l in range(int(lineData[j][0])):
                dti_data_array[j][k][l] = lineData[j][(int(lineData[j][0])*k)+l+1]

    
    #change matrix shape so calculations are easier to vectorise    
    id = identity(3)
    dti_data_array = dot(id, dti_data_array)

    #reset the fiber co-ordinates so that their origin is 0,0,0 (remove qform offset)
    dti_data_array[0] -= orig_offsetList[0] - 176
    dti_data_array[1] -= orig_offsetList[1] 
    dti_data_array[2] -= orig_offsetList[2] + 8
    
    #back to original shape
    dti_data_array = array([dti_data_array[:,i,:] for i in range(dti_data_array.shape[1])])

    #now .multiply to take the corrected co-ords into the structural space
    dti_data_array_transformed = dot(t,dti_data_array)

    ###return dti_data_array_transformed, dti_x_const, dti_y_const, dti_z_const
    
    profileList = extract_dti_fiber_set(dti_data_array_transformed, dti_x_const, dti_y_const, dti_z_const, -5000,5000,-5000,5000,-5000,5000,-5000,5000,-5000,5000,-5000,5000,0) 
    
    return profileList

##array is created .. now post-process in main program...

def extract_dti_fiber_set(dti_data_array, dti_x_const, dti_y_const, dti_z_const, xmin_a, xmax_a, ymin_a, ymax_a, zmin_a, zmax_a, xmin_b, xmax_b, ymin_b, ymax_b, zmin_b, zmax_b, import_data_x_offset):

#####FOR TESTING - ignores seed positions####
#    xmin_a = -10000
#    xmax_a = 10000
#    ymin_a = -10000
#    ymax_a = 10000
#    zmin_a = -10000
#    zmax_a = 10000
########
        
    x1a = dti_data_array[0,:,:] - dti_x_const + import_data_x_offset >=xmin_a
    x2a = dti_data_array[0,:,:] - dti_x_const + import_data_x_offset <=xmax_a
    y1a = dti_data_array[1,:,:] + dti_y_const >=ymin_a
    y2a = dti_data_array[1,:,:] + dti_y_const <=ymax_a
    z1a = dti_data_array[2,:,:] + dti_z_const >=zmin_a
    z2a = dti_data_array[2,:,:] + dti_z_const <=zmax_a
    
    if xmin_b != 'None':
        
        x1b = dti_data_array[0,:,:] - dti_x_const + import_data_x_offset >=xmin_b
        x2b = dti_data_array[0,:,:] - dti_x_const + import_data_x_offset <=xmax_b
        y1b = dti_data_array[1,:,:] + dti_y_const >=ymin_b
        y2b = dti_data_array[1,:,:] + dti_y_const <=ymax_b
        z1b = dti_data_array[2,:,:] + dti_z_const >=zmin_b
        z2b = dti_data_array[2,:,:] + dti_z_const <=zmax_b        
    #apply the bounds
    if xmin_b == 'None':
        conditions = x1a & x2a & y1a & y2a & z1a & z2a
        matching_indices = conditions.max(axis=1)
        #save index list as a separate array
        indices_to_extract = []
        for j in range(len(matching_indices)):
            if matching_indices[j]==True:
                indices_to_extract.append(j)
    else:
        conditions_seed = x1a & x2a & y1a & y2a & z1a & z2a
        matching_indices_seed = conditions_seed.max(axis=1)
        #save index list as a separate array
        indices_to_extract_seed = []
        for j in range(len(matching_indices_seed)):
            if matching_indices_seed[j]==True:
                indices_to_extract_seed.append(j)
        
        conditions_target = x1b & x2b & y1b & y2b & z1b & z2b
        matching_indices_target = conditions_target.max(axis=1)
        #save index list as a separate array
        indices_to_extract_target = []
        for j in range(len(matching_indices_target)):
            if matching_indices_target[j]==True:
                indices_to_extract_target.append(j)
        
        indices_to_extract = []
        for i in range(len(indices_to_extract_target)):
            if indices_to_extract_target[i] in indices_to_extract_seed:
                indices_to_extract.append(indices_to_extract_target[i])
    
    #print 'len = %s' %len(indices_to_extract)
    if len(indices_to_extract) == 0:
        print 'None of the %s supplied fibers cross this area ...\n' %len(dti_data_array)
        return
    else:
        print '\n%s of the supplied fibers cross this area ...\n' %(len(indices_to_extract))

        #ROUTINE FOR GENERATING THE LINE ACTORS:
        
        inputPoints = vtk.vtkPoints()
        
        # Set up some empty arrays to store our calculated / extracted values in   - One spline for each direction.
        aSplineX = vtk.vtkCardinalSpline()
        aSplineY = vtk.vtkCardinalSpline()
        aSplineZ = vtk.vtkCardinalSpline()
        
        # List of line profiles to be drawn through the points
        profileList = []
        
        #create an actor for each line
        for j in range(len(indices_to_extract)):
            
            #get the number stored in the indices_to_extract array = line number in dti_data_array
            line_to_get = indices_to_extract[j]
            
            #each line has a variable number of points, place holders in the array are '100000'
            #each line has x+y+z co-ords so:
            #count the number of values in the line < 5000? (NB not 100000 here - to allow for scaling after the 
            #transformation matrix calculations) and divide by 3 to give the number of points
            number_of_input_points = sum(dti_data_array[0][line_to_get]<50000)###/3
            
            
            #extract the points
            for m in range(number_of_input_points):
                
                # Extract values for each point
                x = dti_data_array[0][line_to_get][m] - dti_x_const + import_data_x_offset-10#+ import_data_x_offset+dti_x_const#+(import_data_x_offset/2)  #add half FOV of DTI set to .mat transform of t1Flair to t1
                y = dti_data_array[1][line_to_get][m] + dti_y_const #- dti_y_const
                z = dti_data_array[2][line_to_get][m] + dti_z_const #+ dti_z_const
                 
                
                # Add the corresponding coordinates to the splines.
                # aSplineX will interpolate the x values of the points
                # aSplineY will interpolate the y values of the points
                # aSplineZ will interpolate the z values of the points
                aSplineX.AddPoint(m, x)
                aSplineY.AddPoint(m, y)
                aSplineZ.AddPoint(m, z)
                inputPoints.InsertPoint(m, x, y, z)            
            
            
            # Generate the polyline for the spline.
            points = vtk.vtkPoints()
            profileData = vtk.vtkPolyData()
            
            # Number of points on the spline
            numberOfOutputPoints = 25
            
            # Interpolate x, y and z by using the three spline filters and
            # create new points
            for i in range(0, numberOfOutputPoints):
                t = (number_of_input_points-1.0)/(numberOfOutputPoints-1.0)*i
                points.InsertPoint(i, aSplineX.Evaluate(t), aSplineY.Evaluate(t),
                                   aSplineZ.Evaluate(t))
            
            # Create the polyline.
            lines = vtk.vtkCellArray()
            lines.InsertNextCell(numberOfOutputPoints)
            for i in range(0, numberOfOutputPoints):
                lines.InsertCellPoint(i)
             
            profileData.SetPoints(points)
            profileData.SetLines(lines)
            
            # Add thickness to the resulting line.
            profileTubes = vtk.vtkTubeFilter()
            profileTubes.SetNumberOfSides(1)
            profileTubes.SetInput(profileData)
            profileTubes.SetRadius(0.1)
            #profileTubes = vtk.vtkSplineFilter()
            #profileTubes.SetInput(profileData)
            
            profileMapper = vtk.vtkPolyDataMapper()
            profileMapper.SetInput(profileTubes.GetOutput())
            
        #    glyphList.append(vtk.vtkActor())
            profileList.append(vtk.vtkLODActor()) #use LOD actor for faster interactive rendering(?)
            #profileList.append(vtk.vtkActor())
            profileList[j].SetMapper(profileMapper)
            profileList[j].GetProperty().SetDiffuseColor(yellow)
            profileList[j].GetProperty().SetSpecular(.3)
            profileList[j].GetProperty().SetSpecularPower(30)
            profileList[j].VisibilityOn()
            profileList[j].SetPickable(1)
            profileList[j].my_fiber_group = 0
                            
        return (profileList)        
        
