import sys,os
from math import floor
from numpy import *

from scipy.signal import lfilter, butter
from scipy import flipud



def _filtfilt(b, a, input_vector):
    """input_vector has shape (n,num_channels)
Written by Andrew Straw (http://article.gmane.org/gmane.comp.python.scientific.user/1164/)"""
    forward = lfilter(b, a, input_vector, axis=0)
    return flipud(lfilter(b, a, flipud(forward), axis = 0))




class _MegError(Exception):
    """Class for reporting general errors while parsing MEG data"""
    
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return repr(self.value)

    
def _ParseArg(lst):

    if type(lst) == type(''):   # a string
    
        if lst == '':
            return []
            
        # attempt to convert a string range to a list
        temp = lst.split('-')   
        if (len(temp) == 2) and (int(temp[0]) < int(temp[1])):
            return range(int(temp[0]),int(temp[1])+1)
        else:
            raise _MegError('Invalid string range passed: ' + lst)
            
    elif type(lst) == type(1):     # a single integer
        return [lst]               # turn it into a list
    else:
        return lst



def _ParseLists(list1, list2, total_epochs):

    if not list1:       # an empty list is taken as a wildcard, so return all possible values
        return range(1,total_epochs+1)

    ret_list = []
    for x in list1:
        for y in list2:
            if x == y[1]:
                ret_list.append(y[0])
    return ret_list


def MegFilter(data, sample_rate, lowf, highf, order=4):
        """Apply a Butterworth bandpass filter to data. (Implementation does a low pass and high pass)
NB Ensure that you have removed the dc offset from your data before filtering to avoid large end-effects."""
        
        if (lowf < 0) or (highf > sample_rate/2.0) or (lowf > highf):
            raise MegError('Invalid frequency range for filter: %d - %d Hz' % (lowf, highf))
        
        # design butterworth filter - we apply two separate filters, one low-pass, one high-pass
        Wn_low = highf/(sample_rate/2.0)
        (butter_low_b, butter_low_a) = butter(order, Wn_low, btype='low')
        
        Wn_hi = lowf/(sample_rate/2.0)
        (butter_hi_b, butter_hi_a) = butter(order, Wn_hi, btype='high')
        
        lp_data = _filtfilt(butter_low_b, butter_low_a, data)
        filt_data = _filtfilt(butter_hi_b, butter_hi_a, lp_data)

        return filt_data



def RemoveDC(data, num_points):
    """Remove dc offset from data, with dc calculated from the first num_points points."""
    if len(data.shape) > 1:
        dc = data - mean(data[0:num_points,:],0)
    else:
        dc = data - mean(data[0:num_points],0)
    return dc
    

class MegReader:
    """Class to allow import of 4D MEG data via a .m4d file."""

    MAX_NUM_MEG_CHANNELS = 248    # The number of MEG channels when they are all working
    MAX_NUM_EEG_CHANNELS = 96  #?  # The number of MEG channels when they are all working
    MAX_NUM_REF_CHANNELS = 23     # The number of MEG channels when they are all working
    SIZE_OF_4D_SHORT = 2      # 16 bit data aquisition
    SIZE_OF_4D_FLOAT = 4      
    SHORT_TYPE = 0
    FLOAT_TYPE = 1
    
    REF_CHANNELS = ['GxxA','GyxA','GyyA','GzxA','GzyA','MCxA','MCxaA','MCyA','MCyaA','MCzA','MCzaA',    \
                    'MLxA','MLxaA','MLyA','MLyaA','MLzA','MLzaA','MRxA','MRxaA','MRyA','MRyaA','MRzA','MRzaA'] 


    def __init__(self, filename = None, pretrig_dur = None, epoch_dur = None, post_trigger = False, slice_file = ''):
        if filename:
            self.Open(filename, pretrig_dur, epoch_dur, post_trigger, slice_file)


    def __ParseEvents(self, ev_loc_str, ev_code_str):
        """Private method to zip together event locations and values."""
        events = []
        evi = [int(x) for x in ev_loc_str.split(',')]
        evc = [int(x) for x in ev_code_str.split(',')]
        for i in range(len(evi)):
            events.append( (evi[i], evc[i]) )
        return events
    
    
    def __FindEpoch(self, indx, post_trigger=False):
        """Find which epoch the indx time slice is in. If post_trigger is True, then return the closest trigger prior to the indx, otherwise find the closest trigger after the event, or -1 if no epoch can be found."""
        if post_trigger:
            if self.trig_events[0][0] - indx > 0:
                return -1               # this is not valid in post_trigger
            for i in range(1, self.trig_event_count):
                if self.trig_events[i][0] - indx > 0:
                    return i           # epoch numbers are 1-based, so we don't return i-1
            return self.trig_event_count
        else:
            if self.trig_events[-1][0] - indx < 0:
                return -1               # this is not valid when not in post_trigger
            for i in range(self.trig_event_count-2, -1, -1):
                if self.trig_events[i][0] - indx < 0:
                    return i+2          # epoch numbers are 1-based
            return 1
        # (never gets here)
        return -1   # make it obvious this is an invalid epoch (0 is also invalid, but less obviously so)
    
    
     
    def __GetEpoch(self, epoch_num):
        """Read an epoch of raw data from the file."""
        
        if (epoch_num < 1) or (epoch_num > self.num_epochs):
            raise _MegError('SetEpoch() requested out of range 1-%d' % self.num_epochs)                
        
        # seek to start of requested epoch
        if self.continuous_mode:
            pos = (self.trig_events[epoch_num-1][0] - self.pre_trig) * self.total_channels * self.size_of_format
            if pos < 0:
                raise _MegError('Tried to seek before file start in SetEpoch()')
            self.data_in.seek(pos)
        else:
            self.data_in.seek((epoch_num - 1) * self.num_bytes_per_epoch)
        
        # load the data
        # !!TODO take care of different data types (float32)
        self.unordered_data = reshape(fromfile(self.data_in, dtype=short, \
                                        count=self.slices_per_epoch*self.total_channels), \
                                                        [self.slices_per_epoch, self.total_channels])
            
    def __GetSliceRange(self, startp, endp):
        """Read a range of slices of raw data from the file"""
        
        if (startp < 0) or (endp > self.total_slices) or (startp >= endp):
            raise MegError('Invalid start and end points for GetSliceRange(): %d, %d' % (startp, endp))
        
        self.data_in.seek(self.num_bytes_per_slice * startp)
        
        # load the data
        # !!TODO take care of different data types (float32)
        self.unordered_data = reshape(fromfile(self.data_in, dtype=short, \
                                        count=(endp-startp)*self.total_channels), \
                                                        [endp-startp, self.total_channels])
        
        self.last_epoch_read = -1   # force a reread of the data on next GetEpoch


        
    def __rdc(self, data):
        """Remove dc according to the current value of self.dcremove."""
        
        if not self.dcremove:
            return data
        if self.dcremove == 1:
            num_points = self.pre_trig
        else:
            num_points = self.slices_per_epoch
            
        return RemoveDC(data, num_points)


    def __GetSliceRange_XXX(self, startp, endp, chan_col, upbsg, transp):
   
        self.__GetSliceRange(startp, endp)
        
        # get the required columns, in order, in a contiguous block of memory
        if transp:
            data = zeros([len(chan_col), endp-startp], float32)
            for i in range(len(chan_col)):
                if chan_col[i] > -1:
                    data[i,:] = self.unordered_data[:, chan_col[i]-1] *  upbsg[i]        
        else:
            data = zeros([endp-startp, len(chan_col)], float32)
            for i in range(len(chan_col)):
                if chan_col[i] > -1:
                    data[:,i] = self.unordered_data[:, chan_col[i]-1] *  upbsg[i]

        # might want to set self.unordered_data=None here to free up some memory??
        
        return data
   
       
    def GetSliceRange_MEG(self, startp, endp, transp=False):
        """Return a contiguous array of MEG data for the requested slice range."""
        
        return self.__GetSliceRange_XXX( startp, endp, self.meg_chan_col, self.meg_upbsg, transp)
    
       
    def GetSliceRange_EEG(self, startp, endp, transp=False):
        """Return a contiguous array of EEG data for the requested slice range."""
        
        return self.__GetSliceRange_XXX( startp, endp, self.eeg_chan_col, self.eeg_upbsg, transp)
       
       
    def GetSliceRange_REF(self, startp, endp, transp=False):
        """Return a contiguous array of REF data for the requested slice range."""
        
        return self.__GetSliceRange_XXX( startp, endp, self.ref_chan_col, self.ref_upbsg, transp)
    
    
    def __GetTrigRespChannels(self):
        """Read in entire data set in chunks, and extract the trigger and response channels for the whole run."""

        self.data_in.seek(0)
        
        self.trigger_channel_data = empty( (self.total_slices), short)
        self.response_channel_data = empty( (self.total_slices), short)
        # load the data
        # !!TODO take care of different data types (float32)
        chunk_size = 10000
        for startp in range(0, self.total_slices, chunk_size):
            chunk_size = min(self.total_slices - startp, chunk_size)    # last chunk will (almost always) be smaller
            data = reshape(fromfile(self.data_in, dtype=short, \
                                        count=chunk_size*self.total_channels), \
                                                        [chunk_size, self.total_channels])
            self.trigger_channel_data[startp:startp+chunk_size]  = data[:, self.trigger_column-1]
            self.response_channel_data[startp:startp+chunk_size]  = data[:, self.response_column-1]
    
    
    def GetTriggerChannel(self):
        """Return an array of trigger channel data for the entire run."""
        
        if self.trigger_channel_data == None:
            self.__GetTrigRespChannels()
            
        return self.trigger_channel_data
    
    
    def GetResponseChannel(self):
        """Return an array of response channel data for the entire run."""
        
        if self.response_channel_data == None:
            self.__GetTrigRespChannels()
            
        return self.response_channel_data
        
        

    
    def __GetEpoch_XXX(self, epoch_num, chan_col, upbsg):
        
        if epoch_num != self.last_epoch_read:
            self.__GetEpoch(epoch_num)
            self.last_epoch_read = epoch_num
        
        # get the required columns, in order, in a contiguous block of memory
        data = zeros([self.slices_per_epoch, len(chan_col)], float32)
        
        for i in range(len(chan_col)):
            if chan_col[i] > -1:
                data[:,i] = self.unordered_data[:, chan_col[i]-1] *  upbsg[i]
        
        return data

    
    def GetEpoch_MEG(self, epoch_num):
        """Return an array of MEG data for the requested epoch, with columns ordered A1-A248.
The data will be filtered according to SetFilter() (default is no filter)."""
        
        return self.__GetEpoch_XXX(epoch_num, self.meg_chan_col, self.meg_upbsg)

    
    def GetEpoch_EEG(self, epoch_num):
        """Return an array of EEG data for the requested epoch, with columns ordered E1-E96.
The data will be filtered according to SetFilter() (default is no filter)."""
        
        return self.__GetEpoch_XXX(epoch_num, self.eeg_chan_col, self.eeg_upbsg)

    
    def GetEpoch_REF(self, epoch_num):
        """Return an array of EEG data for the requested epoch, with columns ordered GxxA-MRzaA
 - MegReader.REF_CHANNELS has the full list. The data will be filtered according to SetFilter() (default is no filter)."""

        return self.__GetEpoch_XXX(epoch_num, self.ref_chan_col, self.ref_upbsg)
        
        

    def GetEpochList(self, trig_list, group_list, reject_list = []):
        """Return a list of epochs which have trigger code in trig_list and group code(s) in group_list,
and which are not in the (optional) reject_list. The lists can be normal python lists of integers, or can
be in the form of a string containing a range, e.g. '3-7' which will assume all values in the range
inclusive of end values."""

        trig_list = _ParseArg(trig_list)        
        group_list = _ParseArg(group_list)

        # get lists of epochs which satisfy the trig_list and group_list requirements separately
        t_list = _ParseLists(trig_list, self.TriggerCodes, self.num_epochs)        
        g_list = _ParseLists(group_list, self.GroupCodes, self.num_epochs)        
        
        # only want epochs that are in both, and not in the reject list
        ret_list = []
        for x in t_list:
            for y in g_list:
                if (x == y) and (reject_list.count(x) == 0):
                    ret_list.append(x)
                    break
        ret_list.sort()

        return ret_list
    
        
  
    
    def Open(self, filename, pretrig_dur = None, epoch_dur = None, post_trigger = False, slice_file = ''):
        """Open and parse the .m4d file associated with the raw data file <filename>. Use post_trigger=True if you have group codes which are presented after a trigger rather than before it (default is post_trigger=False, which assumes that group codes are presented before triggers)."""
        
        self.data_filename = filename
        self.m4d_filename = filename + '.m4d'
                
        try:
            f = open(self.m4d_filename, 'r')
        except:
            raise IOError('Cannot open associated .m4d file: ' + self.m4d_filename)

        d = f.readlines()
        f.close()
        
        self.total_events = 0
        self.trig_event_count = 0
        self.group_event_count = 0
        self.resp_event_count = 0
        self.latency = None     # This will not get defined in following loop if epoch mode data
        
        meg_chan_names_str = meg_chan_index_str = None
        eeg_chan_names_str = eeg_chan_index_str = None
        ref_chan_names_str = ref_chan_index_str = None
        
        self.trigger_channel_data = None
        self.response_channel_data = None

        for x in d:
        
            sp = x.find(':')
            label = x[:sp]
            info = x[sp+1:].strip()
            
            
            if label == "MSI.Meg_Position_Information.Begin":
                #self.ParseCoilInfo(&fin)
                pass
            elif label == "MSI.TotalEpochs":
                self.num_epochs = int(info)
            elif label == "MSI.SampleFrequency":
                self.sample_rate = float(info)
            elif label == "MSI.FirstLatency":
                self.latency = float(info)
            elif label == "MSI.SlicesPerEpoch":
                self.slices_per_epoch = int(info)
            elif label == "MSI.TotalChannels":
                self.total_channels = int(info)
            elif label == "MSI.ChannelOrder":
                chan_names_str = info
            elif label == "MSI.ChannelScale":
                chan_scale_str = info
            elif label == "MSI.ChannelGain":
                chan_gain_str = info
            elif label == "MSI.ChannelUnitsPerBit":
                chan_upb_str = info
            elif label == "MSI.MegChanCount":
                self.meg_chan_count = int(info)
            elif label == "MSI.MegChanNames":
                meg_chan_names_str = info
            elif label == "MSI.MegChanIndex":
                meg_chan_index_str = info
            elif label == "MSI.RefChanCount":
                self.ref_chan_count = int(info)
            elif label == "MSI.RefChanNames":
                ref_chan_names_str = info
            elif label == "MSI.RefChanIndex":
                ref_chan_index_str = info
            elif label == "MSI.EegChanCount":
                self.eeg_chan_count = int(info)
            elif label == "MSI.EegChanNames":
                eeg_chan_names_str = info
            elif label == "MSI.EegChanIndex":
                eeg_chan_index_str = info
            elif label == "MSI.TriggerIndex":
                self.trigger_column = int(info)
            elif label == "MSI.ResponseIndex":
                self.response_column = int(info)
            elif label == "MSI.TotalEvents":
                self.total_events = int(info)
            elif label == "MSI.Events":
                events_str = info
            elif label == "MSI.EventCodes":
                event_code_str = info
            elif label == "MSI.TrigEventCount":
                self.trig_event_count = int(info)
            elif label == "MSI.TrigEvents":
                trig_events_str = info
            elif label == "MSI.TrigEventCodes":
                trig_event_code_str = info
            elif label == "MSI.GroupEventCount":
                self.group_event_count = int(info)
            elif label == "MSI.GroupEvents":
                group_events_str = info
            elif label == "MSI.GroupEventCodes":
                group_event_code_str = info
            elif label == "MSI.RespEventCount":
                self.resp_event_count = int(info)
            elif label == "MSI.RespEvents":
                resp_events_str = info
            elif label == "MSI.RespEventCodes":
                resp_event_code_str = info
            elif label == "MSI.Format":
                if info == "SHORT":
                    self.format = MegReader.SHORT_TYPE
                    self.size_of_format = MegReader.SIZE_OF_4D_SHORT
                elif info == "FLOAT":
                    self.format = MegReader.FLOAT_TYPE
                    self.size_of_format = MegReader.SIZE_OF_4D_FLOAT
                else:
                    raise TypeError('Unknown data format in MSI.Format field of: ' + self.m4d_filename)
        
        self.total_slices  = self.slices_per_epoch * self.num_epochs

        if not epoch_dur:
            if self.trig_event_count != 0:
                if not self.latency:
                    raise _MegError('Please specify pretrig duration and total epoch length in ms')
                self.pre_trig = int(-self.latency * self.sample_rate)   #latency is in secs
            else:
                self.pre_trig = 0   #continuous mode, but no triggers, so treat as one epoch of data
            self.continuous_mode = 0      # assume epoch mode
        else:
            self.pre_trig = int(pretrig_dur * self.sample_rate / 1000)
            self.slices_per_epoch = int(epoch_dur * self.sample_rate / 1000)
            self.continuous_mode = 1
        
    
        self.num_bytes_per_slice = self.total_channels * self.size_of_format
        self.num_bytes_per_epoch = self. num_bytes_per_slice * self.slices_per_epoch

        #
        # Relate MEG channel names to columns in raw data file
        #
        
        # Initialise the channel column array for MEG channels
        self.meg_chan_col = [-1] * MegReader.MAX_NUM_MEG_CHANNELS     # -1 indicates the channel is not working
        
        # Parse the lists of MEG channel names and corresponding data columns
        if meg_chan_names_str:
            mcnl = [int(x.strip()[1:]) for x in meg_chan_names_str.split(',')]  # the working meg channel numbers
            mcil = [int(x) for x in meg_chan_index_str.split(',')]              # corresponding columns in data
            
            for i in range(len(mcnl)):
                self.meg_chan_col[mcnl[i] - 1] = mcil[i]


        #
        # Relate EEG channel names to columns in raw data file
        #
        
        # Initialise the channel column array for MEG channels
        self.eeg_chan_col = [-1] * MegReader.MAX_NUM_EEG_CHANNELS     # -1 indicates the channel is not working
        
        # Parse the lists of EEG channel names and corresponding data columns
        if eeg_chan_names_str:
            ecnl = [int(x.strip()[1:]) for x in eeg_chan_names_str.split(',')]  # the working eeg channel numbers
            ecil = [int(x) for x in eeg_chan_index_str.split(',')]              # corresponding columns in data
            
            for i in range(len(ecnl)):
                self.eeg_chan_col[ecnl[i] - 1] = ecil[i]


        #
        # Relate REF channel names to columns in raw data file
        #
        
        # first create a dictionary to enumerate the ref channel names since they have no natural ordering
        self.ref_dict = {}
        for i in range(len(MegReader.REF_CHANNELS)):
            self.ref_dict[MegReader.REF_CHANNELS[i]] = i
        
        # Initialise the channel column array for MEG channels
        self.ref_chan_col = [-1] * MegReader.MAX_NUM_REF_CHANNELS     # -1 indicates the channel is not working
        
        # Parse the lists of REF channel names and corresponding data columns
        if ref_chan_names_str:
            rcnl = [self.ref_dict[x] for x in ref_chan_names_str.split(',')]    # the working ref channel numbers
            rcil = [int(x) for x in ref_chan_index_str.split(',')]              # corresponding columns in data
            
            for i in range(len(rcnl)):
                self.ref_chan_col[rcnl[i] - 1] = rcil[i]
        
        
        
        # Initialise units-per-bit_times_scale_divided_by_gain
        self.meg_upbsg = [0] * MegReader.MAX_NUM_MEG_CHANNELS
        self.eeg_upbsg = [0] * MegReader.MAX_NUM_EEG_CHANNELS
        self.ref_upbsg = [0] * MegReader.MAX_NUM_REF_CHANNELS

        cnl = [x.strip() for x in chan_names_str.split(',')]  # all channel names, incl MEG, EEG, Ref, etc.
        csl = [float(x) for x in chan_scale_str.split()]   # corresponding channel scale
        cgl = [float(x) for x in chan_gain_str.split()]   # corresponding channel gain
        cul = [float(x) for x in chan_upb_str.split()]   # corresponding channel units-per-bit

        for i in range(len(cnl)):

            if cnl[i][0] == 'A':            # an MEG channel
                channel_num = int(cnl[i][1:])
                self.meg_upbsg[channel_num - 1] = cul[i] * csl[i] / cgl[i]

            if cnl[i][0] == 'E':            # an EEG channel
                channel_num = int(cnl[i][1:])
                self.eeg_upbsg[channel_num - 1] = cul[i] * csl[i] / cgl[i]

            if (cnl[i][0] == 'M') or (cnl[i][0] == 'G'):    # a REF channel
                channel_num = self.ref_dict[cnl[i]]                 # NB this one is zero-based
                self.ref_upbsg[channel_num] = cul[i] * csl[i] / cgl[i]
                


        # Parse events
        if self.total_events > 0:
            self.events = self.__ParseEvents(events_str, event_code_str)
        
        if self.trig_event_count > 0:
            self.trig_events = self.__ParseEvents(trig_events_str, trig_event_code_str)
            
            if slice_file == '':
            
                # Test to see if the epoch duration is short enough to only contain one trigger
                if (self.trig_event_count > 1) and (self.slices_per_epoch > self.trig_events[1][0] - self.trig_events[0][0]):
                    raise _MegError('Multiple triggers occur in specified epoch duration')
                
                # Test to see if the pretrigger duration will fit before first trigger in data
                if self.pre_trig > self.trig_events[0][0]:
                    raise _MegError('Pre-trigger duration too large for first trigger in data')
            
            if self.continuous_mode:       
                self.num_epochs = self.trig_event_count

        if self.group_event_count > 0:
            self.group_events = self.__ParseEvents(group_events_str, group_event_code_str)

        if self.resp_event_count > 0:
            self.resp_events = self.__ParseEvents(resp_events_str, resp_event_code_str)
        
        
        # Check if aquisition was stopped before last epoch was finished...(!)
        if((self.trig_event_count > 0) and
                (self.trig_events[self.trig_event_count-1][0] - self.pre_trig + \
                                                self.slices_per_epoch-1  > self.total_slices)):
            # silently dump last epoch...
            if self.trig_event_count > 0:
                self.trig_events = self.trig_events[:-1]
                self.trig_event_count -= 1
            if self.group_event_count > 0:
                self.group_events = self.group_events[:-1]
                self.group_event_count -= 1
            if self.resp_event_count > 0:
                self.resp_events = self.resp_events[:-1]
                self.resp_event_count -= 1
        
        # NB triggers define the epochs, first epoch is number 1 (not 0)     
        self.TriggerCodes = [ (i+1, self.trig_events[i][1]) for i in range(self.trig_event_count)]
        
        self.GroupCodes = [ (self.__FindEpoch(self.group_events[i][0], post_trigger=post_trigger), self.group_events[i][1])  \
                                                            for i in range(self.group_event_count)]
        
        for x in self.GroupCodes:
            if x[0] == -1:
                print '\nWARNING: failed to find associated epoch for at least one group code. This can be caused when the group value on the trigger line does not occur inside the epoch you have specified.\n'
                break
        
#        # This code also returns the timing for the response (in ms, relative to the trigger point)
        self.ResponseCodes = []
        for i in range(self.resp_event_count):
            ep_num = self.__FindEpoch(self.resp_events[i][0], post_trigger=True)
            if ep_num == -1:
                continue
            self.ResponseCodes.append( (ep_num, self.resp_events[i][1], \
                    (self.resp_events[i][0] - self.trig_events[ep_num-1][0])*1000/self.sample_rate) )
        
        
        # Open the actual data file for reading
        try:
            self.data_in = open(self.data_filename, 'rb')
        except:
            raise IOError('Unable to open raw data file: ' + self.data_filename)

        # indicate that we have not yet read in any epoch data
        self.last_epoch_read = -1
        


if __name__ == '__main__':
    m = MegReader('/mnt/megdata/zb101/1somato/04%05%06@12:15/2/c,rfhp1.0Hz', 500, 1500)
    #d=m.GetEpoch_MEG(1)
    print m.GetEpochList([504,760],'',[760]) 



