
import gzip, sys
from struct import unpack, pack, calcsize
from numpy import *

class nifti_reader:

    # only a subset of these are supported
    DT_NONE = 0
    DT_UNKNOWN = 0
    DT_BINARY = 1
    DT_UNSIGNED_CHAR = 2
    DT_SIGNED_SHORT = 4
    DT_SIGNED_INT = 8
    DT_FLOAT = 16
    DT_COMPLEX = 32
    DT_DOUBLE = 64
    DT_RGB = 128
    DT_ALL = 255

    def __init__(self, filename=None):
        self.fid = None
        if filename != None:
            self.open(filename)
    
    def __openfile(self, filename):
        """Open nifit.gz file, and set endianess. Returns file pointer reset to start of file"""
        if filename[-7:] != '.nii.gz':
            filename += '.nii.gz'
        fid = gzip.open(filename, 'rb') 
        
        # Check for file type and endianess      
        self._EndianChar = '>'
        self.sizeof_hdr = unpack(self._EndianChar+'i', fid.read(calcsize('i')))[0]
        if self.sizeof_hdr != 348:
            # try swapping endianess
            self._EndianChar = '<'
            fid.seek(0)
            self.sizeof_hdr = unpack(self._EndianChar+'i', fid.read(calcsize('i')))[0]
            if self.sizeof_hdr != 348:
               raise IOError('Unrecognised file type or endianess.')
        else:
            pass
        
        fid.seek(0)
        return fid


    def open(self, filename):
        """Parse the header for information about the image."""
        self.fid = self.__openfile(filename)
        
        self.hdr = self.fid.read(self.sizeof_hdr)
                
        # get the dimensions of the image
        self.dim = unpack(self._EndianChar+'hhhh', self.hdr[42:50])
        
        #if self.dim[3] != 1:
        #    raise ValueError('Found file with more than 1 time dimension.')
        
        # retain only three dimensions
        self.dim = self.dim[0:3]
        
        
        # get the data type
        self.dtype = unpack(self._EndianChar+'H', self.hdr[70:72])[0]
                
        # get the voxel dimensions
        self.pixdim = unpack(self._EndianChar+'fff',self.hdr[80:92])
        self.vox_offset = int( unpack(self._EndianChar+'f',self.hdr[108:112])[0] )
        self.cal_max, self.cal_min = unpack(self._EndianChar+'ff',self.hdr[124:132])

        
        # calculate the field of view size
        self.fov = []
        for i in range(3):
            self.fov.append(round(self.pixdim[i] * self.dim[i]))
        
        print unpack(self._EndianChar+'ffffff', self.hdr[256:280])
        
        # get the image origin
        #self.orig = unpack(self._EndianChar+'hhh', self.hdr[253:259])
        #print self.orig
        #print 'test'

        # read the header extension
        self.fid.seek(self.sizeof_hdr)
        self.hdr_ext  = self.fid.read(self.vox_offset - self.sizeof_hdr)
        
        # read the s_form martix form comparison and automatic translation between 2 datasets (AG 2007 - YNIC_DV3D)
        self.srow_x = unpack(self._EndianChar+'ffff', self.hdr[280:296])
        self.srow_y = unpack(self._EndianChar+'ffff', self.hdr[296:312])
        self.srow_z = unpack(self._EndianChar+'ffff', self.hdr[312:328])


    def __flip_data(self, data):
        """Do left-right flip of data"""
        datar = reshape(data, (self.dim[2],self.dim[1],self.dim[0]))
        datarf = datar[:,:,range(self.dim[0]-1,-1,-1)]
        data = reshape(datarf, (-1))
        return data
    
        
    def get_image_data(self, lrflip=False):
        """Return a 1d array of the image data."""
        if self.fid == None:
            raise IOError('File not opened.')
         
          
        self.fid.seek(self.vox_offset)   # absolute (from start of file)
        dstr = self.fid.read()
        if self.dtype == nifti_reader.DT_UNSIGNED_CHAR:
            data = fromstring(dstr, dtype=uint8)   
        elif self.dtype == nifti_reader.DT_SIGNED_SHORT:
            data = fromstring(dstr, dtype=short)   
        elif self.dtype == nifti_reader.DT_FLOAT:
            data = fromstring(dstr, dtype=float32)
        elif self.dtype == nifti_reader.DT_SIGNED_INT:
            data = fromstring(dstr, dtype=int32)
        elif self.dtype == nifti_reader.DT_DOUBLE:
            data = fromstring(dstr, dtype=double)
            

        #check platformn dependence /endian-nes
        if (sys.byteorder == 'little') and (self._EndianChar == '>'):
            #if self.dtype != 4:
            data = data.byteswap(True)
        elif (sys.byteorder == 'big') and (self._EndianChar == '<'):
            data = data.byteswap(True)

        try:
            reshape(data, self.dim)     # don't actually keep this, just used for check
        except:
            raise TypeError('Wrong data type for file! (possible internal error) ... make sure you are not trying to load a timeseries dataset!')

        if lrflip:
            data = self.__flip_data(data)
        
        #data = data.byteswap()
        return data



    def duplicate(self, new_filename, imdata, dtype=None, lrflip=False):
        """Generate a new nifti.gz file with same header (but possibly different 
data type and max/min values), and new (1d) image data array."""

        if lrflip:
            imdata = self.__flip_data(imdata)

        new_hdr = self.hdr
        if dtype ==  nifti_reader.DT_FLOAT:
             new_hdr = self.hdr[0:70] + pack(self._EndianChar+'H', nifti_reader.DT_FLOAT) + self.hdr[72:]
        cal_max = float(imdata.max())
        cal_min = float(imdata.min())
        
        new_hdr = new_hdr[0:124] + pack(self._EndianChar+'ff', cal_max, cal_min) + new_hdr[132:]
        
        if new_filename[-7:] != '.nii.gz':
            new_filename += '.nii.gz'

        ofid = gzip.open(new_filename, 'wb')
        ofid.write(new_hdr)
        ofid.write(self.hdr_ext)
        ofid.write(imdata.tostring())
        ofid.close()

        
