#!/usr/bin/python
#-----------------------------------------------------------------------------
# Copyright 2006-2007, Andre D Gouws and York NeuroImaging Centre
#
# Name :          ynicDV3D - ynic Data Viewer 3D
# Description :   multi-modal imaging visualisation toolkit
# Author :        Andre D Gouws
# Created :       2006-October-06
# Last Update:    2007-March-05
# Notes :         requires the following additional files:
#                   
#                   /help/ynicDV3D_wx_interactor_help.py
#                   /help/ynicDV3D_wx_main_help.py
#                   processing.py
#                   wxProcessing.py
#                   tYNI_imginfo.py
#                   tYNI_imports.py
#                   nifti_reader.py
#                   runLogo.py
#                   ynicDV3Dlogo.jpg
#                   func.lut
#                   tYNI.lut
#                   ynicDV3D_wxdialogs.py
#
# History:        BETA 0.5 (07.Oct.2006) - code to be reviewed by M.H.
# Dependencies:   Python 2.4.1 or later, VTK 5.0 or later,
#                  FSL(FMRIB, Oxford) 3.2 or later                                     
#
# Usage :        TODO:  functional data:
#                           - allow 2d to load pos/neg only / both
#                           - ?f/F button functionality depending on 3d / merged func data
#                       Fix 'unload' / exiting of python threads
#                       Help file contents
#
#-----------------------------------------------------------------------------


import os,sys,os.path
import threading

def get_path(path):
    return os.path.join(os.path.split(__file__)[0], path)

# Check arguments
if len(sys.argv) < 2:
    print 'A structural nifti (.nii.gz) or analyze (.hdr) volume is required'
    sys.exit()


#############################################################
# --------- Startup window ------------

class StartupThread ( threading.Thread ):
    def run(self):
        os.system('pythonw %s' % get_path('runLogo.py'))


StartupThread().start()


#############################################################

#############################################################
## --------- Main Help window ------------
#
#
#class MainHelpThread( threading.Thread ):
#    def run(self):
#        os.system('pythonw ./ynicDV3D_wx_main_help.py')
#


#############################################################
# --------- Interactor Help window ------------

class InteractorHelpThread( threading.Thread ):
    def run(self):
        os.system('pythonw %s' % get_path('help/ynicDV3D_wx_interactor_help.py'))

#############################################################
# --------- Progress tracking window ------------

class ProgressTrackingThread( threading.Thread ):
    def run(self):
        os.system('pythonw %s' % get_path('wxProcessing.py'))

#############################################################
 
print '\nynicDV3D - York NeuroImaging Centre Data Viewer - 3D (c) 2006-2007.'
print '\n ... starting ynicDV3d ...'
print '\nBuilt on Python 2.4.1 and VTK 5.0'
print '\nThe main window will launch shortly ... Once it has started click on the main window to launch the menu.\n\n\n'

import vtk
from vtk.util.colors import tomato, banana
from struct import unpack, calcsize #from surface
from numpy import *
import scipy as Sci
import scipy.linalg
import sys
import time
from tYNI_imginfo import *
from nifti_reader import *
import tYNI_imports as tYNI
from subprocess import Popen, PIPE
import processing

#globals
surf_num = 0
surf_min = []
surf_max = []
fiber_data_loaded = 0   
fibers_visible = 0
fibers_pickable = 0
func_data_loaded = 0
func_data_visible = 0
cortex_data_loaded = 0
ortho_planes_on = 0
profileList = []
plotData = []
lineData = []
instr_visible = 1
x_slice_pos = 88    
y_slice_pos = 128
z_slice_pos = 128
current_mpr_orient = 0
fiber_group = 0
fiber_group_counter = []
counter = 0
a_fibers = []
ortho_planes_on = 0
original_LU_table = vtk.vtkLookupTable()
pos_colour_bar_generated = 0
neg_colour_bar_generated = 0
pos_scalarBar = []
neg_scalarBar = []
functional_2d_volume = []
functionalLut = []
mpr_showing_2d = 0

#for dipole data
rows = []
dipole_file_loaded = 0
dipole_rows = []
dipole_transform = 0
dipole_min_time = 0
dipole_max_time = 0
dipole_current_time = 0
dipole_mat = []
cortex_data = 0
origActor = vtk.vtkLODActor()
selectActor = vtk.vtkLODActor()
planes = 0
boxWidget = 0
boxWidget2 = 0


######## -- SYSTEM ROUTINES -- #############################################################
# --------- File handling ------------

def my_file_dialog(processing_type, extra_var1, extra_var2):
    p = Popen('pythonw %s %s %s %s' % (get_path('ynicDV3D_wxdialogs.py'), processing_type,
              extra_var1, extra_var2), stdout=PIPE, shell=True)
    p.wait()
    (so, se) = p.communicate()
    return so.strip()

######## -- READERS -- #############################################################
# --------- Import routines for Nifti and Analyze------------
x=0
y=0

def nifti2vtkImageData(filename):
    global x,y,data
    """Read in an image in .nii.gz format and return a vtkImageData object"""

    # parse the header
    ar = nifti_reader(filename)
    spacing = ar.pixdim

    data = ar.get_image_data()

    if ar.dtype == nifti_reader.DT_SIGNED_SHORT:
        bytes_per_sample = 2
        data_type = vtk.VTK_SHORT
    elif ar.dtype == nifti_reader.DT_FLOAT:
        bytes_per_sample = 4
        data_type = vtk.VTK_FLOAT
    elif ar.dtype == nifti_reader.DT_SIGNED_INT:
        bytes_per_sample = 4
        data_type = vtk.VTK_INT
    
    x = data.max()
    y = data.min()
    vol = vtk.vtkImageImport()
    size = len(data) * bytes_per_sample
    vol.CopyImportVoidPointer(data, size)
    vol.SetDataScalarType(data_type)
    vol.SetNumberOfScalarComponents(1)
    extent = vol.GetDataExtent()
    dim = list(ar.dim)
    dim.reverse()

    vol.SetDataExtent(extent[0],extent[0]+dim[2]-1,
                                        extent[2],extent[2]+dim[1]-1,
                                        extent[4],extent[4]+dim[0]-1)
    vol.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
                                         extent[2],extent[2]+dim[1]-1,
                                         extent[4],extent[4]+dim[0]-1)

    vol.SetDataSpacing(-ar.pixdim[0], ar.pixdim[1], ar.pixdim[2])

        
    
    o = [0,0,0]
    vol.SetDataOrigin(ar.fov[0]*ar.pixdim[0], o[1], o[2])
    return vol


def process_functional_input_for_2d(structural_vol, functional_vol):
    global x,y,overlay_max,overlay_min,maxval,minval
    """Read in an image in .nii.gz format and return a vtkImageData object"""

    # parse the header
    ar_struct = nifti_reader(structural_vol)
    spacing = ar_struct.pixdim
    grey_data = float32(ar_struct.get_image_data())

    ar_func = nifti_reader(functional_vol)
    spacing = ar_func.pixdim
    func_data = float32(ar_func.get_image_data())
    
    x = func_data.max()
    y = func_data.min()
    overlay_max = x
    overlay_min = y
    #these next 2 vals are the 'lower' bounds of the negative and positive extremes
    # i.e. minval is closer to zero than overlay_min = -7 etc.
    if overlay_max > 0:
        maxval = my_file_dialog(12,overlay_max,0)
        if maxval == 'skip':
            maxval = overlay_max
        else:
            maxval = float(maxval)
            if maxval < 0:
                maxval = -maxval
    else:
        maxval = overlay_max
        
    if overlay_min < 0:
        minval = my_file_dialog(13,overlay_min,0)
        if minval == 'skip':
            minval = overlay_min
        else:
            minval = float(minval)
            if minval > 0:
                minval = -minval
    else:
        minval = overlay_min


    #merge the data sets
    pos_mask = clip(func_data,0,overlay_max)
    pos_mask = clip(pos_mask-maxval,0,1)
    
    neg_mask = clip(func_data,overlay_min,0)
    neg_mask = clip(neg_mask-minval,-1,0)
    
    pos_mask_ones = ceil(pos_mask)
    neg_mask_ones = floor(neg_mask)
    mask_ones = pos_mask_ones - neg_mask_ones
    mask_ones = -(mask_ones-1)
    grey_max = grey_data.max()
    masked_200scaled_grey_data = ((grey_data/grey_max)*199)*mask_ones
    
    pos_func_data = clip(func_data,0,func_data.max())
    scaled200_225_pos_func_data1 = ((pos_func_data/func_data.max())*25)+200
    scaled200_225_pos_func_data = scaled200_225_pos_func_data1*pos_mask_ones    
    neg_func_data = clip(func_data,func_data.min(),0)
    scaled225_250_neg_func_data1 = ((neg_func_data/-func_data.min())*25)+250
    scaled225_250_neg_func_data = scaled225_250_neg_func_data1*neg_mask_ones
    
    if minval == overlay_min:
        data = masked_200scaled_grey_data + scaled200_225_pos_func_data
    elif maxval == overlay_max:
        data = masked_200scaled_grey_data - scaled225_250_neg_func_data
    else:
        data = masked_200scaled_grey_data + scaled200_225_pos_func_data - scaled225_250_neg_func_data
        
    #start the processing
    
    if ar_func.dtype == nifti_reader.DT_SIGNED_SHORT:
        bytes_per_sample = 2
        data_type = vtk.VTK_SHORT
    elif ar_func.dtype == nifti_reader.DT_FLOAT:
        bytes_per_sample = 4
        data_type = vtk.VTK_FLOAT
    elif ar_func.dtype == nifti_reader.DT_SIGNED_INT:
        bytes_per_sample = 4
        data_type = vtk.VTK_INT
    

    vol = vtk.vtkImageImport()
    size = len(data) * bytes_per_sample
    vol.CopyImportVoidPointer(data, size)
    vol.SetDataScalarType(data_type)
    vol.SetNumberOfScalarComponents(1)
    extent = vol.GetDataExtent()
    dim = list(ar_func.dim)
    dim.reverse()

    vol.SetDataExtent(extent[0],extent[0]+dim[2]-1,
                                        extent[2],extent[2]+dim[1]-1,
                                        extent[4],extent[4]+dim[0]-1)
    vol.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
                                         extent[2],extent[2]+dim[1]-1,
                                         extent[4],extent[4]+dim[0]-1)

    vol.SetDataSpacing(-ar_func.pixdim[0], ar_func.pixdim[1], ar_func.pixdim[2])

        
    o = [0,0,0]
    vol.SetDataOrigin(ar_func.fov[0]*ar_func.pixdim[0], o[1], o[2])
    return vol


def analyze2vtkImageData(filename):
    #global vol
    """Read in an image in analyze format and return a vtkImageData object"""

    # remove any file ending
    if (filename[-4:] == '.hdr') or (filename[-4:] == '.img'):
        filename = filename[:-4]

    # parse the header
    ar = AnalyzeReader(filename)
    #spacing = [abs(x) for x in ar.pixdim]
    spacing = ar.pixdim
    if ar.dtype[0] == 4:
        bytes_per_sample = 2
        data = fromfile(filename + '.img', short)
        data_type = vtk.VTK_SHORT
    elif ar.dtype[0] == 16:
        bytes_per_sample = 4
        data = fromfile(filename + '.img', 'f')
        data_type = vtk.VTK_FLOAT
    elif ar.dtype[0] == 2:
        bytes_per_sample = 1
        data = fromfile(filename + '.img', 'B')
        data_type = vtk.VTK_UNSIGNED_CHAR         
    elif ar.dtype[0] == 8:
        bytes_per_sample = 8
        data = fromfile(filename + '.img', 'i')
        data_type = vtk.VTK_INT
        


    vol = vtk.vtkImageImport()
    size = len(data) * bytes_per_sample
    vol.CopyImportVoidPointer(data, size)
    vol.SetDataScalarType(data_type)
    vol.SetNumberOfScalarComponents(1)
    extent = vol.GetDataExtent()

    dim = list(ar.dim[0:3])
    dim.reverse()

    vol.SetDataExtent(extent[0],extent[0]+dim[2]-1,
                                        extent[2],extent[2]+dim[1]-1,
                                        extent[4],extent[4]+dim[0]-1)
    vol.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
                                         extent[2],extent[2]+dim[1]-1,
                                         extent[4],extent[4]+dim[0]-1)

    vol.SetDataSpacing(spacing[0], spacing[1], spacing[2])
    
    if ar.pixdim[0] < 0:
        vol.SetDataOrigin(-ar.GetFOV()[0], 0, 0)
        
    return vol


########  END Readers #############################################################





######## -- ADD FUNCTIONAL DATA -- #############################################################
# --------- Routine ------------
    

def load_functional_data(): 
    global ren, renWin, surf_min, surf_max, surf_num, func_data_loaded, data_to_load, pos_colour_bar_generated, neg_colour_bar_generated, pos_scalarBar, neg_scalarBar, mpr_showing_2d, functionalLut
    #Get user to select functional data
    data_to_load = my_file_dialog(1,0,0)
    if data_to_load == '':
        print 'Load cancelled' 
        return
    if (data_to_load[-7:] != '.nii.gz'):
        print '\nSorry! All loaded functional data currently needs to be in NIFTI_GZ format. - analyze will be supported soon! You can use avwchfiletype to convert to nifti first ...\n'
        return
    else:
        f_or_3d_correct = 0
        while f_or_3d_correct == 0:
            flat_or_3d = my_file_dialog(19,0,0)
            if flat_or_3d == '':
                print 'Load cancelled' 
                return            
            if flat_or_3d == '3':
                f_or_3d_correct = 1
                create_3d_functional_data()
            elif flat_or_3d == 'm':
                f_or_3d_correct = 1
                create_2d_functional_data(data_to_load)
            
            else:
                print '\nPlease type m if you want to merge the functional data with the current structural or 3 if you want to load the functional data as 3d blobs ...'


#create an intermediate fsl merged overlay file for stats and structure and load this in with a predefined LUT
def create_2d_functional_data(overlay_filename): 
    global ren, renWin, surf_min, surf_max, surf_num, func_data_loaded, data_to_load, vol, pos_colour_bar_generated, neg_colour_bar_generated, pos_scalarBar, neg_scalarBar, mpr_showing_2d, functional_2d_volume, functionalLut,volx,overlay, process_functional_input_for_2d

    
    
    functionalLut = []
    structural_filename = sys.argv[1]
        
    volx = process_functional_input_for_2d(structural_filename, overlay_filename)

    if pos_colour_bar_generated == 1:
        ren.RemoveActor2D(pos_scalarBar)
    if neg_colour_bar_generated == 1:    
        ren.RemoveActor2D(neg_scalarBar)
        renWin.Render()
        
    processing.set_start_processing()
    ProgressTrackingThread().start()
        
    functional_2d_volume = volx.GetOutput()
    w = functional_2d_volume.GetDimensions()
    
    r = reshape(fromfile(get_path('func.lut'),dtype=float, count=-1, sep=','),(250,4))
    functionalLut = vtk.vtkWindowLevelLookupTable()
    functionalLut.SetNumberOfTableValues(250)
    for i in range(len(r)):
        functionalLut.SetTableValue(i,r[i][0],r[i][1],r[i][2],r[i][3])

    planeWidgetX.SetUserControlledLookupTable(1)
    planeWidgetY.SetUserControlledLookupTable(1)
    planeWidgetZ.SetUserControlledLookupTable(1)
    
    planeWidgetX.SetLookupTable(functionalLut)
    planeWidgetY.SetLookupTable(functionalLut)
    planeWidgetZ.SetLookupTable(functionalLut)
    planeWidgetX.SetWindowLevel(250,125)
    planeWidgetY.SetWindowLevel(250,125)
    planeWidgetZ.SetWindowLevel(250,125)
    
    planeWidgetX.SetInput(functional_2d_volume)
    planeWidgetY.SetInput(functional_2d_volume)
    planeWidgetZ.SetInput(functional_2d_volume)
    
    #no prepare the colour bars
    if (overlay_max > 0) and (overlay_max != maxval):
        r = reshape(fromfile(get_path('func.lut'),dtype=float, count=-1, sep=','),(250,4))
        pos_component_Lut = vtk.vtkWindowLevelLookupTable()
        pos_component_Lut.SetNumberOfTableValues(25)
        pos_component_Lut.SetRange(maxval, overlay_max)
        for i in range(25):
            pos_component_Lut.SetTableValue(i,r[i+200][0],r[i+200][1],r[i+200][2],r[i+200][3])        
        
        pos_scalarBar =  vtk.vtkScalarBarActor()
        pos_scalarBar.SetLookupTable(pos_component_Lut)
        pos_scalarBar.SetTitle('Positive')

        pos_scalarBar.GetTitleTextProperty().SetFontSize(24)
        pos_scalarBar.GetPositionCoordinate().SetCoordinateSystemToNormalizedViewport()
        pos_scalarBar.GetPositionCoordinate().SetValue(0.05,0.55)
        pos_scalarBar.SetWidth(0.06)
        pos_scalarBar.SetHeight(0.4)
        ren.AddActor2D(pos_scalarBar)
        pos_scalarBar.SetNumberOfLabels(5)
        pos_scalarBar.SetOrientationToVertical()
        pos_colour_bar_generated = 1
    
    if (overlay_min < 0) and (overlay_min != minval):
        r2 = reshape(fromfile(get_path('func.lut'),dtype=float, count=-1, sep=','),(250,4))
        neg_component_Lut = vtk.vtkWindowLevelLookupTable()
        neg_component_Lut.SetNumberOfTableValues(25)
        neg_component_Lut.SetRange(overlay_min, minval)
        for i in range(25):
            #have to reverse the order .. not 225+ but 250- because of requirements of scalarbar neg-to-more-neg
            neg_component_Lut.SetTableValue(i,r2[i+225][0],r2[i+225][1],r2[i+225][2],r2[i+225][3])        
        
        neg_scalarBar =  vtk.vtkScalarBarActor()
        neg_scalarBar.SetLookupTable(neg_component_Lut)
        neg_scalarBar.SetTitle('Negative')

        neg_scalarBar.GetTitleTextProperty().SetFontSize(24)
        neg_scalarBar.GetPositionCoordinate().SetCoordinateSystemToNormalizedViewport()
        neg_scalarBar.GetPositionCoordinate().SetValue(0.05,0.05)
        neg_scalarBar.SetWidth(0.06)
        neg_scalarBar.SetHeight(0.4)
        ren.AddActor2D(neg_scalarBar)
        neg_scalarBar.SetNumberOfLabels(5)
        neg_scalarBar.SetOrientationToVertical()
        neg_colour_bar_generated = 1
                
    renWin.Render()
    resetPlanes()
    mpr_showing_2d = 1
    processing.set_fin_processing()
    
    print ' ... Completed 2d - merge ...\n'


#create the functional data as multiple isosurfaces
def create_3d_functional_data(): 
    global ren, renWin, surf_min, surf_max, surf_num, func_data_loaded, data_to_load, rMin, rMax, first_dipole_point
        
    num_bins = my_file_dialog(18,0,0)
    if num_bins == '':
        print 'Load cancelled' 
        return
    if num_bins == '0':
        print 'you asked for zero bins - Load cancelled' 
        return   
    bins_to_display = my_file_dialog(17,0,0)
    if bins_to_display == '':
        print 'Load cancelled' 
        return
    if bins_to_display == '0':
        print 'you asked for zero bins - Load cancelled' 
        return
    if bins_to_display == num_bins:
        bins_to_display = str(int(bins_to_display) - 1)
    if int(num_bins) > 10:
        num_bins = '10'
        print 'Number of bins currently restricted to maximum 10. Proceeding with 10 ...'
    if int(bins_to_display) >= int(num_bins):
        bins_to_display = str(int(num_bins) - 1)
    if data_to_load.endswith('.hdr'):
        print '\nA nifti (.nii.gz) volume is required'
        #volr = analyze2vtkImageData(data_to_load) 
        return
        
    elif data_to_load.endswith('.nii.gz'):
        volr = nifti2vtkImageData(data_to_load) 

    else:
        print '\nA nifti (.nii.gz) volume is required'
        return
    
    print '\n ... Loading functional data as multiple 3d isosurfaces ...\n'
    
    processing.set_start_processing()
    ProgressTrackingThread().start()
    
    vr = volr.GetOutput()
    vr.Update()
    
    rMin, rMax = vr.GetScalarRange()

    func_data_loaded = 1
    
    #colour tables
    
    pos_max_colors = []
    pos_max_colors.append((1, 0, 0,))
    pos_max_colors.append((1, 0.1, 0)) 
    pos_max_colors.append((1, 0.2, 0)) 
    pos_max_colors.append((1, 0.3, 0)) 
    pos_max_colors.append((1, 0.4, 0)) 
    pos_max_colors.append((1, 0.5, 0)) 
    pos_max_colors.append((1, 0.6, 0)) 
    pos_max_colors.append((1, 0.8, 0)) 
    pos_max_colors.append((1, 0.9, 0)) 
    pos_max_colors.append((1, 1, 0))
    
    pos_min_colors = []
    pos_min_colors.append((0, 0, 1,))
    pos_min_colors.append((0, 0.1, 1)) 
    pos_min_colors.append((0, 0.2, 1)) 
    pos_min_colors.append((0, 0.3, 1)) 
    pos_min_colors.append((0, 0.4, 1)) 
    pos_min_colors.append((0, 0.5, 1)) 
    pos_min_colors.append((0, 0.6, 1)) 
    pos_min_colors.append((0, 0.8, 1)) 
    pos_min_colors.append((0, 0.9, 1)) 
    pos_min_colors.append((0, 1, 1))
    
    #routines
    
    def get_surface(vr, value):
        global surf_min, surf_max, surf_num
        
        isoSurfaceExtractor = vtk.vtkContourFilter()
        isoSurfaceExtractor.SetInput(vr)
        isoSurfaceExtractor.SetValue(0, value)
        isoSurfaceExtractor.ComputeGradientsOn()
        isoSurfaceExtractor.UseScalarTreeOn()
        
        surfNormals = vtk.vtkPolyDataNormals()
        surfNormals.SetInput(isoSurfaceExtractor.GetOutput())
        surfNormals.SetFeatureAngle(60.0)
        
        surfMapper = vtk.vtkPolyDataMapper()
        surfMapper.SetInput(surfNormals.GetOutput())
        surfMapper.ScalarVisibilityOff()
        surfMapper.Update()
        
        surf = vtk.vtkActor()
        #surf.GetProperty().SetFrontfaceCulling(0)
        surf.GetProperty().SetLineWidth(0.005) 
        surf.SetMapper(surfMapper)
        
        return surf
    
    surf_min = []
    surf_max = []
    for i in range(int(num_bins)-int(bins_to_display),int(num_bins)):
        print 'Extracting contours %d/%d' %(i+1, int(num_bins))
        surf_max.append(get_surface(vr, rMax*(i*0.1)))
        surf_max[-1].GetProperty().SetColor(pos_max_colors[i])
        surf_max[-1].GetProperty().SetOpacity(0.1*i)
        surf_max[-1].GetProperty().SetOpacity(0.1*i)
        surf_max[-1].VisibilityOff()
        if rMin == 0:
            continue
        else:
            surf_min.append(get_surface(vr, rMin*(i*0.1)))
            surf_min[-1].GetProperty().SetColor(pos_min_colors[i])
            surf_min[-1].GetProperty().SetOpacity(0.1*i)
            surf_min[-1].VisibilityOff()
    
    # --- Add the Beamformer Actors
    for s in surf_min:
        ren.AddActor(s)
    for s in surf_max:
        ren.AddActor(s)
    surf_num = len(surf_max)
    surf_max[surf_num-1].VisibilityOn()
    
    if rMin == 0:
        pass
    else:
        surf_min[surf_num-1].VisibilityOn()

    processing.set_fin_processing()
    
    
    
    def myCallBack_BF(obj, event):
        global renWin, iact, currentActor, currentActor2, surf_num
        key=iact.GetKeySym()
        
        #for planes
        if key == 'X':
            surf_num -= 1
            if surf_num < 0:
                surf_num = 0
            if func_data_visible == 0:
                pass
            elif func_data_visible == 1:
                surf_max[surf_num].VisibilityOn()
                if rMin == 0:
                    pass
                else:
                    surf_min[surf_num].VisibilityOn()

            elif func_data_visible == 2:
                surf_max[surf_num].VisibilityOn()
            elif func_data_visible == 3:
                if rMin == 0:
                    pass
                else:
                    surf_min[surf_num].VisibilityOn()

        if key == 'Y':
            surf_num += 1
            if surf_num == len(surf_max)+1:
                surf_num = len(surf_max)
            surf_max[surf_num-1].VisibilityOff()
            if rMin == 0:
                pass
            else:            
                surf_min[surf_num-1].VisibilityOff()
        
        renWin.Render()
        
    iact.AddObserver("KeyPressEvent",myCallBack_BF)
    return
    print ' ... Completed ...'
    

########  END Add functional data #############################################################




######## -- ORTHO VIEWs -- #############################################################
# --------- Routine for outputting orthogonal views t a third window ------------
 
def ortho_updates():
    global ren, renWin, ortho_planes_on, pos_colour_bar_generated, neg_colour_bar_generated, pos_scalarBar, neg_scalarBar
    
    #markers for views L-R-A-P-I-S
    markers = []
    marker_pos = []
    marker_txt = []
    marker_list = []
    
    marker_pos.append((15,145))
    marker_pos.append((285,145))
    marker_pos.append((148,285))
    marker_pos.append((148,15))
    
    marker_txt.append("L")
    marker_txt.append("R")  
    marker_txt.append("I")
    marker_txt.append("S")
    marker_txt.append("A")
    marker_txt.append("P")
    
    for j in range(4):
        atext = vtk.vtkTextActor()
        atext.SetDisplayPosition(marker_pos[j][0],marker_pos[j][1])
        tprop = atext.GetTextProperty()
        tprop.SetFontSize(14)
        tprop.SetFontFamilyToArial()
        tprop.SetColor(1, 1, 0)
        tprop.BoldOn()
        marker_list.append(atext)
    
    for j in range(4):
        ren.AddActor(marker_list[j])

    
    #catch the current main window settings so we can revert to these later
    ttt = ren.GetActiveCamera()
    #set projection to parallel rather than perspective
    ttt.SetParallelProjection(1)
    sz = renWin.GetSize()
    p = planeWidgetX.GetTexturePlaneProperty()
    p.SetOpacity(0)
    p = planeWidgetY.GetTexturePlaneProperty()
    p.SetOpacity(0)
    a,b,c,d,e,g = ttt.GetClippingRange(), ttt.GetDistance(),  ttt.GetEyeAngle(),  ttt.GetFocalDisk(),  ttt.GetPosition(),  ttt.GetViewUp()
    renWin.SetSize(300,300)
    
    if planeWidgetX.GetPlaneProperty().GetOpacity() == 1:
        edges_on = 1
    else:
        edges_on = 0
    
    ## SET UP THE FIRST VIEW
    xxx = ren.GetActiveCamera()
    xxx.SetPosition(90,145,-717)
    xxx.SetViewUp(0,1,0)
    xxx.SetViewAngle(20)

    if pos_colour_bar_generated == 1:
        pos_scalarBar.VisibilityOff()           
    if neg_colour_bar_generated == 1:
        neg_scalarBar.VisibilityOff()

    if edges_on == 1:
        planeWidgetX.GetPlaneProperty().SetOpacity(1.0)
        planeWidgetY.GetPlaneProperty().SetOpacity(1.0)
        planeWidgetZ.GetPlaneProperty().SetOpacity(0.0)
        
    #update markers
    marker_list[0].SetInput(marker_txt[1])
    marker_list[1].SetInput(marker_txt[0])
    marker_list[2].SetInput(marker_txt[4])
    marker_list[3].SetInput(marker_txt[5])
    for j in range(4):
        marker_list[j].GetTextProperty().SetColor(0,0,1)
    renWin.Render()
    xgrab = vtk.vtkWindowToImageFilter()
    xgrab.SetInput(renWin)
    xgrab.Update()
    x1=xgrab.GetOutput()
    
    ## SET UP THE SECOND VIEW

    p = planeWidgetZ.GetTexturePlaneProperty()
    p.SetOpacity(0)
    p = planeWidgetX.GetTexturePlaneProperty()
    p.SetOpacity(1)
    xxx.SetPosition(1000,145,149)
    xxx.SetViewUp(0,0,1)
    
    if edges_on == 1:
        planeWidgetX.GetPlaneProperty().SetOpacity(0.0)
        planeWidgetY.GetPlaneProperty().SetOpacity(1.0)
        planeWidgetZ.GetPlaneProperty().SetOpacity(1.0)
    
    #update markers
    marker_list[0].SetInput(marker_txt[5])
    marker_list[1].SetInput(marker_txt[4])
    marker_list[2].SetInput(marker_txt[3])
    marker_list[3].SetInput(marker_txt[2])
    for j in range(4):
        marker_list[j].GetTextProperty().SetColor(1,0,0)
    renWin.Render()
    ygrab = vtk.vtkWindowToImageFilter()
    ygrab.SetInput(renWin)
    ygrab.Update()
    y1=ygrab.GetOutput()
    
    ## SET UP THE THIRD VIEW

    ##SET UP THE COLOUR BARS (IF PRESENT)

    if pos_colour_bar_generated == 1:
        pos_scalarBar.VisibilityOn()           
        pos_scalarBar.SetWidth(0.14)
        pos_scalarBar.GetPositionCoordinate().SetValue(0.01,0.55)
    if neg_colour_bar_generated == 1:
        neg_scalarBar.VisibilityOn()
        neg_scalarBar.SetWidth(0.14)
        neg_scalarBar.GetPositionCoordinate().SetValue(0.01,0.05)

    p = planeWidgetX.GetTexturePlaneProperty()
    p.SetOpacity(0)
    p = planeWidgetY.GetTexturePlaneProperty()
    p.SetOpacity(1)
    
    xxx.SetPosition(90,1000,154)
    xxx.SetViewUp(0,0,1)
    
    if edges_on == 1:
        planeWidgetX.GetPlaneProperty().SetOpacity(1.0)
        planeWidgetY.GetPlaneProperty().SetOpacity(0.0)
        planeWidgetZ.GetPlaneProperty().SetOpacity(1.0)
    
    #update markers
    marker_list[0].SetInput(marker_txt[1])
    marker_list[1].SetInput(marker_txt[0])
    marker_list[2].SetInput(marker_txt[3])
    marker_list[3].SetInput(marker_txt[2])
    for j in range(4):
        marker_list[j].GetTextProperty().SetColor(1,1,0)
    renWin.Render()
    zgrab = vtk.vtkWindowToImageFilter()
    zgrab.SetInput(renWin)
    zgrab.Update()
    z1=zgrab.GetOutput()
    
    #RESET - reset original window and re-render it
    for j in range(4):
        ren.RemoveActor(marker_list[j])

    if pos_colour_bar_generated == 1:
        pos_scalarBar.VisibilityOn()           
    if neg_colour_bar_generated == 1:
        neg_scalarBar.VisibilityOn()     
    
    if pos_colour_bar_generated ==1:
        pos_scalarBar.SetWidth(0.06)
        pos_scalarBar.GetPositionCoordinate().SetValue(0.05,0.55)
           
    if neg_colour_bar_generated ==1:
        neg_scalarBar.SetWidth(0.06)
        neg_scalarBar.GetPositionCoordinate().SetValue(0.05,0.05)
    
    #set projection back to perspective rather than parallel 
    ttt.SetParallelProjection(0)
    
    renWin.SetSize(sz)
    p = planeWidgetX.GetTexturePlaneProperty()
    p.SetOpacity(1)
    p = planeWidgetY.GetTexturePlaneProperty()
    p.SetOpacity(1)
    p = planeWidgetZ.GetTexturePlaneProperty()
    p.SetOpacity(1)
    
    xxx.SetPosition(e)
    xxx.SetViewUp(g)
    xxx.SetViewAngle(30)
    renWin.Render()
    
    map_x = vtk.vtkImageMapper()
    map_x.SetInput(x1)
    map_x.SetColorLevel(128)
    map_x.SetColorWindow(255)
    
    map_y = vtk.vtkImageMapper()
    map_y.SetInput(y1)
    map_y.SetColorLevel(128)
    map_y.SetColorWindow(255)
    
    map_z = vtk.vtkImageMapper()
    map_z.SetInput(z1)
    map_z.SetColorLevel(128)
    map_z.SetColorWindow(255)
    
    act_x.SetMapper(map_x)
    act_x.SetPosition(0,600)
    act_y.SetMapper(map_y)
    act_y.SetPosition(0,300)
    act_z.SetMapper(map_z)

    ren3.AddActor(act_x)
    ren3.AddActor(act_y)
    ren3.AddActor(act_z)
    
    ortho_planes_on = 1
    if edges_on ==1:
        planeWidgetX.GetPlaneProperty().SetOpacity(1.0)
        planeWidgetY.GetPlaneProperty().SetOpacity(1.0)
        planeWidgetZ.GetPlaneProperty().SetOpacity(1.0)
    
    renWin.Render()
    return

########  END Ortho views #############################################################





######## -- MPR Volume -- #############################################################
# --------- Routine for displayng the 3d-MPR MRI data volume ------------


file_to_load = sys.argv[1]

if file_to_load[-4:] == '.hdr':
    vol = analyze2vtkImageData(sys.argv[1]) 

elif file_to_load[-7:] == '.nii.gz':
    vol = nifti2vtkImageData(sys.argv[1]) 

else:
    print 'A structural nifti (.nii.gz) or analyze (.hdr) volume is required'
    sys.exit()

vol.Update()


#get info for accurate placement of roi in slice coordinates
d = vol.GetOutput()

bounds = d.GetBounds()
dims = d.GetDimensions()
xspace = []
yspace = []
zspace = []
xspace = Sci.linspace(bounds[0],bounds[1],dims[0]) 
yspace = Sci.linspace(bounds[2],bounds[3],dims[1]) 
zspace = Sci.linspace(bounds[4],bounds[5],dims[2]) 

xMin, xMax, yMin, yMax, zMin, zMax = vol.GetOutput().GetWholeExtent()

# An outline is shown for context.
outline = vtk.vtkOutlineFilter()
outline.SetInput(vol.GetOutput())

outlineMapper = vtk.vtkPolyDataMapper()
outlineMapper.SetInput(outline.GetOutput())

outlineActor = vtk.vtkActor()
outlineActor.SetMapper(outlineMapper)

# The shared picker enables us to use 3 planes at one time
# and gets the picking order right
picker = vtk.vtkCellPicker()
picker.SetTolerance(0.005)


# The 3 image plane widgets are used to probe the dataset.
planeWidgetX = vtk.vtkImagePlaneWidget()
planeWidgetX.DisplayTextOn()
planeWidgetX.SetInput(vol.GetOutput())
planeWidgetX.SetPlaneOrientationToXAxes()
planeWidgetX.SetSliceIndex(x_slice_pos)
planeWidgetX.SetPicker(picker)
planeWidgetX.SetHandleSize(500)
prop1 = planeWidgetX.GetPlaneProperty()
prop1.SetColor(1, 0, 0)
planeWidgetX.SetKeyPressActivation(0)

planeWidgetY = vtk.vtkImagePlaneWidget()
planeWidgetY.DisplayTextOn()
planeWidgetY.SetInput(vol.GetOutput())
planeWidgetY.SetPlaneOrientationToYAxes()
planeWidgetY.SetSliceIndex(y_slice_pos)
planeWidgetY.SetPicker(picker)
prop2 = planeWidgetY.GetPlaneProperty()
prop2.SetColor(1, 1, 0)
planeWidgetY.SetLookupTable(planeWidgetX.GetLookupTable())
planeWidgetY.SetKeyPressActivation(0)

original_LU_table = planeWidgetX.GetLookupTable()


planeWidgetZ = vtk.vtkImagePlaneWidget()
planeWidgetZ.DisplayTextOn()
planeWidgetZ.SetInput(vol.GetOutput())
planeWidgetZ.SetPlaneOrientationToZAxes()
planeWidgetZ.SetSliceIndex(z_slice_pos)
planeWidgetZ.SetPicker(picker)
prop3 = planeWidgetZ.GetPlaneProperty()
prop3.SetColor(0, 0, 1)
planeWidgetZ.SetLookupTable(planeWidgetX.GetLookupTable())
planeWidgetZ.SetKeyPressActivation(0)

######## END MPR Volume -- #############################################################





######## -- SETUP MAIN WINDOW, ACTORS AND INTERACTORS -- ###############################

# Create the RenderWindow and Renderer
ren = vtk.vtkRenderer()
renWin = vtk.vtkRenderWindow()
renWin.AddRenderer(ren)

rent = vtk.vtkRenderer()
renWint = vtk.vtkRenderWindow()
renWint.AddRenderer(rent)


# Add the outline actor to the renderer, set the background color and size
#ren.AddActor(outlineActor)
renWin.SetSize(900, 900)
ren.SetBackground(0.0, 0.0, 0.0)

picker1 = vtk.vtkCellPicker()    

# Set the interactor for the widgets
iact = vtk.vtkRenderWindowInteractor()
iact.SetRenderWindow(renWin)
style = vtk.vtkInteractorStyleTrackballCamera()
iact.SetInteractorStyle(style) 
planeWidgetX.SetInteractor(iact)
iact.SetPicker(picker1)
planeWidgetX.SetWindowLevel(2536, 2687)


planeWidgetX.On()
planeWidgetY.SetInteractor(iact)
planeWidgetY.SetWindowLevel(2536, 2687)

planeWidgetY.On()
planeWidgetZ.SetInteractor(iact)
planeWidgetZ.SetWindowLevel(2536, 2687)

planeWidgetZ.On()

dipoleSphereActor = vtk.vtkActor()
dipoleLineActor = vtk.vtkActor()

# stuff for the boxwidget interactor
roi = vtk.vtkSphereSource()
roi.SetRadius(15.0)
roi.SetCenter(88,250,250) 
roiMapper = vtk.vtkPolyDataMapper()
roiMapper.SetInput(roi.GetOutput())
roiActor = vtk.vtkActor()
roiActor.SetMapper(roiMapper)
roiActor.GetProperty().SetColor(1,1,0)
ren.AddActor(roiActor)

boxWidget = vtk.vtkBoxWidget()
boxWidget.SetInteractor(iact)
boxWidget.SetPlaceFactor(1.25)
boxWidget.SetProp3D(roiActor)
boxWidget.PlaceWidget()
boxWidget.SetHandleSize(0.003)

def myCallback2(widget, event_string):
    t = vtk.vtkTransform()
    boxWidget.GetTransform(t)
    boxWidget.GetProp3D().SetUserTransform(t)
    p = boxWidget.GetProp3D().GetBounds()
    
boxWidget.AddObserver("InteractionEvent", myCallback2)
boxWidget.On()

roiActor.VisibilityOff()
boxWidget.Off()
renWin.Render()
renWin.SetPosition(220,0)


########## END INITAL SETUP of MAIN INTERACTORS, WINDOW ETC #####################################


######## -- SETUP MENU WINDOWS -- ###############################################################
def myMenu():
    global boxes, box_coords, labels, label_pos, button_coordlist, buttonlist, ren2, renWin2, iren2
    ######## -- Opens a second window with menu functionality -- ###############################
    
    #  objects ############
    #boxes
    boxes = []
    box_coords = []
    box_coords.append((9,51,115,130))
    box_coords.append((9,51,94.5,113.5))
    box_coords.append((9,51,62.5,92.5))
    box_coords.append((9,51,35.5,61.5))
    box_coords.append((9,51,14.5,34.5))
    
    #labels
    labels = []
    label_pos = []
    label_list = []
    
    label_pos.append((25,560))
    label_pos.append((65,535))
    label_pos.append((65,515))
    label_pos.append((25,485))
    label_pos.append((40,469))
    label_pos.append((40,455))
    label_pos.append((40,441))
    label_pos.append((40,427))
    label_pos.append((40,413))
    label_pos.append((25,385))
    label_pos.append((65,365))
    label_pos.append((65,345))
    label_pos.append((65,328))
    label_pos.append((65,310))
    label_pos.append((25,290))
    label_pos.append((120,469))
    
        #top,bot,fro,bac,L,R
    label_pos.append((135,455))
    label_pos.append((165,455))
    label_pos.append((135,441))
    label_pos.append((165,441))
    label_pos.append((135,427))
    label_pos.append((165,427))
    
    label_pos.append((25,175))
    label_pos.append((30,155))
    label_pos.append((65,135))
    label_pos.append((40,105))
    label_pos.append((40,85))
    label_pos.append((40,65))
    label_pos.append((40,45))
    label_pos.append((145,32))
    
    label_pos.append((160,560))
    label_pos.append((160,385))
    label_pos.append((160,175))
    label_pos.append((160,235))
    
    label_pos.append((25,235))
    label_pos.append((65,215))
    label_pos.append((65,195))
    label_pos.append((125,155))
    
    labels.append("CORTICAL SURFACE ( B )")
    labels.append("Visible On/Off ( v-v )")
    labels.append("Opacity Up/Down ( b-n )")
    labels.append("MRI PLANES")
    labels.append("Sagittal ( S )")
    labels.append("Axial ( A )")
    labels.append("Coronal ( C )")
    labels.append("Edges ( E )")
    labels.append("Reset ( L )")
    labels.append("FIBER DATA ( T )")
    labels.append("FIBERS: Visible On/Off ( g-G )")
    labels.append("PICKER: Visible On/Off ( h-h )")
    labels.append("Size Up/Down")
    labels.append("Individual fibers pickable")
    labels.append("Colours")
    labels.append("View from: ( H )")
    labels.append("sup")
    labels.append("inf")
    labels.append("ant")
    labels.append("pos")
    labels.append("lft")
    labels.append("rgt")
    labels.append("FUNCTIONAL DATA ( F )")
    labels.append("Cycle 3d ( f )")
    labels.append("Threshold Up/Down ( X/Y )")
    labels.append("ORTHO-VIEWS ON / UPDATE ( x )")
    labels.append("Ortho Views: save image ( j )")
    labels.append("Main Window: save image ( J )")
    labels.append("Load help ( z )")
    labels.append("QUIT (q)")
    
    labels.append("LOAD")
    labels.append("LOAD")
    labels.append("LOAD")
    labels.append("LOAD")
    
    labels.append("DIPOLE DATA ( D )")
    labels.append("Visible On/Off ( d-d )")
    labels.append("Timepoint ( ctrl <- / -> )")
    labels.append("Switch 2d ( M )")
    

    
    #draw the labels
    for j in range(len(label_pos)):
        atext = vtk.vtkTextActor()
        atext.SetDisplayPosition(label_pos[j][0], label_pos[j][1])
        atext.SetInput(labels[j])
        tprop = atext.GetTextProperty()
        tprop.SetFontSize(10)
        tprop.SetFontFamilyToArial()
        tprop.SetColor(0, 0, 0.0)
    
        label_list.append(atext)
    
    for j in range(len(label_pos)-5-3,len(label_pos)-3):
        label_list[j].GetTextProperty().SetColor(1, 0, 0)
    
    count = 0
    
    for u in range(len(box_coords)):
        count += 1
        xmin = box_coords[u][0]
        xmax = box_coords[u][1]
        ymin = box_coords[u][2]
        ymax = box_coords[u][3]
        
        polygonPoints = vtk.vtkPoints()
        polygonPoints.SetNumberOfPoints(4)
        polygonPoints.InsertPoint(0, xmin, ymin, 0)
        polygonPoints.InsertPoint(1, xmin, ymax, 0)
        polygonPoints.InsertPoint(2, xmax, ymax, 0)
        polygonPoints.InsertPoint(3, xmax, ymin, 0)
        aPolygon = vtk.vtkPolygon()
        aPolygon.GetPointIds().SetNumberOfIds(4)
        aPolygon.GetPointIds().SetId(0, 0)
        aPolygon.GetPointIds().SetId(1, 1)
        aPolygon.GetPointIds().SetId(2, 2)
        aPolygon.GetPointIds().SetId(3, 3)
        aPolygonGrid = vtk.vtkUnstructuredGrid()
        aPolygonGrid.Allocate(1, 1)
        aPolygonGrid.InsertNextCell(aPolygon.GetCellType(), aPolygon.GetPointIds())
        aPolygonGrid.SetPoints(polygonPoints)
        aPolygonMapper = vtk.vtkDataSetMapper()
        aPolygonMapper.SetInput(aPolygonGrid)
        aPolygonActor = vtk.vtkActor()
        aPolygonActor.SetMapper(aPolygonMapper)
        #aPolygonActor.AddPosition(6, 0, 2)
        aPolygonActor.GetProperty().SetDiffuseColor(1, 1, 1)
        aPolygonActor.SetPickable(0)
        boxes.append(aPolygonActor)
    
    #buttons
    
    button_coordlist=[]
        
        #cortex
    button_coordlist.append((15,123,0)) #toggle 1a
    button_coordlist.append((15,123,0)) #toggle 1b
    button_coordlist.append((13,119,0))
    button_coordlist.append((17,119,0))
    
        #3d mpr
    button_coordlist.append((13,109,0)) #toggle 3a
    button_coordlist.append((13,106,0)) #toggle 4a
    button_coordlist.append((13,103,0)) #toggle 5a
    button_coordlist.append((13,100,0)) #borders
    
        #fiber picker
    button_coordlist.append((15,87,0))
    button_coordlist.append((15,83,0))
    button_coordlist.append((13,79,0))
    button_coordlist.append((17,79,0))
    button_coordlist.append((15,75,0))
    
        #colours
    button_coordlist.append((13,68,0))
    button_coordlist.append((16,68,0))
    button_coordlist.append((19,68,0))
    button_coordlist.append((22,68,0))
    button_coordlist.append((25,68,0))
    button_coordlist.append((28,68,0))
    button_coordlist.append((31,68,0))
    button_coordlist.append((34,68,0))
    button_coordlist.append((37,68,0))
    button_coordlist.append((40,68,0))
    button_coordlist.append((13,65,0))
    button_coordlist.append((16,65,0))
    button_coordlist.append((19,65,0))
    button_coordlist.append((22,65,0))
    button_coordlist.append((25,65,0))
    button_coordlist.append((28,65,0))
    button_coordlist.append((31,65,0))
    button_coordlist.append((34,65,0))
    button_coordlist.append((37,65,0))
    button_coordlist.append((40,65,0))
        
        #beamformer/fmri data
    button_coordlist.append((11,43,0))
    button_coordlist.append((13,38,0))
    button_coordlist.append((17,38,0))
    
        #last panel
    button_coordlist.append((13,32,0))
    button_coordlist.append((13,28,0))
    button_coordlist.append((13,24,0))
    button_coordlist.append((13,20,0))
    
        #load buttons
    button_coordlist.append((48,128,0))
    button_coordlist.append((48,91,0))
    button_coordlist.append((48,46.5,0))
    
        #QUIT
    button_coordlist.append((48,17,0))
    
        #RESET MPR SLICES
    button_coordlist.append((13,97,0)) #reset
    
        #Viewing Angle
    button_coordlist.append((33,106,0))
    button_coordlist.append((47,106,0))
    button_coordlist.append((33,103,0))
    button_coordlist.append((47,103,0))
    button_coordlist.append((33,100,0))
    button_coordlist.append((47,100,0))
    
        #dipole load button
    button_coordlist.append((48,59,0))
    
        #dipole buttons
    button_coordlist.append((15,55,0))
    button_coordlist.append((13,51,0))
    button_coordlist.append((17,51,0))
    
    #functional switch between 2d and structural
    button_coordlist.append((31,43,0))

    
    
    button_colours = []
        #cortex
    button_colours.append((1,0,0))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((1,0,0))
    button_colours.append((0,0,1))
    
        #3d mpr
    button_colours.append((1,0,0))
    button_colours.append((1,0,0))
    button_colours.append((1,0,0))
    button_colours.append((1,0,0))
    
        #fiber picker
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((1,0,0))
    button_colours.append((0,0,1))
    button_colours.append((0.3,0.3,0.3))
    
        #colour buttons
    button_colours.append((0,0,0)) #black
    button_colours.append((0.2,0.2,0.2)) #dark gray
    button_colours.append((0.5,0.5,0.5)) #mid gray 
    button_colours.append((0.7,0.7,0.7)) #light gray
    button_colours.append((1,1,1)) #white
    button_colours.append((0,0,0.5)) #blu1
    button_colours.append((0,0,1)) #b
    button_colours.append((0.6,0.8,1)) #blu2
    button_colours.append((0,1,1)) #cyan
    button_colours.append((0.6,0,1)) #purp
    button_colours.append((0,1,0)) #g
    button_colours.append((0.5,1,0.5)) #light green
    button_colours.append((0.6,0.2,0)) #brown
    button_colours.append((1,0.5,0.9)) #pink
    button_colours.append((1,0,1)) #magenta
    button_colours.append((1,0,0)) #r
    button_colours.append((1,0.2,0)) #dark orange
    button_colours.append((1,0.5,0)) #orange
    button_colours.append((1,1,0)) #yellow
    button_colours.append((1,1,0.5)) #light yellow
    
        #functional data
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((1,0,0))
    button_colours.append((0,0,1))
    
        #last
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    
        #load
    button_colours.append((1,1,0))
    button_colours.append((1,1,0))
    button_colours.append((1,1,0))
    
        #quit
    button_colours.append((1,1,0))
    
        #reset MPR
    button_colours.append((1,1,0))
    
        #Viewing angle
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((0.3,0.3,0.3))
    
        #dipole load
    button_colours.append((1,1,0))
    
        #other dipole buttons
    button_colours.append((0.3,0.3,0.3))
    button_colours.append((1,0,0))
    button_colours.append((0,0,1))

    #functional switch between 2d and structural
    button_colours.append((0.3,0.3,0.3))

    buttonlist=[]
    for i in range(len(button_coordlist)):
        if i in (2,3,10,11,34,35,53,54): #list of arrow / cone actors
            button = vtk.vtkConeSource()
            button.SetCenter(button_coordlist[i])
            button.SetRadius(1.5)
            button.SetHeight(2.5)
            bMapper = vtk.vtkPolyDataMapper()
            bMapper.SetInput(button.GetOutput())
            b1 = vtk.vtkActor()
            b1.SetMapper(bMapper)
            b1.GetProperty().SetColor(button_colours[i])
            b1.SetPickable(1)
            b1.my_name = i+1
            buttonlist.append(b1)
            if i in (2,10,34): #up arrows
                button.SetDirection(0,1,0)
            elif i in (3,11,35): #down arrows
                button.SetDirection(0,-1,0)
            elif i in (53,999): #left arrows
                button.SetDirection(-1,0,0)
                
        else:
            button = vtk.vtkSphereSource()
            button.SetCenter(button_coordlist[i])
            button.SetRadius(1.2)
            bMapper = vtk.vtkPolyDataMapper()
            bMapper.SetInput(button.GetOutput())
            b1 = vtk.vtkActor()
            b1.SetMapper(bMapper)
            b1.GetProperty().SetColor(button_colours[i])
            b1.SetPickable(1)
            b1.my_name = i+1
            buttonlist.append(b1)
    
    ## INTERACTOR for menu window 
    
    pickerMenu = vtk.vtkPicker()
    pickerMenu.SetTolerance(0.001)
    
    ren2= vtk.vtkRenderer() 
    for j in range(len(boxes)):
        ren2.AddActor(boxes[j])
        
    for k in range(0,6):
        #TODO - decide if some boxes should be shaded to highlight apt functionality
        #i.e. the simple stuff the psychologists might want
        
        pass
        #boxes[k].GetProperty().SetColor(0,0,0.5)
        #to highlight functionality for Psych
        #boxes[8].GetProperty().SetColor(1,1,0.5)
        #boxes[11].GetProperty().SetColor(1,1,0.5)
        #boxes[12].GetProperty().SetColor(1,1,0.5)
        
    for j in range(len(buttonlist)):
        ren2.AddActor2D(buttonlist[j])
    
    buttonlist[0].VisibilityOff()
    
    for j in range(len(label_list)):
        ren2.AddActor(label_list[j])
    
    ren2.SetBackground( 0.0, 0.0, 0.0 )
    
    renWin2 = vtk.vtkRenderWindow()
    renWin2.AddRenderer( ren2 )
    renWin2.SetSize( 220, 600 )
    renWin2.SetPosition(0,0)
    
    style = vtk.vtkInteractorStyleFlight()
    iren2 = vtk.vtkRenderWindowInteractor()
    iren2.SetInteractorStyle(style)
    iren2.SetRenderWindow(renWin2)
    
    iren2.SetPicker(pickerMenu)
    
    
    ## FUNCTIONALITY OF MENU
    
    def ButtonEvent(obj, event):
        global ren3, renWin3, Rotating, Panning, Zooming, v, fibers_visible, func_data_loaded, func_data_visible, fibers_pickable, fiber_data_loaded, cortex_data_loaded, meshActor,lineData, plotData, profileList, ortho_planes_on, surf_num, iact, renWin2, current_mpr_orient, fiber_group, fiber_group_counter, x_const, y_const, z_const, dipole_rows, dipole_transform, dipole_min_time, dipole_max_time, dipole_current_time, dipoleSphereActor, line, lineMapper, Sphere, SphereMapper, dipoleLineActor, dipole_mat, cortex_data, rMin, first_dipole_point, vol, original_LU_table, functionalLut, functional_2d_volume, pos_scalarBar, pos_colour_bar_generated, resetPlanes, mpr_showing_2d
       
        
        #add extra funtionality options by enabling CTRL - selecting 
        if obj.GetControlKey():
            if event == "LeftButtonPressEvent":
                r = iren2.GetEventPosition()
                s = r[0],r[1],0
                pickerMenu.Pick(s, ren2)
                t = pickerMenu.GetActor()
                u = t.GetProperty()
                v = u.GetDiffuseColor()
                        
                if t.my_name in range(14,34):
                    for j in range(len(profileList)):
                        if profileList[j].GetVisibility() == 1:
                            profileList[j].GetProperty().SetDiffuseColor(v)
                             
        
        
        if event == "KeyPressEvent":
            key=iact.GetKeySym()
        
            #quit app
            if key == 'q':
                iact.TerminateApp()
        
        
        ##AUTO-RESET: since this window has objects which an also be interacted with in 3D
        # we need to have a wayto stop this window form doing rotating / zooming etc its 
        #objects .. or at least set the back very quickly if they are changed
        
        if event == "LeftButtonReleaseEvent":
            r = ren2.GetActiveCamera()
            r.SetPosition(30.0, 72.5, 235.26399390218796)  
            r.SetViewAngle(30) 
            r.SetViewUp(0,1,0)
               
        if event == "RightButtonPressEvent":
            r = ren2.GetActiveCamera()
    
            r.SetPosition(30.0, 72.5, 235.26399390218796)  
            r.SetViewAngle(30) 
            r.SetViewUp(0,1,0)   
        if event == "RightButtonReleaseEvent":
            r = ren2.GetActiveCamera()
    
            r.SetPosition(30.0, 72.5, 235.26399390218796)  
            r.SetViewAngle(30) 
            r.SetViewUp(0,1,0)   
        if event == "MiddleButtonPressEvent":
            r = ren2.GetActiveCamera()
    
            r.SetPosition(30.0, 72.5, 235.26399390218796)  
            r.SetViewAngle(30) 
            r.SetViewUp(0,1,0)   
        if event == "MiddleButtonReleaseEvent":
            r = ren2.GetActiveCamera()
    
            r.SetPosition(30.0, 72.5, 235.26399390218796)  
            r.SetViewAngle(30) 
            r.SetViewUp(0,1,0)   
        if event == "MouseMoveEvent":
            r = ren2.GetActiveCamera()
    
            r.SetPosition(30.0, 72.5, 235.26399390218796)  
            r.SetViewAngle(30) 
            r.SetViewUp(0,1,0)        
        #### end auto-reset
                
        
        #Normal interaction requires the user to single-click on a button     
        if event == "LeftButtonPressEvent":
            r = iren2.GetEventPosition()
            s = r[0],r[1],0
            pickerMenu.Pick(s, ren2)
            t = pickerMenu.GetActor()
            u = t.GetProperty()
            v = u.GetDiffuseColor()
                    
            if t.my_name == 1:
                meshActor.VisibilityOff()
                t.VisibilityOff()
                buttonlist[1].VisibilityOn()
                renWin.Render()
                
            elif t.my_name == 2:
                meshActor.VisibilityOn()
                t.VisibilityOff()
                buttonlist[0].VisibilityOn()
                renWin.Render()
                
            elif t.my_name == 3:
                e = meshActor.GetProperty()
                f = e.GetOpacity()
                if f < 1:
                    e.SetOpacity(f + 0.1)
                    renWin.Render()
                    
            elif t.my_name == 4:
                e = meshActor.GetProperty()
                f = e.GetOpacity()
                if f > 0:
                    e.SetOpacity(f - 0.1)
                    renWin.Render()
            
            elif t.my_name == 5:
                if planeWidgetX.GetTexturePlaneProperty().GetOpacity() == 0:
                    planeWidgetX.GetTexturePlaneProperty().SetOpacity(1.0)
                    t.GetProperty().SetColor(1,0,0)
                else:
                    planeWidgetX.GetTexturePlaneProperty().SetOpacity(0)
                    t.GetProperty().SetColor(0.3,0.3,0.3)
                    
            elif t.my_name == 6:
                if planeWidgetZ.GetTexturePlaneProperty().GetOpacity() == 0:
                    planeWidgetZ.GetTexturePlaneProperty().SetOpacity(1.0)
                    t.GetProperty().SetColor(1,0,0)
                else:
                    planeWidgetZ.GetTexturePlaneProperty().SetOpacity(0)
                    t.GetProperty().SetColor(0.3,0.3,0.3)
    
            elif t.my_name == 7:
                if planeWidgetY.GetTexturePlaneProperty().GetOpacity() == 0:
                    planeWidgetY.GetTexturePlaneProperty().SetOpacity(1.0)
                    t.GetProperty().SetColor(1,0,0)
                else:
                    planeWidgetY.GetTexturePlaneProperty().SetOpacity(0)
                    t.GetProperty().SetColor(0.3,0.3,0.3)
    
            elif t.my_name == 8:
                if planeWidgetY.GetPlaneProperty().GetOpacity() == 0:
                    planeWidgetX.GetPlaneProperty().SetOpacity(1.0)
                    planeWidgetY.GetPlaneProperty().SetOpacity(1.0)
                    planeWidgetZ.GetPlaneProperty().SetOpacity(1.0)
                    t.GetProperty().SetColor(1,0,0)
                else:
                    planeWidgetX.GetPlaneProperty().SetOpacity(0)
                    planeWidgetY.GetPlaneProperty().SetOpacity(0)
                    planeWidgetZ.GetPlaneProperty().SetOpacity(0)
                    t.GetProperty().SetColor(0.3,0.3,0.3)
    
            elif t.my_name == 9:
                if fibers_visible == 1:
                    for u in range(len(plotData)):
                        profileList[u].VisibilityOff()
                        fibers_visible = 0
                        t.GetProperty().SetColor(0.3,0.3,0.3)
                else:
                    for u in range(len(plotData)):
                        profileList[u].VisibilityOn()
                        fibers_visible = 1
                        t.GetProperty().SetColor(1,0,0)
    
            elif t.my_name == 10:
                if roiActor.GetVisibility() == 1:
                    roiActor.VisibilityOff()
                    boxWidget.Off()
                    renWin.Render()
                    t.GetProperty().SetColor(0.3,0.3,0.3)
                else:
                    roiActor.VisibilityOn()
                    boxWidget.On()
                    renWin.Render()
                    t.GetProperty().SetColor(1,0,0)                 
    
            elif t.my_name == 11:
                current_rad = roi.GetRadius()
                roi.SetRadius(current_rad + 1)
                renWin.Render()
            
            elif t.my_name == 12:
                current_rad = roi.GetRadius()
                if current_rad > 1:
                    roi.SetRadius(current_rad - 1)
                    renWin.Render()
                else:
                    print 'ROI is set to 1 .. smallest value possible = 1mm'           
    
            elif t.my_name == 13:
                if fibers_pickable == 1:
                    for u in range(len(plotData)):
                        profileList[u].SetPickable(0)
                        fibers_pickable = 0
                        t.GetProperty().SetColor(0.3,0.3,0.3)
                else:
                    for u in range(len(plotData)):
                        profileList[u].SetPickable(1)
                        fibers_pickable = 1
                        t.GetProperty().SetColor(1,0,0)
    
            elif t.my_name in range(14,34):
                x = []
                t = vtk.vtkTransform()
                boxWidget.GetTransform(t)
                boxWidget.GetProp3D().SetUserTransform(t)
                xl, xu, yl, yu, zl, zu = boxWidget.GetProp3D().GetBounds()
                for j in range(0,len(plotData)):
                    for i in range(0,len(plotData[j])):
                        if plotData[j][i][0]+x_const >= xl:
                            if plotData[j][i][0]+x_const <= xu:
                                if plotData[j][i][1]+y_const >= yl:
                                    if plotData[j][i][1]+y_const <= yu:
                                        if plotData[j][i][2]+z_const >= zl:
                                            if plotData[j][i][2]+z_const <= zu:
                                                x.append(j)
                                                break
                                                
                print len(x), ' fibers changed'
                for j in x:
                    if  profileList[j].GetVisibility() == 1:
                        profileList[j].GetProperty().SetDiffuseColor(v)
            
            elif t.my_name == 34:
                if func_data_loaded == 0:
                    print 'no functional data loaded'
                else:
                    if func_data_visible == 0:
                        func_data_visible = 1
                        surf_num = 5
                        for i in range(0,len(surf_max)-1):
                            surf_max[i].VisibilityOff()
                            if rMin == 0:
                                pass
                            else:
                                surf_min[i].VisibilityOff()
                        for i in range(len(surf_max)-1, len(surf_max)):
                            surf_max[i].VisibilityOn()
                            if rMin == 0:
                                pass
                            else:
                                surf_min[i].VisibilityOn()
                    elif func_data_visible == 1:
                        for i in range(len(surf_max)-1, len(surf_max)):
                            surf_max[i].VisibilityOn()
                        for i in range(len(surf_max)):
                            if rMin == 0:
                                pass
                            else:
                                surf_min[i].VisibilityOff()
                            func_data_visible = 2
                    elif func_data_visible == 2:
                        for i in range(len(surf_max)):
                            surf_max[i].VisibilityOff()
                        for i in range(len(surf_max)-1, len(surf_max)):
                            if rMin == 0:
                                pass
                            else:
                                surf_min[i].VisibilityOn()
                            func_data_visible = 3
                    else:
                        for i in range(len(surf_max)):
                            surf_max[i].VisibilityOff()
                            if rMin == 0:
                                pass
                            else:
                                surf_min[i].VisibilityOff()
                            func_data_visible = 0
                    
            elif t.my_name == 35:
                surf_num -= 1
                if surf_num < 0:
                    surf_num = 0
                if func_data_visible == 0:
                    pass
                elif func_data_visible == 1:
                    if rMin == 0:
                        pass
                    else:
                        surf_min[surf_num].VisibilityOn()
                    surf_max[surf_num].VisibilityOn()
                elif func_data_visible == 2:
                    surf_max[surf_num].VisibilityOn()
                elif func_data_visible == 3:
                    if rMin == 0:
                        pass
                    else:
                        surf_min[surf_num].VisibilityOn()
    
            elif t.my_name == 36:
                surf_num += 1
                if surf_num == len(surf_max)+1:
                    surf_num = len(surf_max)
                if rMin == 0:
                    pass
                else:
                    surf_min[surf_num-1].VisibilityOff()
                surf_max[surf_num-1].VisibilityOff()
    
            elif t.my_name == 37:
                if ortho_planes_on == 0:
                    ortho_updates()
                    renWin3.AddRenderer(ren3)
                    renWin3.SetSize(300,900)
                    renWin3.SetPosition(1120,0)
                    style = vtk.vtkInteractorStyleTrackballCamera()
                    iren3.SetInteractorStyle(style) 
                    iren3.SetRenderWindow(renWin3)
                    iren3.Initialize()
                    ortho_planes_on = 1
                    iren3.AddObserver("KeyPressEvent", myCallBack)
                    iren3.Start()
                else:
                    ren3.RemoveActor(act_x)
                    ren3.RemoveActor(act_y)
                    ren3.RemoveActor(act_z)
                    
                    ortho_updates()
                    ren3.Render()
                    iren3.Render()        
                                            
            elif t.my_name == 38:
                if ortho_planes_on == 0:
                    ortho_updates()
                    renWin3.AddRenderer(ren3)
                    renWin3.SetSize(300,900)
                    renWin3.SetPosition(1120,0)
                    style = vtk.vtkInteractorStyleTrackballCamera()
                    iren3.SetInteractorStyle(style) 
                    iren3.SetRenderWindow(renWin3)
                    iren3.Initialize()
                    iren3.AddObserver("KeyPressEvent", myCallBack)
                    ortho_planes_on = 1
                    iren3.Start()    
                
                w2i = vtk.vtkWindowToImageFilter()
                writer = vtk.vtkTIFFWriter()
                w2i.SetInput(renWin3)
                w2i.Update()
                writer.SetInputConnection(w2i.GetOutputPort())
                save_name = my_file_dialog(4,0,0)
                if save_name == '':
                    print 'Save cancelled' 
                    return
                writer.SetFileName("%s.tif" %save_name)
                renWin3.Render()
                writer.Write()
                print 'saved ortho-window to %s.tif sucessfully' %save_name
    
            elif t.my_name == 39:
                w2i = vtk.vtkWindowToImageFilter()
                writer = vtk.vtkTIFFWriter()
                w2i.SetInput(renWin)
                w2i.Update()
                writer.SetInputConnection(w2i.GetOutputPort())
                save_name = my_file_dialog(4,0,0)
                if save_name == '':
                    print 'Save cancelled' 
                    return
                writer.SetFileName("%s.tif" %save_name)
                renWin.Render()
                writer.Write()
                print 'saved main window to %s.tif sucessfully' %save_name
                            
            elif t.my_name == 40:
                InteractorHelpThread().start()

    
            elif t.my_name == 41:
                if cortex_data_loaded == 0:
                    #fn = raw_input("Type the path of the .off cortex data file you would like to load: ")
                    fn = my_file_dialog(2,0,0)
                    if fn == '':
                        print 'Load cancelled' 
                        return
                    (meshActor, cortex_data) = tYNI.import_cortex_data(fn)
                    ren.AddActor(meshActor)
                    buttonlist[0].VisibilityOn()
                    renWin2.Render()
                    cortex_data_loaded = 1
                    print 'CORTEX LOADED'
                            
            elif t.my_name == 42:
                if fiber_data_loaded == 0:
                    fn = my_file_dialog(3,0,0)
                    if fn == '':
                        print 'Load cancelled' 
                        return
                    (profileList, lineData, plotData, x_const, y_const, z_const) = tYNI.import_fiber_data(fn)
                    fiber_data_loaded = 1
                    fiber_group = 1
                    print 'fiber_group %d' %fiber_group
                    for i in range(len(profileList)):
                        ren.AddActor(profileList[i])
                        profileList[i].my_fiber_group = fiber_group
                    print 'FIBER DATA READY (IN BACKGROUND)'
                    fiber_group_counter.append((0,len(profileList)))
                    planeWidgetX.GetTexturePlaneProperty().SetOpacity(1.0)
                    planeWidgetY.GetTexturePlaneProperty().SetOpacity(1.0)
                    planeWidgetZ.GetTexturePlaneProperty().SetOpacity(1.0)                    
                else:    
                    fn = my_file_dialog(3,0,0)
                    if fn == '':
                        print 'Load cancelled' 
                        return
                    fiber_group += 1
                    print 'fiber_group %d' %fiber_group
                    profileList2 = []
                    lineData2 = []
                    plotData2 = []
                    orig_fiber_count = len(profileList)
                    (profileList2, lineData2, plotData2, x_const, y_const, z_const) = tYNI.import_fiber_data(fn)
                    for j in range(len(profileList2)):
                        profileList.append(profileList2[j])
                    for j in range(len(lineData2)):
                        lineData.append(lineData2[j])
                    for j in range(len(plotData2)):
                        plotData.append(plotData2[j])
        
                    for i in range(orig_fiber_count,len(profileList)):
                        profileList[i].my_fiber_group = fiber_group
                        ren.AddActor(profileList[i])
                    fiber_data_loaded = 1
                    print 'NEW FIBER DATA READY (IN BACKGROUND)'
                    fiber_group_counter.append((orig_fiber_count,len(profileList)))
                    planeWidgetX.GetTexturePlaneProperty().SetOpacity(1.0)
                    planeWidgetY.GetTexturePlaneProperty().SetOpacity(1.0)
                    planeWidgetZ.GetTexturePlaneProperty().SetOpacity(1.0)
                                  
            elif t.my_name == 43:
                load_functional_data()
                func_data_loaded = 1
                func_data_visible = 1
    
            elif t.my_name == 44:
                iact.TerminateApp()
                
            elif t.my_name == 45:
                resetPlanes()
                ren.Render()
     
            elif t.my_name in range(46,52): # NB t.my_name is one indexed not zero-indexed!
                for i in range(45,51):
                    r = buttonlist[i].GetProperty()
                    r.SetColor(0.3,0.3,0.3)                
                t.GetProperty().SetColor(1,0,0)            
                if t.my_name == 46:
                    val1 = (90,145,1000)
                    val2 = (0,1,0)
                    current_mpr_orient = 1
                elif t.my_name == 47:
                    val1 = (90,145,-717)
                    val2 = (0,1,0)
                    current_mpr_orient = 2                
                elif t.my_name == 48:
                    val1 = (90,1000,154)
                    val2 = (0,0,1)
                    current_mpr_orient = 3              
                elif t.my_name == 49:
                    val1 = (90,-717,155)
                    val2 = (0,0,1)
                    current_mpr_orient = 4  
                elif t.my_name == 50:
                    val1 = (-775,145,155)
                    val2 = (0,0,1)
                    current_mpr_orient = 5  
                elif t.my_name == 51:
                    val1 = (1000,145,149)
                    val2 = (0,0,1)
                    current_mpr_orient = 0           
                ttt = ren.GetActiveCamera()
                ttt.SetPosition(val1)
                ttt.SetViewUp(val2)
                ttt.SetViewAngle(30)
             
            elif t.my_name == 52:       
                ####################### DIPOLE ACTOR #############
                #NB - import dipole data in cm 
                # divide location and magnitude elements by 100, append 1
                # use as a 1*4 vector
                # import spheres.txt transformation matrix  
                # trans_rot * location vector
                # first 41 lines are header
                dipole_file = my_file_dialog(5,0,0)
                if dipole_file == '':
                        print 'Load cancelled' 
                        return
                dipole_transform = my_file_dialog(6,0,0)
                if dipole_transform == '':
                        print 'Load cancelled' 
                        return
                first_dipole_point = my_file_dialog(20,0,0)
                if first_dipole_point == '':
                        print 'Load cancelled' 
                        return
                dipole_current_time = float(first_dipole_point)
                p = open(dipole_file)       
                d = p.readlines()
                p.close()
                rows = []
                data1 = []
                for c in d:
                    f = c.split()
                    rows.append(f)
                dipole_file_loaded = 1
                rows = rows[41:] # first 41 lines are header
                dipole_rows = array(rows).astype(float)
                dipole_min_time = dipole_rows[0][0]
                dipole_max_time = dipole_rows[len(rows)-1][0]
                dipole_data = dipole_rows[dipole_rows[:,0] >= float(first_dipole_point), :][0]            
                e = reshape(fromfile(dipole_transform, sep=' ') , (4,4))
                coord1 = array([dipole_data[1]/100, dipole_data[2]/100, dipole_data[3]/100,1]).reshape(4,1)
                coord2 = array([dipole_data[4]/100, dipole_data[5]/100, dipole_data[6]/100,1]).reshape(4,1)         
                trans_coord1 = dot(e, coord1).reshape(1,4)
                trans_coord2 = dot(e, coord2).reshape(1,4)
                #  ADD Sphere  #
                Sphere = vtk.vtkSphereSource()
                Sphere.SetCenter( trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
                Sphere.SetRadius( dipole_data[7]/2 )
                SphereMapper = vtk.vtkPolyDataMapper()
                SphereMapper.SetInput(Sphere.GetOutput())
                dipoleSphereActor = vtk.vtkActor()
                dipoleSphereActor.SetMapper(SphereMapper)
                dipoleSphereActor.GetProperty().SetColor(1,0,0) 
                point1 = [trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]]
                point2 = [trans_coord2[0][0]+trans_coord1[0][0], trans_coord2[0][1]+trans_coord1[0][1], trans_coord2[0][2]+trans_coord1[0][2]]
                point2_rescale = add((subtract(point2,point1))/10,point1)            
                #  ADD Line  #
                line = vtk.vtkLineSource()
                line.SetPoint1(point1)
                line.SetPoint2(point2_rescale)            
                lineMapper = vtk.vtkPolyDataMapper()
                lineMapper.SetInput(line.GetOutput())
                dipoleLineActor = vtk.vtkActor()
                dipoleLineActor.SetMapper(lineMapper)
                dipoleLineActor.GetProperty().SetColor(1,0,0)
                dipoleLineActor.GetProperty().SetLineWidth(2)
                ren.AddActor(dipoleLineActor)
                ren.AddActor(dipoleSphereActor)
                ###End Dipole Actor#
           
           
            elif t.my_name == 53:
                if dipoleSphereActor.GetVisibility() == 0:
                    dipoleSphereActor.SetVisibility(1)
                    dipoleLineActor.SetVisibility(1)
                else:
                    dipoleSphereActor.SetVisibility(0)
                    dipoleLineActor.SetVisibility(0)                                       
    
            elif t.my_name == 55:
                dipole_data = dipole_rows[dipole_rows[:,0] >= float(dipole_current_time+1), :][0]
                dipole_current_time -=1
                dipole_mat = reshape(fromfile(dipole_transform, sep=' ') , (4,4))
                coord1 = array([dipole_data[1]/100, dipole_data[2]/100, dipole_data[3]/100,1]).reshape(4,1)
                coord2 = array([dipole_data[4]/100, dipole_data[5]/100, dipole_data[6]/100,1]).reshape(4,1)         
                trans_coord1 = dot(dipole_mat, coord1).reshape(1,4)
                trans_coord2 = dot(dipole_mat, coord2).reshape(1,4)
                Sphere.SetCenter( trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
                Sphere.SetRadius( dipole_data[7]/2 )
                point1 = [trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]]
                point2 = [trans_coord2[0][0]+trans_coord1[0][0], trans_coord2[0][1]+trans_coord1[0][1], trans_coord2[0][2]+trans_coord1[0][2]]
                point2_rescale = add((subtract(point2,point1))/10,point1)
                line.SetPoint1(point1)
                line.SetPoint2(point2_rescale)
                print 'dipole at %d ms' %dipole_current_time            
                
            elif t.my_name == 54:
                dipole_data = dipole_rows[dipole_rows[:,0] >= float(dipole_current_time+1), :][0]
                dipole_current_time +=1
                dipole_mat = reshape(fromfile(dipole_transform, sep=' ') , (4,4))
                coord1 = array([dipole_data[1]/100, dipole_data[2]/100, dipole_data[3]/100,1]).reshape(4,1)
                coord2 = array([dipole_data[4]/100, dipole_data[5]/100, dipole_data[6]/100,1]).reshape(4,1)         
                trans_coord1 = dot(dipole_mat, coord1).reshape(1,4)
                trans_coord2 = dot(dipole_mat, coord2).reshape(1,4)     
                Sphere.SetCenter( trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
                Sphere.SetRadius( dipole_data[7]/2 )
                
                point1 = [trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]]
                point2 = [trans_coord2[0][0]+trans_coord1[0][0], trans_coord2[0][1]+trans_coord1[0][1], trans_coord2[0][2]+trans_coord1[0][2]]
                point2_rescale = add((subtract(point2,point1))/10,point1)
                line.SetPoint1(point1)
                line.SetPoint2(point2_rescale)        
                print 'dipole at %d ms' %dipole_current_time              

            elif t.my_name == 56:
                if functional_2d_volume != []: #if a 2d_merge exists
                    if mpr_showing_2d == 1:
                        print '\nswitching to structural view'
                        planeWidgetX.SetInput(vol.GetOutput())
                        planeWidgetX.SetLookupTable(original_LU_table)
                        planeWidgetX.SetWindowLevel(2536, 2687)
                        planeWidgetY.SetInput(vol.GetOutput())
                        planeWidgetY.SetLookupTable(planeWidgetX.GetLookupTable())
                        planeWidgetY.SetWindowLevel(2536, 2687)
                        planeWidgetZ.SetInput(vol.GetOutput())
                        planeWidgetZ.SetLookupTable(planeWidgetX.GetLookupTable())
                        planeWidgetZ.SetWindowLevel(2536, 2687)
                        planeWidgetX.SetUserControlledLookupTable(0)
                        planeWidgetY.SetUserControlledLookupTable(0)
                        planeWidgetZ.SetUserControlledLookupTable(0)
                        if pos_colour_bar_generated == 1:
                            pos_scalarBar.VisibilityOff()
                        if neg_colour_bar_generated == 1:
                            neg_scalarBar.VisibilityOff()
                        resetPlanes()
                        mpr_showing_2d = 0
                    else:
                        print '\nswitching to 2d-merged functional view'
                        planeWidgetX.SetUserControlledLookupTable(1)
                        planeWidgetY.SetUserControlledLookupTable(1)
                        planeWidgetZ.SetUserControlledLookupTable(1)
                        
                        planeWidgetX.SetLookupTable(functionalLut)
                        planeWidgetY.SetLookupTable(functionalLut)
                        planeWidgetZ.SetLookupTable(functionalLut)
                        planeWidgetX.SetWindowLevel(250,125)
                        planeWidgetY.SetWindowLevel(250,125)
                        planeWidgetZ.SetWindowLevel(250,125)
                        
                        planeWidgetX.SetInput(functional_2d_volume)
                        planeWidgetY.SetInput(functional_2d_volume)
                        planeWidgetZ.SetInput(functional_2d_volume)
                        if pos_colour_bar_generated == 1:
                            pos_scalarBar.VisibilityOn()
                        if neg_colour_bar_generated == 1:
                            neg_scalarBar.VisibilityOn()
                        resetPlanes()
                        mpr_showing_2d = 1
                
                
                
        #now apply any changes made by the MENU
        renWin.Render()
    
    
            
            
            
    
    def myCallBackForMenu(obj, event):
        key2=iren2.GetKeySym()
    
    style = vtk.vtkInteractorStyleTrackballCamera()
    iren2.SetInteractorStyle(style) 
    
    
    iren2.AddObserver("MouseMoveEvent", ButtonEvent)  
    iren2.AddObserver("LeftButtonPressEvent", ButtonEvent)
    iren2.AddObserver("LeftButtonReleaseEvent", ButtonEvent)
    iren2.AddObserver("RightButtonPressEvent", ButtonEvent)
    iren2.AddObserver("RightButtonReleaseEvent", ButtonEvent)
    iren2.AddObserver("MiddleButtonPressEvent", ButtonEvent)
    iren2.AddObserver("MiddleButtonReleaseEvent", ButtonEvent)
    iren2.AddObserver("MouseMoveEvent", ButtonEvent)
    iren2.AddObserver("KeyPressEvent", ButtonEvent)

    iren2.Initialize()
    iren2.AddObserver("KeyPressEvent",myCallBackForMenu)
    iact.ReInitialize()
    iact.RemoveObserver(1)
    iren2.Start()
    
######## -- END MENU WINDOW -- #########################################################




##### -- KEY BINDINGS FOR MAIN WINDOW -- ####################################################
######## -- Setup main window functionality -- ###############################################################

ttt= []
ren3 = vtk.vtkRenderer()
renWin3 = vtk.vtkRenderWindow()
iren3 = vtk.vtkRenderWindowInteractor()
act_x = vtk.vtkActor2D()
act_y = vtk.vtkActor2D()
act_z = vtk.vtkActor2D()
meshActor = vtk.vtkActor()
menu_initialised = 0            


#when the user first clicks on the window, the menu window is generated and loaded
def myCallBack_loadMenuAtStart(obj, event):
    global menu_initialised, iren2, iact
    
    if event == "LeftButtonPressEvent":
            if menu_initialised == 0:
                menu_initialised = 1
                myMenu()


#reset the planewidgets to initial view if required at any time
def resetPlanes():
    planeWidgetX.SetPlaneOrientationToXAxes()
    planeWidgetY.SetPlaneOrientationToYAxes()
    planeWidgetZ.SetPlaneOrientationToZAxes()
    planeWidgetX.SetSliceIndex(x_slice_pos)
    planeWidgetY.SetSliceIndex(y_slice_pos)
    planeWidgetZ.SetSliceIndex(z_slice_pos)

stuff=[]

#Main window callback routine            
def myCallBack(obj, event):
    global renWin, iact, ren, currentActor, currentActor2, surf_num, ttt, a, b, c, d, e, f, g, ortho_planes_on, ren3, renWin3, iren3, act_x, act_y, act_z, fiber_data_loaded, profileList, plotData, lineData, cortex_data_loaded, meshActor, menu_initialised, func_data_visible, func_data_loaded, surf_num, surf_max, surf_min, current_mpr_orient, fiber_group, fiber_group_counter, x_const, y_const, z_const, counter, a_fibers, rows, dipole_file_loaded, dipole_rows, dipole_transform, dipole_min_time, dipole_max_time, dipole_current_time, dipoleSphereActor, dipoleLineActor, dipole_mat, Sphere, SphereMapper, line, lineMapper, cortex_data, origActor, selectActor, planes, boxWidget, boxWidget2, buttonlist, renWin2, clipper, brain, triangles, vertices, polyData, first_dipole_point,x,y,stuff,functionalLut, vol, original_LU_table, functionalLut, functional_2d_volume, pos_scalarBar, pos_colour_bar_generated, resetPlanes, mpr_showing_2d

    
    key=iact.GetKeySym()
    
    #add extra funtionality options by enabling CTRL - selecting
    if obj.GetControlKey():
        
        if key == 'Up': #place holder
            my_file_dialog(1,0,0)
        
        if key == 'Down': #place holder
            pass
        
        if key == 'Left': # move dipole back to previous time-point
            dipole_data = dipole_rows[dipole_rows[:,0] >= float(dipole_current_time+1), :][0]
            dipole_current_time -=1
            dipole_mat = reshape(fromfile(dipole_transform, sep=' ') , (4,4))
            coord1 = array([dipole_data[1]/100, dipole_data[2]/100, dipole_data[3]/100,1]).reshape(4,1)
            coord2 = array([dipole_data[4]/100, dipole_data[5]/100, dipole_data[6]/100,1]).reshape(4,1)         
            trans_coord1 = dot(dipole_mat, coord1).reshape(1,4)
            trans_coord2 = dot(dipole_mat, coord2).reshape(1,4)
            Sphere.SetCenter( trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
            Sphere.SetRadius( dipole_data[7]/2 )
            point1 = [trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]]
            point2 = [trans_coord2[0][0]+trans_coord1[0][0], trans_coord2[0][1]+trans_coord1[0][1], trans_coord2[0][2]+trans_coord1[0][2]]
            point2_rescale = add((subtract(point2,point1))/10,point1)
            line.SetPoint1(point1)
            line.SetPoint2(point2_rescale)         
            print 'dipole at %d ms' %dipole_current_time
        
        if key == 'Right': # move dipole back to next time-point
            dipole_data = dipole_rows[dipole_rows[:,0] >= float(dipole_current_time+1), :][0]
            dipole_current_time +=1
            dipole_mat = reshape(fromfile(dipole_transform, sep=' ') , (4,4))
            coord1 = array([dipole_data[1]/100, dipole_data[2]/100, dipole_data[3]/100,1]).reshape(4,1)
            coord2 = array([dipole_data[4]/100, dipole_data[5]/100, dipole_data[6]/100,1]).reshape(4,1)         
            trans_coord1 = dot(dipole_mat, coord1).reshape(1,4)
            trans_coord2 = dot(dipole_mat, coord2).reshape(1,4)
            Sphere.SetCenter( trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
            Sphere.SetRadius( dipole_data[7]/2 )
            point1 = [trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]]
            point2 = [trans_coord2[0][0]+trans_coord1[0][0], trans_coord2[0][1]+trans_coord1[0][1], trans_coord2[0][2]+trans_coord1[0][2]]
            point2_rescale = add((subtract(point2,point1))/10,point1)
            line.SetPoint1(point1)
            line.SetPoint2(point2_rescale)
            print 'dipole at %d ms' %dipole_current_time
                        
        if key == '1': #interact with first fiber group loaded
            if fiber_group < 1:
                print 'no fiber group 1 - load some fiber data!'
            
            elif profileList[fiber_group_counter[0][0]].GetVisibility() == 0:
                print profileList[fiber_group_counter[0][0]].GetVisibility()
                for j in range (fiber_group_counter[0][0], fiber_group_counter[0][1]):
                    profileList[j].VisibilityOn()
                        
            elif profileList[fiber_group_counter[0][0]].GetVisibility() == 1:
                print profileList[fiber_group_counter[0][0]].GetVisibility()

                for k in range (fiber_group_counter[0][0], fiber_group_counter[0][1]):
                    profileList[k].VisibilityOff()
    
        if key == '2': #interact with 2nd fiber group loaded
            if fiber_group < 2:
                print 'no fiber group 2 - load some fiber data!'
            
            elif profileList[fiber_group_counter[1][0]].GetVisibility() == 0:
                print profileList[fiber_group_counter[1][0]].GetVisibility()
                for j in range (fiber_group_counter[1][0], fiber_group_counter[1][1]):
                    profileList[j].VisibilityOn()
                        
            elif profileList[fiber_group_counter[1][0]].GetVisibility() == 1:
                print profileList[fiber_group_counter[1][0]].GetVisibility()

                for k in range (fiber_group_counter[1][0], fiber_group_counter[1][1]):
                    profileList[k].VisibilityOff()
                    
        if key == '3': #interact with 3rd fiber group loaded
            if fiber_group < 3:
                print 'no fiber group 3 - load some fiber data!'
            
            elif profileList[fiber_group_counter[2][0]].GetVisibility() == 0:
                print profileList[fiber_group_counter[2][0]].GetVisibility()
                for j in range (fiber_group_counter[2][0], fiber_group_counter[2][1]):
                    profileList[j].VisibilityOn()
                        
            elif profileList[fiber_group_counter[2][0]].GetVisibility() == 1:
                print profileList[fiber_group_counter[2][0]].GetVisibility()

                for k in range (fiber_group_counter[2][0], fiber_group_counter[2][1]):
                    profileList[k].VisibilityOff()
    
        if key == '4': #interact with 4th fiber group loaded
            if fiber_group < 4:
                print 'no fiber group 4 - load some fiber data!'
            
            elif profileList[fiber_group_counter[3][0]].GetVisibility() == 0:
                print profileList[fiber_group_counter[3][0]].GetVisibility()
                for j in range (fiber_group_counter[3][0], fiber_group_counter[3][1]):
                    profileList[j].VisibilityOn()
                        
            elif profileList[fiber_group_counter[3][0]].GetVisibility() == 1:
                print profileList[fiber_group_counter[3][0]].GetVisibility()

                for k in range (fiber_group_counter[3][0], fiber_group_counter[3][1]):
                    profileList[k].VisibilityOff()  

        if key == '5': #interact with 5th fiber group loaded
            if fiber_group < 5:
                print 'no fiber group 5 - load some fiber data!'
            
            elif profileList[fiber_group_counter[4][0]].GetVisibility() == 0:
                print profileList[fiber_group_counter[4][0]].GetVisibility()
                for j in range (fiber_group_counter[4][0], fiber_group_counter[4][1]):
                    profileList[j].VisibilityOn()
                        
            elif profileList[fiber_group_counter[4][0]].GetVisibility() == 1:
                print profileList[fiber_group_counter[4][0]].GetVisibility()

                for k in range (fiber_group_counter[4][0], fiber_group_counter[4][1]):
                    profileList[k].VisibilityOff()
    
        if key == '6': #interact with 6th fiber group loaded
            if fiber_group < 6:
                print 'no fiber group 6 - load some fiber data!'
            
            elif profileList[fiber_group_counter[5][0]].GetVisibility() == 0:
                print profileList[fiber_group_counter[5][0]].GetVisibility()
                for j in range (fiber_group_counter[5][0], fiber_group_counter[5][1]):
                    profileList[j].VisibilityOn()
                        
            elif profileList[fiber_group_counter[5][0]].GetVisibility() == 1:
                print profileList[fiber_group_counter[5][0]].GetVisibility()

                for k in range (fiber_group_counter[5][0], fiber_group_counter[5][1]):
                    profileList[k].VisibilityOff()
        
        
        if key == '7': #place holder
            pass
            
            
        if key == '8': #place holder
            pass    
        
        
        if key == 'h': #place holder
            #os.spawnlp(os.P_NOWAIT,'ynicDV3D_wx_main_help.py')
            InteractorHelpThread().start()  
        
        
        #TODO - make ctrl a b c d e f etc all fiber sub-groups which can be
            #   1.) made created from current visible selection / touching roi
            #   2.) if exists .. make visible / invisible
            #   3.) ? - allow append of current set with SHIFT-CTRL?
        if key == 'a': #add fibers on the screen to a stored group
            for j in range(len(profileList)):
                if profileList[j].GetVisibility() == 1:
                    a_fibers.append(j)
        
        
        if key == 'b': #view the stored fiber group
            for j in a_fibers:
                profileList[j].VisibilityOn()
                
                
    elif key == 'q': #quit app
        iact.TerminateApp()
        
    elif key == 'z': #toggle instructions on / off
        InteractorHelpThread().start()
    
    elif key == 'f':  #toggle func data visibility
        if func_data_loaded == 0:
            print 'no functional data loaded'
        else:
            if func_data_visible == 0:
                func_data_visible = 1
                surf_num = 5
                for i in range(0,len(surf_max)-1):
                    surf_max[i].VisibilityOff()
                    if rMin == 0:
                        pass
                    else:
                        surf_min[i].VisibilityOff()
                for i in range(len(surf_max)-1, len(surf_max)):
                    surf_max[i].VisibilityOn()
                    if rMin == 0:
                        pass
                    else:
                        surf_min[i].VisibilityOn()
            elif func_data_visible == 1:
                for i in range(len(surf_max)-1, len(surf_max)):
                    surf_max[i].VisibilityOn()
                for i in range(len(surf_max)):
                    if rMin == 0:
                        pass
                    else:
                        surf_min[i].VisibilityOff()
                    func_data_visible = 2
            elif func_data_visible == 2:
                for i in range(len(surf_max)):
                    surf_max[i].VisibilityOff()
                for i in range(len(surf_max)-1, len(surf_max)):
                    if rMin == 0:
                        pass
                    else:
                        surf_min[i].VisibilityOn()
                    func_data_visible = 3
            else:
                for i in range(len(surf_max)):
                    surf_max[i].VisibilityOff()
                    if rMin == 0:
                        pass
                    else:
                        surf_min[i].VisibilityOff()
                    func_data_visible = 0
    
    
    elif key == 'M':
        if functional_2d_volume != []: #if a 2d_merge exists
            if mpr_showing_2d == 1:
                print '\nswitching to structural view'
                planeWidgetX.SetInput(vol.GetOutput())
                planeWidgetX.SetLookupTable(original_LU_table)
                planeWidgetX.SetWindowLevel(2536, 2687)
                planeWidgetY.SetInput(vol.GetOutput())
                planeWidgetY.SetLookupTable(planeWidgetX.GetLookupTable())
                planeWidgetY.SetWindowLevel(2536, 2687)
                planeWidgetZ.SetInput(vol.GetOutput())
                planeWidgetZ.SetLookupTable(planeWidgetX.GetLookupTable())
                planeWidgetZ.SetWindowLevel(2536, 2687)
                planeWidgetX.SetUserControlledLookupTable(0)
                planeWidgetY.SetUserControlledLookupTable(0)
                planeWidgetZ.SetUserControlledLookupTable(0)
                if pos_colour_bar_generated == 1:
                    pos_scalarBar.VisibilityOff()
                if neg_colour_bar_generated == 1:
                    neg_scalarBar.VisibilityOff()
                resetPlanes()
                mpr_showing_2d = 0
            else:
                print '\nswitching to 2d-merged functional view'
                planeWidgetX.SetUserControlledLookupTable(1)
                planeWidgetY.SetUserControlledLookupTable(1)
                planeWidgetZ.SetUserControlledLookupTable(1)
                
                planeWidgetX.SetLookupTable(functionalLut)
                planeWidgetY.SetLookupTable(functionalLut)
                planeWidgetZ.SetLookupTable(functionalLut)
                planeWidgetX.SetWindowLevel(250,125)
                planeWidgetY.SetWindowLevel(250,125)
                planeWidgetZ.SetWindowLevel(250,125)
                
                planeWidgetX.SetInput(functional_2d_volume)
                planeWidgetY.SetInput(functional_2d_volume)
                planeWidgetZ.SetInput(functional_2d_volume)
                if pos_colour_bar_generated == 1:
                    pos_scalarBar.VisibilityOn()
                if neg_colour_bar_generated == 1:
                    neg_scalarBar.VisibilityOn()
                resetPlanes()
                mpr_showing_2d = 1
    
    
    elif key == 'S': #toggle sagittal plane on / off
        if planeWidgetX.GetTexturePlaneProperty().GetOpacity() == 0:
            planeWidgetX.GetTexturePlaneProperty().SetOpacity(1.0)
        else:
            planeWidgetX.GetTexturePlaneProperty().SetOpacity(0)
    
    elif key == 'C': #toggle coronal plane on / off
        if planeWidgetY.GetTexturePlaneProperty().GetOpacity() == 0:
            planeWidgetY.GetTexturePlaneProperty().SetOpacity(1.0)
        else:
            planeWidgetY.GetTexturePlaneProperty().SetOpacity(0)

    elif key == 'A': #toggle axial plane on / off
        if planeWidgetZ.GetTexturePlaneProperty().GetOpacity() == 0:
            planeWidgetZ.GetTexturePlaneProperty().SetOpacity(1.0)
        else:
            planeWidgetZ.GetTexturePlaneProperty().SetOpacity(0)

    elif key == 'j': #save ortho window to tiff file
        if ortho_planes_on == 0:
            ortho_updates()
            renWin3.AddRenderer(ren3)
            renWin3.SetSize(300,900)
            renWin3.SetPosition(1120,0)
            style = vtk.vtkInteractorStyleTrackballCamera()
            iren3.SetInteractorStyle(style) 
            iren3.SetRenderWindow(renWin3)
            iren3.Initialize()
            iren3.AddObserver("KeyPressEvent", myCallBack)
            ortho_planes_on = 1
            iren3.Start() 
        
        w2i = vtk.vtkWindowToImageFilter()
        writer = vtk.vtkTIFFWriter()
        w2i.SetInput(renWin3)
        w2i.Update()
        writer.SetInputConnection(w2i.GetOutputPort())
        save_name = my_file_dialog(4,0,0)
        if save_name == '':
            print 'Load cancelled' 
            return
        writer.SetFileName("%s.tif" %save_name)
        renWin3.Render()
        writer.Write()
        print 'saved ortho-window to %s.tif sucessfully' %save_name       
    
    elif key == 'J': #save main window to tiff file
        w2i = vtk.vtkWindowToImageFilter()
        writer = vtk.vtkTIFFWriter()
        w2i.SetInput(renWin)
        w2i.Update()
        writer.SetInputConnection(w2i.GetOutputPort())
        save_name = my_file_dialog(4,0,0)
        if save_name == '':
            print 'Load cancelled' 
            return
        writer.SetFileName("%s.tif" %save_name)
        writer.Write()
        print 'saved main window to %s.tif sucessfully' %save_name


    elif key == 'T':
        ## Fiber Data -- ##
        #Import data to structure
        if fiber_data_loaded == 0:
            fn = my_file_dialog(3,0,0)
            if fn == '':
                print 'Load cancelled' 
                return
            (profileList, lineData, plotData, x_const, y_const, z_const) = tYNI.import_fiber_data(fn)
            fiber_data_loaded = 1
            fiber_group = 1
            print 'fiber_group %d' %fiber_group
            for i in range(len(profileList)):
                ren.AddActor(profileList[i])
                profileList[i].my_fiber_group = fiber_group
            print 'FIBER DATA READY (IN BACKGROUND)'
            fiber_group_counter.append((0,len(profileList)))                
            planeWidgetX.GetTexturePlaneProperty().SetOpacity(1.0)
            planeWidgetY.GetTexturePlaneProperty().SetOpacity(1.0)
            planeWidgetZ.GetTexturePlaneProperty().SetOpacity(1.0)
        else:    
            fn = my_file_dialog(3,0,0)
            if fn == '':
                print 'Load cancelled' 
                return
            fiber_group += 1
            print 'fiber_group %d' %fiber_group
            profileList2 = []
            lineData2 = []
            plotData2 = []
            orig_fiber_count = len(profileList)
            (profileList2, lineData2, plotData2, x_const, y_const, z_const) = tYNI.import_fiber_data(fn)
            for j in range(len(profileList2)):
                profileList.append(profileList2[j])
            for j in range(len(lineData2)):
                lineData.append(lineData2[j])
            for j in range(len(plotData2)):
                plotData.append(plotData2[j])
    
            for i in range(orig_fiber_count,len(profileList)):
                profileList[i].my_fiber_group = fiber_group
                ren.AddActor(profileList[i])
            fiber_data_loaded = 1
            print 'NEW FIBER DATA READY (IN BACKGROUND)'
            fiber_group_counter.append((orig_fiber_count,len(profileList)))                   
            planeWidgetX.GetTexturePlaneProperty().SetOpacity(1.0)
            planeWidgetY.GetTexturePlaneProperty().SetOpacity(1.0)
            planeWidgetZ.GetTexturePlaneProperty().SetOpacity(1.0)
            ## END FIBER DATA 


    elif key == 'B':
    ## Cortex Data -- ##
    #Import data to structure
        if cortex_data_loaded == 0:
            fn = my_file_dialog(2,0,0)
            if fn == '':
                print 'Load cancelled' 
                return
            (meshActor, polyData) = tYNI.import_cortex_data(fn)
            buttonlist[0].VisibilityOn()
            renWin2.Render()
            ren.AddActor(meshActor)
            cortex_data_loaded = 1
                        
            planes = vtk.vtkPlanes()
            clipper = vtk.vtkClipPolyData()
            clipper.SetInput(polyData)
            clipper.SetClipFunction(planes)
            clipper.InsideOutOn()
            selectMapper = vtk.vtkPolyDataMapper()
            selectMapper.SetInput(clipper.GetOutput())
            selectActor = vtk.vtkLODActor()
            selectActor.SetMapper(selectMapper)
            selectActor.GetProperty().SetColor(0.6, 0.6, 0.6)
            
            #set some properties for the inner face
            #TODO figure out why inside and outside are the wrong ay around?
            property1 = vtk.vtkProperty()
            property1.SetColor(0.2, 0.9, 0.2)
            #property1.SetDiffuse(0.7)
            #property1.SetSpecular(0.4)
            #property1.SetSpecularPower(20)
            
            selectActor.SetBackfaceProperty(property1)
            selectActor.VisibilityOff()
            selectActor.SetScale(1.01, 1.01, 1.01)
            
            # The SetInteractor method is how 3D widgets are associated with the
            # render window interactor.  Internally, SetInteractor sets up a bunch
            # of callbacks using the Command/Observer mechanism (AddObserver()).
            boxWidget2 = vtk.vtkBoxWidget()
            boxWidget2.SetInteractor(iact)
            boxWidget2.SetKeyPressActivationValue('V')
            boxWidget2.SetPlaceFactor(1.25)
            
            ren.AddActor(origActor)
            ren.AddActor(selectActor)
            
            # This callback funciton does the actual work: updates the vtkPlanes
            # implicit function.  This in turn causes the pipeline to update.
            def SelectPolygons(object, event):
                # object will be the boxWidget
                global meshActor,selectActor, planes
                object.GetPlanes(planes)
                selectActor.VisibilityOn()
                meshActor.VisibilityOff()
             
            # Place the interactor initially. The input to a 3D widget is used to
            # initially position and scale the widget. The "EndInteractionEvent" is
            # observed which invokes the SelectPolygons callback.
            boxWidget2.SetInput(polyData)
            boxWidget2.PlaceWidget()
            boxWidget2.AddObserver("EndInteractionEvent", SelectPolygons)
            
            print 'CORTEX LOADED'
        else:
            print 'cortex data already loaded'
            ## END CORTEX DATA


    elif key == 'x':
    ##CREATE ORTHOGONAL IMAGES  ##
        if ortho_planes_on == 0: #if not already loaded
            ortho_updates()
            renWin3.AddRenderer(ren3)
            renWin3.SetSize(300,900)
            renWin3.SetPosition(1120,0)
            style = vtk.vtkInteractorStyleTrackballCamera()
            iren3.SetInteractorStyle(style) 
            iren3.SetRenderWindow(renWin3)
            iren3.Initialize()
            ortho_planes_on = 1
            iren3.AddObserver("KeyPressEvent", myCallBack)
            iren3.Start()
        else:
            ren3.RemoveActor(act_x)
            ren3.RemoveActor(act_y)
            ren3.RemoveActor(act_z)

            ortho_updates()
            ren3.Render()
            iren3.Render()
            ##  END ORTHOGONAL IMAGES 
    
    
    elif key == 'L': #reset MPR planes to original views
        resetPlanes()
        ren.Render()


    elif key == 's': #seed from beamformer roi            
        r = iact.GetEventPosition()
        s = r[0],r[1],0
        picker1.Pick(s, ren)
        t = picker1.GetActor()
        u = t.GetBounds()
        xl, xu, yl, yu, zl, zu = t.GetBounds()
        x = []
        for j in range(0,len(plotData)):
            for i in range(0,len(plotData[j])):
                if plotData[j][i][0]+x_const >= xl:
                    if plotData[j][i][0]+x_const <= xu:
                        if plotData[j][i][1]+y_const >= yl:
                            if plotData[j][i][1]+y_const <= yu:
                                if plotData[j][i][2]+z_const >= zl:
                                    if plotData[j][i][2]+z_const <= zu:
                                        x.append(j)
                                        break
        for j in x:
            profileList[j].VisibilityOn()
            
    elif key == 'G': # All fibers off
        for u in range(len(plotData)):
            profileList[u].VisibilityOff()

    elif key == 'g': # All fibers visible
        for u in range(len(plotData)):
            profileList[u].VisibilityOn()                     

    elif key == 'b': # increase cortex opacity
        props = meshActor.GetProperty()
        op = props.GetOpacity()
        if op < 1.0:
            meshActor.GetProperty().SetOpacity(0.1+op)
    
    elif key == 'n': # decrease cortex opacity
        props = meshActor.GetProperty()
        op = props.GetOpacity()
        if op > 0.1:
            meshActor.GetProperty().SetOpacity(op-0.1)
    
    elif key =='v': #toggle cortex visibility on / off
        vis = meshActor.GetVisibility()
        if vis == 1:
            meshActor.VisibilityOff()
        else:
            meshActor.VisibilityOn()
            props = meshActor.GetProperty()
            if props.GetOpacity() == 0:
                meshActor.GetProperty().SetOpacity(0.1)

    elif key == 'h': #swith on ROI
        if roiActor.GetVisibility() == 0:
            roiActor.VisibilityOn()
            boxWidget.On()
        else:
            roiActor.VisibilityOff()
            boxWidget.Off()
    
    
    elif key == 'o': #place planes at pick point
        pick = iact.GetPicker()
        r = iact.GetEventPosition()
        s = r[0],r[1],0
        pick.Pick(s, ren)
        ppos = pick.GetPickPosition()  
        planeWidgetX.SetSlicePosition(ppos[0])
        planeWidgetY.SetSlicePosition(ppos[1])
        planeWidgetZ.SetSlicePosition(ppos[2])
        
    elif key == 't': #place ROI at pick point
        x = planeWidgetX.GetCurrentCursorPosition()
        y = planeWidgetY.GetCurrentCursorPosition()
        z = planeWidgetZ.GetCurrentCursorPosition()
        pick = iact.GetPicker()
        r = iact.GetEventPosition()
        s = r[0],r[1],0
        pick.Pick(s, ren)
        ppos = pick.GetPickPosition()
        roi.SetRadius(8)
        roi.SetCenter(ppos[0],ppos[1],ppos[2])
        renWin.Render()
        boxWidget.SetProp3D(roiActor)
        boxWidget.PlaceWidget()

    elif key == 'k': #remove non-roi fibers  
        x = []
        t = vtk.vtkTransform()
        boxWidget.GetTransform(t)
        boxWidget.GetProp3D().SetUserTransform(t)
        xl, xu, yl, yu, zl, zu = boxWidget.GetProp3D().GetBounds()
        for j in range(0,len(plotData)):
            for i in range(0,len(plotData[j])):
                if plotData[j][i][0]+x_const >= xl:
                    if plotData[j][i][0]+x_const <= xu:
                        if plotData[j][i][1]+y_const >= yl:
                            if plotData[j][i][1]+y_const <= yu:
                                if plotData[j][i][2]+z_const >= zl:
                                    if plotData[j][i][2]+z_const <= zu:
                                        x.append(j)
                                        break
        for j in [fib for fib in range(len(plotData)) if fib not in x]:
            profileList[j].VisibilityOff()
   
    elif key == 'K': #remove ROI - picked fibers 
        x = []
        t = vtk.vtkTransform()
        boxWidget.GetTransform(t)
        boxWidget.GetProp3D().SetUserTransform(t)
        xl, xu, yl, yu, zl, zu = boxWidget.GetProp3D().GetBounds()
        for j in range(0,len(plotData)):
            for i in range(0,len(plotData[j])):
                if plotData[j][i][0]+x_const >= xl:
                    if plotData[j][i][0]+x_const <= xu:
                        if plotData[j][i][1]+y_const >= yl:
                            if plotData[j][i][1]+y_const <= yu:
                                if plotData[j][i][2]+z_const >= zl:
                                    if plotData[j][i][2]+z_const <= zu:
                                        x.append(j)
                                        break
        for j in [fib for fib in range(len(plotData)) if fib in x]:
            profileList[j].VisibilityOff()

    elif key == 'R': #change colours - red
        x = []
        t = vtk.vtkTransform()
        boxWidget.GetTransform(t)
        boxWidget.GetProp3D().SetUserTransform(t)
        xl, xu, yl, yu, zl, zu = boxWidget.GetProp3D().GetBounds()
        for j in range(0,len(plotData)):
            for i in range(0,len(plotData[j])):
                if plotData[j][i][0]+x_const >= xl:
                    if plotData[j][i][0]+x_const <= xu:
                        if plotData[j][i][1]+y_const >= yl:
                            if plotData[j][i][1]+y_const <= yu:
                                if plotData[j][i][2]+z_const >= zl:
                                    if plotData[j][i][2]+z_const <= zu:
                                        x.append(j)
                                        break
        for j in x:
            profileList[j].GetProperty().SetDiffuseColor(1,0,0)
            
    elif key == 'G': #change colours - green
        x = []
        t = vtk.vtkTransform()
        boxWidget.GetTransform(t)
        boxWidget.GetProp3D().SetUserTransform(t)
        xl, xu, yl, yu, zl, zu = boxWidget.GetProp3D().GetBounds()
        for j in range(0,len(plotData)):
            for i in range(0,len(plotData[j])):
                if plotData[j][i][0]+x_const >= xl:
                    if plotData[j][i][0]+x_const <= xu:
                        if plotData[j][i][1]+y_const >= yl:
                            if plotData[j][i][1]+y_const <= yu:
                                if plotData[j][i][2]+z_const >= zl:
                                    if plotData[j][i][2]+z_const <= zu:
                                        x.append(j)
                                        break
        for j in x:
            profileList[j].GetProperty().SetDiffuseColor(0,1,0)

    elif key == 'B': #change colours - blue
        x = []
        t = vtk.vtkTransform()
        boxWidget.GetTransform(t)
        boxWidget.GetProp3D().SetUserTransform(t)
        xl, xu, yl, yu, zl, zu = boxWidget.GetProp3D().GetBounds()
        for j in range(0,len(plotData)):
            for i in range(0,len(plotData[j])):
                if plotData[j][i][0]+x_const >= xl:
                    if plotData[j][i][0]+x_const <= xu:
                        if plotData[j][i][1]+y_const >= yl:
                            if plotData[j][i][1]+y_const <= yu:
                                if plotData[j][i][2]+z_const >= zl:
                                    if plotData[j][i][2]+z_const <= zu:
                                        x.append(j)
                                        break
        for j in x:
            profileList[j].GetProperty().SetDiffuseColor(0,0,1)


    elif key =='d' : #toggle dipole actor on / off
        if dipoleSphereActor.GetVisibility() == 0:
            dipoleSphereActor.SetVisibility(1)
            dipoleLineActor.SetVisibility(1)
        else:
            dipoleSphereActor.SetVisibility(0)
            dipoleLineActor.SetVisibility(0) 
             
    elif key =='6': #make picked object grey
        # TODO - make this take any colour
        currentActor = picker1.GetActor()
        props = currentActor.GetProperty()
        props.SetDiffuseColor(0.8,0.8,0.8)
        
    elif key =='7': #make picked object red
        currentActor = picker1.GetActor()
        props = currentActor.GetProperty()
        props.SetDiffuseColor(1,0,0)
        
    elif key =='8': #make picked object green
        currentActor = picker1.GetActor()
        props = currentActor.GetProperty()
        props.SetDiffuseColor(0,1,0)
        
    elif key =='9': #make picked object blue
        currentActor = picker1.GetActor()
        props = currentActor.GetProperty()
        props.SetDiffuseColor(0,0,1)
    
    elif key == 'F': #call load functional data routine
        load_functional_data()
        func_data_loaded = 1
        func_data_visible = 1     

    elif key == 'H': # cycle views Top, Bottom, Front, Back, Left, Right
        ttt = ren.GetActiveCamera()
        if current_mpr_orient == 0:
            ttt.SetPosition(90,145,1000)
            ttt.SetViewUp(0,1,0)
            ttt.SetViewAngle(30)
            current_mpr_orient = 1
        elif current_mpr_orient == 1:
            ttt.SetPosition(90,145,-717)
            ttt.SetViewUp(0,1,0)
            ttt.SetViewAngle(30)
            current_mpr_orient = 2
        elif current_mpr_orient == 2:
            ttt.SetPosition(90,1000,154)
            ttt.SetViewUp(0,0,1)
            ttt.SetViewAngle(30)
            current_mpr_orient = 3
        elif current_mpr_orient == 3:
            ttt.SetPosition(90,-717,155)
            ttt.SetViewUp(0,0,1)
            ttt.SetViewAngle(30)
            current_mpr_orient = 4
        elif current_mpr_orient == 4:
            ttt.SetPosition(-775,145,155)
            ttt.SetViewUp(0,0,1)
            ttt.SetViewAngle(30)
            current_mpr_orient = 5
        elif current_mpr_orient == 5:
            ttt.SetPosition(1000,145,149)
            ttt.SetViewUp(0,0,1)
            ttt.SetViewAngle(30)
            current_mpr_orient = 0           
    
    elif key == 'E': 
        if planeWidgetY.GetPlaneProperty().GetOpacity() == 0:
            planeWidgetX.GetPlaneProperty().SetOpacity(1.0)
            planeWidgetY.GetPlaneProperty().SetOpacity(1.0)
            planeWidgetZ.GetPlaneProperty().SetOpacity(1.0)
            #t.GetProperty().SetColor(1,0,0)
        else:
            planeWidgetX.GetPlaneProperty().SetOpacity(0)
            planeWidgetY.GetPlaneProperty().SetOpacity(0)
            planeWidgetZ.GetPlaneProperty().SetOpacity(0)
            #t.GetProperty().SetColor(0.3,0.3,0.3)
    
    elif key == 'D':
        ####################### DIPOLE ACTOR #############
        #NB - import dipole data in cm 
        # divide location and magnitude elements by 100, append 1
        # use as a 1*4 vector
        # import spheres.txt transformation matrix  
        # trans_rot * location vector
        # first 41 lines are header
        dipole_file = my_file_dialog(5,0,0)
        if dipole_file == '':
                print 'Load cancelled' 
                return
        dipole_transform = my_file_dialog(6,0,0)
        if dipole_transform == '':
                print 'Load cancelled' 
                return
        first_dipole_point = my_file_dialog(20,0,0)
        if first_dipole_point == '':
                print 'Load cancelled' 
                return
        dipole_current_time = float(first_dipole_point)
        p = open(dipole_file)       
        d = p.readlines()
        p.close()
        rows = []
        data1 = []
        for c in d:
            f = c.split()
            rows.append(f)
        dipole_file_loaded = 1
        rows = rows[41:] # first 41 lines are header
        dipole_rows = array(rows).astype(float)
        dipole_min_time = dipole_rows[0][0]
        dipole_max_time = dipole_rows[len(rows)-1][0]
        dipole_data = dipole_rows[dipole_rows[:,0] >= float(first_dipole_point), :][0]
        e = reshape(fromfile(dipole_transform, sep=' ') , (4,4))
        coord1 = array([dipole_data[1]/100, dipole_data[2]/100, dipole_data[3]/100,1]).reshape(4,1)
        coord2 = array([dipole_data[4]/100, dipole_data[5]/100, dipole_data[6]/100,1]).reshape(4,1)         
        trans_coord1 = dot(e, coord1).reshape(1,4)
        trans_coord2 = dot(e, coord2).reshape(1,4)
        #  ADD Sphere  #
        Sphere = vtk.vtkSphereSource()
        Sphere.SetCenter( trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2])
        Sphere.SetRadius( dipole_data[7]/2 )
        SphereMapper = vtk.vtkPolyDataMapper()
        SphereMapper.SetInput(Sphere.GetOutput())
        dipoleSphereActor = vtk.vtkActor()
        dipoleSphereActor.SetMapper(SphereMapper)
        dipoleSphereActor.GetProperty().SetColor(1,0,0)
        point1 = [trans_coord1[0][0], trans_coord1[0][1], trans_coord1[0][2]]
        point2 = [trans_coord2[0][0]+trans_coord1[0][0], trans_coord2[0][1]+trans_coord1[0][1], trans_coord2[0][2]+trans_coord1[0][2]]
        point2_rescale = add((subtract(point2,point1))/10,point1)
         #  ADD Line  #
        line = vtk.vtkLineSource()
        line.SetPoint1(point1)
        line.SetPoint2(point2_rescale)
        lineMapper = vtk.vtkPolyDataMapper()
        lineMapper.SetInput(line.GetOutput())
        dipoleLineActor = vtk.vtkActor()
        dipoleLineActor.SetMapper(lineMapper)
        dipoleLineActor.GetProperty().SetColor(1,0,0)
        dipoleLineActor.GetProperty().SetLineWidth(2)
        ren.AddActor(dipoleLineActor)
        ren.AddActor(dipoleSphereActor)
        ## END DIPOLE ACTOR ##
    
    #Apply the changes by re-rendering the main window
    renWin.Render()
    
    
    
##################### END KEY PRESS EVENTS FOR MAIN WINDOW #######################################################

InteractorHelpThread().start()
iact.Initialize()
renWin.Render()
iact.AddObserver("KeyPressEvent", myCallBack)
iact.AddObserver("LeftButtonPressEvent", myCallBack_loadMenuAtStart)
iact.Start()
