from dv3dAnalyzeReader import *
import gzip, sys
from struct import unpack, pack, calcsize
from numpy import *
import vtk


def analyze2vtkImageData(filename):
    #global vol
    """Read in an image in analyze format and return a vtkImageData object"""

    # remove any file ending
    if (filename[-4:] == '.hdr') or (filename[-4:] == '.img'):
        filename = filename[:-4]

    # parse the header
    ar = AnalyzeReader(filename)
    #spacing = [abs(x) for x in ar.pixdim]
    spacing = ar.pixdim
    num_of_values = ar.dim[0]*ar.dim[1]*ar.dim[2]
    
    if ar.dtype[0] == 4:
        bytes_per_sample = 2
        filename = open(filename + '.img', 'rb')
        data = fromfile(filename, short)
        data_type = vtk.VTK_SHORT
        
    elif ar.dtype[0] == 16:

        bytes_per_sample = 4
        filename = open(filename + '.img', 'rb')
        data = fromfile(filename, short)
        data_type = vtk.VTK_FLOAT


        #bytes_per_sample = 4
        #
        #f = open(filename + '.img', "rb")
        ##data = fromfile(f, '>f')
        #num_bytes = 176*256*256*calcsize('>f')
        #data = f.read(num_bytes)
        #f.close()
        ##f = open(filename + '.img', "rb")
        ##data_test = unpack('>%df' %(176*256*256), f.read(calcsize('>%df' %(176*256*256))))
        ##f.close()
        #data_type = vtk.VTK_FLOAT
        
    elif ar.dtype[0] == 2:
        bytes_per_sample = 1
        filename = open(filename + '.img', 'rb')
        data = fromfile(filename, 'B')
        data_type = vtk.VTK_UNSIGNED_CHAR
        
    elif ar.dtype[0] == 8:
        bytes_per_sample = 8
        data = fromfile(filename + '.img', 'i')
        data_type = vtk.VTK_INT

    elif ar.dtype[0] == 16:
        bytes_per_sample = 4
        data = fromfile(filename + '.img', 'f')
        data_type = vtk.VTK_FLOAT

    elif ar.dtype[0] == 64:
        bytes_per_sample = 8
        data1 = open(filename + '.img', 'rb')
        data = fromfile(data1,'d')
        data_type = vtk.VTK_DOUBLE
    
    
    vol = vtk.vtkImageImport()
    size = len(data) * bytes_per_sample
    vol.CopyImportVoidPointer(data, size)
    vol.SetDataScalarType(data_type)
    vol.SetNumberOfScalarComponents(1)
    extent = vol.GetDataExtent()

    dim = list(ar.dim[0:3])
    dim.reverse()

    vol.SetDataExtent(extent[0],extent[0]+dim[2]-1,
                                        extent[2],extent[2]+dim[1]-1,
                                        extent[4],extent[4]+dim[0]-1)
    vol.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
                                         extent[2],extent[2]+dim[1]-1,
                                         extent[4],extent[4]+dim[0]-1)

    vol.SetDataSpacing(spacing[0], spacing[1], spacing[2])
    
    #if ar.pixdim[0] < 0:
    #    vol.SetDataOrigin(-ar.GetFOV()[0], 0, 0)

    vol.Update()

    bounds = vol.GetOutput().GetBounds()
    
    #hack fix here - across different platforms th order of the x-dimension pair can often be switched
    # thus we check if they are small first and larger second - if not we rearrange ... ?
    first_x, second_x = bounds[0], bounds[1]
    if second_x < first_x: #if the proble has happened ...
        bounds = [second_x, first_x , bounds[2], bounds[3], bounds[4], bounds[5]] #reverse the order of first 2 elements
        bounds = tuple(bounds)   
    else:
        pass
    
    
    #set the default plane position to the middle of the dataset
    x_slice_pos, y_slice_pos, z_slice_pos = int(ar.dim[0]/2),int(ar.dim[1]/2),int(ar.dim[2]/2)
    
    (y, x) = vol.GetOutput().GetScalarRange()
        
    return vol, x, y, bounds, x_slice_pos, y_slice_pos, z_slice_pos, 0, 0, 0, 0