from numpy import *
import scipy.io

class MatReader:
    def __init__(self, filename):
        self.dat = scipy.io.loadmat(filename)
        if 'megdata' not in self.dat.keys():
            raise StandardError("ERROR: File %s does not have a 'megdata' matrix" % filename)
        if 'epochs' not in self.dat.keys():
            raise StandardError("ERROR: File %s does not have a 'epochs' matrix" % filename)
        if 'sr' not in self.dat.keys():
            raise StandardError("ERROR: File %s does not have a 'sr' matrix" % filename)

        # Check that self.dat['sr'] is a float or a 0d array or (1,)
        if not isinstance(self.dat['sr'], float):
            if self.dat['sr'].shape != () and self.dat['sr'].shape != (1,):
                raise StandardError("ERROR: The sr matrix in %s must be 1x1" % filename)

        self.sample_rate = float(self.dat['sr'])

        # Check that self.dat['epochs'] is num_epochs x 4
        if self.dat['epochs'].shape[1] != 4:
            raise StandardError("ERROR: The epochs matrix in %s must be Nx4" % filename)

        self.numepochs = self.dat['epochs'].shape[0]


    def CheckEpochLengths(self):
        """This routine returns None if there are different epoch lengths or
           the length if they're all the same"""
        di = self.dat['epochs'][:,1] - self.dat['epochs'][:,0]
        if min(di) == max(di):
            return min(di)
        else:
            return None


    def GetEpochList(self, trig_list, group_list, reject_list = []):
        # Work out which epochs to include:
        # NB: We need to do 1-indexing for consistency with MegReader and epoch rejection code
        # We need the intersection of those which have our trigger and group code
        gce = []
        # Deal with a group wildcard
        if group_list == []:
            gce = range(1, self.numepochs+1)
        else:
            for j in range(1, self.numepochs+1):
                if self.dat['epochs'][j-1, 2] in group_list:
                    gce.append(j)

        tce = []
        # Deal with a trigger wildcard
        if trig_list == []:
            tce = range(1, self.numepochs+1)
        else:
            for j in range(1, self.numepochs+1):
                if self.dat['epochs'][j-1, 3] in trig_list:
                    tce.append(j)

        # Combine the groups and triggers
        el = gce and tce

        # Remove those in the artifact rejection list
        el = list(set(el).difference(set(reject_list)))
        el.sort()

        return el
        
    def GetEpoch_MEG(self, epochNumber):
        """Note that this routine takes 1 based epoch numbers so that it's consistent
           with the MegReader class"""
        # Because the epoch list is supplied in matlab points, not milliseconds
        # and because of the way that pythons indexing works, we subtract one
        # from the start value but not the end value
        sp = self.dat['epochs'][epochNumber-1, 0] - 1
        ep = self.dat['epochs'][epochNumber-1, 1]

        return self.dat['megdata'][sp:ep, :]

