#!/usr/bin/python

# module to read nifti data into DV3D and convert to a format (ImageData) tha VTK can use

from nifti import *
from numpy import *
import vtk


def niftiToVTKImageData(niftiFilename):
    #load the provided data file
    nim = NiftiImage(niftiFilename)

    #extract the data
    data = nim.data

    #determine the data type
    dataType = data.dtype

    #get the data types VTK requires
    [dataTypeVTK, dataBytesPerSampleVTK] =  determineRequiredVTKDataType(str(dataType))  
    
    #get the scalar range of the data
    dMax = data.max()
    dMin = data.min()

    #reshape data to a 1 x n vector - required for ease of next step
    data = data.reshape(size(data))
    
    #set up a vtk class to receive the data with the parameters we have calculated
    vol = vtk.vtkImageImport()
    vol_size = len(data) * dataBytesPerSampleVTK
    vol.CopyImportVoidPointer(data, vol_size)
    vol.SetDataScalarType(dataTypeVTK)
    vol.SetNumberOfScalarComponents(1)
    extent = vol.GetDataExtent()

    dim = list(nim.extent)
    dim.reverse() # cant remember why we do this .. may no longer be necessary
    
    #tell vtk how to unpack our 1 x n vector
    vol.SetDataExtent(extent[0],extent[0]+dim[2]-1,
                      extent[2],extent[2]+dim[1]-1,
                      extent[4],extent[4]+dim[0]-1)
    
    vol.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
                       extent[2],extent[2]+dim[1]-1,
                       extent[4],extent[4]+dim[0]-1)
    
    # euclidian spacing between our data points
    vol.SetDataSpacing(-nim.pixdim[0], nim.pixdim[1], nim.pixdim[2])
    
    #get the s-row information -- useful later
    srow_x = nim.getSForm()[0]
    srow_y = nim.getSForm()[1]
    srow_z = nim.getSForm()[2]

    ###---------------------------------------------- 
    ###some obsolete code I'll keep as a hint for now
    ###o = [0,0,0]
    ###TODO- fix for freesurfer!
    ###vol.SetDataOrigin(ar.fov[0]*ar.pixdim[0], o[1], o[2])
    ###vol.SetDataOrigin(srow_x[3], srow_y[3], srow_z[3])
    ###import_data_x_offset = ar.fov[0]*nim.pixdim[0]
    ###---------------------------------------------- 

    vol.Update()
    bounds = vol.GetOutput().GetBounds()

    #fix required here - across different platforms the order of the x-dimension pair can often be switched
    # thus we check if they are small first and larger second - if not we rearrange ... ?
    if bounds[1] < bounds[0]: #if the problem has happened ...
        #reverse the order of first 2 elements    
        bounds = [bounds[1], bounds[0] , bounds[2], bounds[3], bounds[4], bounds[5]] 
        bounds = tuple(bounds)
    else:
        pass
        
    
    #set the default plane position to the middle of the dataset
    x_slice_pos, y_slice_pos, z_slice_pos = int(dim[0]/2),int(dim[1]/2),int(dim[2]/2)
    
    #return the data we want
    return vol, dMax, dMin, bounds, x_slice_pos, y_slice_pos, z_slice_pos, srow_x, srow_y, srow_z, spacing


def determineRequiredVTKDataType(input_dtype):
    # a function to cast our numpy data array to a datatype that VTK can handle

    # here we hard code a dictionary to look up against
    # column 1 is the vtk datatype equivalent
    # column 2 is bytes_per_sample
    vtkDtypeDict =  {   'uint8':[vtk.VTK_UNSIGNED_CHAR, 1], 
                        'uint16':[vtk.VTK_UNSIGNED_SHORT, 2],
                        'uint32':[vtk.VTK_UNSIGNED_LONG, 4],
                        'int8':[vtk.VTK_CHAR, 1],
                        'int16':[vtk.VTK_SHORT, 2],
                        'int32':[vtk.VTK_INT, 4],
                        'int':[vtk.VTK_LONG, 4],
                        'float32':[vtk.VTK_FLOAT, 4],  
                        'float64':[vtk.VTK_DOUBLE, 8],  
                        'complex32':[vtk.VTK_FLOAT, 4],  
                        'complex64':[vtk.VTK_DOUBLE, 8],
                        }
    return vtkDtypeDict[input_dtype]

