import os
import os.path as op

import nibabel as nib
import numpy as np
import pylab as pl


def read_patch_and_curv(fname, curv_fname):
    curv = nib.freesurfer.read_morph_data(curv_fname)
    curv = (curv < 0).astype(np.float)
    curv[curv == 0.] = 0.35
    curv[curv == 1.] = 0.65

    f = open(fname, 'r')
    f.readline()
    numvert, numquad = map(int, f.readline().split())
    print "numvert=%d, numquad=%d" % (numvert, numquad)

    vertx_pos = np.zeros((numvert, 2))
    vertx_id = np.zeros(numvert, dtype=np.int)
    for l in xrange(numvert):
        vertx_id[l] = abs(int(f.readline().split()[0]))
        vertx_pos[l] = np.array(map(float, f.readline().split()[:2]))

    quads = np.zeros((numquad, 3), dtype=np.int)
    quads = list()

    vertx_full_pos = np.zeros((len(curv), 2))
    vertx_full_pos[vertx_id - 1] = vertx_pos

    for l in xrange(numquad):
        int(f.readline()) # quad_id
        quads.append(map(float, f.readline().split()))

    quads = np.array(quads, dtype=np.int)
    return vertx_full_pos, quads, curv


def plot_flat_map(data_fname, subject, hemi, vmin=1, vmax=50, sign_flip=False):

    subject_dir = op.join(os.environ['SUBJECTS_DIR'], subject)

    curv_fname = op.join(subject_dir, 'surf', '%s.curv' % hemi)

    patch_fname = op.join(subject_dir, 'surf', '%s.cortex.flat.patch.3d.asc' % hemi)

    vertx_full_pos, quads, curv = read_patch_and_curv(patch_fname, curv_fname)
    data = nib.load(data_fname).get_data().ravel()
    if sign_flip:
        data *= -1
    data[np.abs(data) < vmin] = 0.
    tri_mask = np.all(data[quads], axis=1)

    pl.rcParams['axes.edgecolor'] = 'k'
    pl.rcParams['axes.facecolor'] = 'k'
    fig = pl.figure()
    fig.patch.set_color('black')
    pl.gca().set_aspect('equal')
    pl.xticks(())
    pl.yticks(())
    pl.tripcolor(vertx_full_pos[:,0], vertx_full_pos[:,1], quads, curv, cmap=pl.cm.gray,
                 vmin=0, vmax=1, mask=tri_mask)
    pl.tripcolor(vertx_full_pos[:,0], vertx_full_pos[:,1], quads, data, mask=~tri_mask,
                 cmap=pl.cm.jet, vmin=vmin, vmax=vmax)
    color_bar = pl.colorbar() #this one is a little bit
    cbytick_obj = pl.getp(color_bar.ax.axes, 'yticklabels') #tricky
    pl.setp(cbytick_obj, color='w')
    pl.show()
