from dv3dNiftiReader import *
import vtk
from numpy import *

def nifti2vtkImageData(filename):
    """Read in an image in .nii.gz format and return a vtkImageData object"""

    # parse the header
    ar = nifti_reader(filename)
    spacing = ar.pixdim

    data = ar.get_image_data()
    if ar.dtype == nifti_reader.DT_SIGNED_SHORT:
        bytes_per_sample = 2
        ##hack!!
        #data_type = vtk.VTK_SHORT
        data_type = vtk.VTK_UNSIGNED_SHORT
    elif ar.dtype == nifti_reader.DT_FLOAT:
        bytes_per_sample = 4
        data_type = vtk.VTK_FLOAT
    elif ar.dtype == nifti_reader.DT_SIGNED_INT:
        bytes_per_sample = 4
        data_type = vtk.VTK_INT
    else:
        bytes_per_sample = 1
        data_type = vtk.VTK_UNSIGNED_CHAR
        
    data = data
    x = data.max()
    y = data.min()

    vol = vtk.vtkImageImport()
    size = len(data) * bytes_per_sample
    vol.CopyImportVoidPointer(data, size)
    vol.SetDataScalarType(data_type)
    vol.SetNumberOfScalarComponents(1)
    extent = vol.GetDataExtent()

    dim = list(ar.dim)
    dim.reverse()

    vol.SetDataExtent(extent[0],extent[0]+dim[2]-1,
                                        extent[2],extent[2]+dim[1]-1,
                                        extent[4],extent[4]+dim[0]-1)
    vol.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
                                         extent[2],extent[2]+dim[1]-1,
                                         extent[4],extent[4]+dim[0]-1)
	
	#changed for Kinect stuff
    vol.SetDataSpacing(-ar.pixdim[0], ar.pixdim[1], ar.pixdim[2])
    #vol.SetDataSpacing(ar.pixdim[0], ar.pixdim[1], ar.pixdim[2])

    srow_x = (float(ar.srow_x[0]), float(ar.srow_x[1]), float(ar.srow_x[2]), float(ar.srow_x[3]))
    srow_y = (float(ar.srow_y[0]), float(ar.srow_y[1]), float(ar.srow_y[2]), float(ar.srow_y[3]))
    srow_z = (float(ar.srow_z[0]), float(ar.srow_z[1]), float(ar.srow_z[2]), float(ar.srow_z[3]))

    o = [0,0,0]
    ####TODO- fix for freesurfer!
    ###vol.SetDataOrigin(ar.fov[0]*ar.pixdim[0], o[1], o[2]) #TODO -check offsets for .off
    #vol.SetDataOrigin(srow_x[3], srow_y[3], srow_z[3])
    #? TODO offest required?
    import_data_x_offset = ar.fov[0]*ar.pixdim[0]
    vol.Update()
    bounds = vol.GetOutput().GetBounds()
    print 'bounds', bounds
    
    #hack fix here - across different platforms th order of the x-dimension pair can often be switched
    # thus we check if they are small first and larger second - if not we rearrange ... ?
    first_x, second_x = bounds[0], bounds[1]
    if second_x < first_x: #if the proble has happened ...
        bounds = [second_x, first_x , bounds[2], bounds[3], bounds[4], bounds[5]] #reverse the order of first 2 elements
        bounds = tuple(bounds)   
    else:
        pass
        
    
    #set the default plane position to the middle of the dataset
    x_slice_pos, y_slice_pos, z_slice_pos = int(ar.dim[0]/2),int(ar.dim[1]/2),int(ar.dim[2]/2)


    return vol, x, y, bounds, x_slice_pos, y_slice_pos, z_slice_pos, srow_x, srow_y, srow_z, spacing

