#!/usr/bin/env python

import wx
import sys
import vtk
from vtk.wx.wxVTKRenderWindowInteractor import wxVTKRenderWindowInteractor
from dv3dPlaneWidgets import *
import time
from common_functions import align_all_planes


class OrthoViews(wx.Frame):
    def __init__(self, frame_parent):#, plane1, plane2, plane3):
        wx.Frame.__init__(self, None, wx.ID_ANY, 'DV3D - Orthogonal Views')
        
        #so the main routine can update this window when required
        self.frame_parent = frame_parent
        
        # Add a panel so it looks the correct on all platforms
        self.panel = wx.Panel(self, wx.ID_ANY)
        bagSizer    = wx.GridBagSizer(hgap=5, vgap=5)

        #a properties panel
        self.properties_panel = wx.Panel(self.panel, wx.ID_ANY)
        
        self.Sync_with_main_window = wx.CheckBox(self.properties_panel , -1, 'Auto update? (slower)')
        self.Sync_with_main_window.SetValue(False)
        self.Sync_with_main_window.Bind(wx.EVT_CHECKBOX, self.OnSync_with_main_window,
              id=-1)

        self.L_R_flip = wx.CheckBox(self.properties_panel , -1, 'L_R_flip', pos = (0,90))
        self.L_R_flip.SetValue(False)
        self.L_R_flip.Bind(wx.EVT_CHECKBOX, self.OnL_R_flip ,
              id=-1) 

        self.enable_planes = wx.CheckBox(self.properties_panel , -1, 'Enable plane interaction', pos = (0,60))
        self.enable_planes.SetValue(False)
        self.enable_planes.Bind(wx.EVT_CHECKBOX, self.OnEnable_planes ,
              id=-1)        

        self.Manual_sync_with_main_window = wx.Button(self.properties_panel , -1, 'Manually update now', pos = (0,30))
        self.Manual_sync_with_main_window.Bind(wx.EVT_BUTTON, self.OnManual_sync_with_main_window,
              id=-1)
        self.Manual_sync_with_main_window.Enable(1)

        
        #the vtk stuff
        self.widget1 = wxVTKRenderWindowInteractor(self.panel, -1)
        self.widget2 = wxVTKRenderWindowInteractor(self.panel, -1)
        self.widget3 = wxVTKRenderWindowInteractor(self.panel, -1)
    
        widget_list = [self.widget1 , self.widget2 , self.widget3]
        self.picker = vtk.vtkCellPicker()
        
        self.ortho_planes = []
        
        for i in range(3):
            orthoWidget = create_new_planewidget(widget_list[0],                              \
                                            self.frame_parent.my_loaded_volumes[0][0].volume_data.GetOutput(),                                   \
                                            self.frame_parent.slice_pos[i],   \
                                            i,self.picker,0,self.frame_parent)
            orthoWidget.SetLookupTable(self.frame_parent.ListOfObjects[0].GetLookupTable())
            self.ortho_planes.append(orthoWidget)

        self.plane1, self.plane2, self.plane3 = self.ortho_planes[0], self.ortho_planes[1], self.ortho_planes[2]       
        
        
        self.plane1.SetInteractor(self.widget1)
        self.plane2.SetInteractor(self.widget3) #NB dif numbers for radiological convention layout
        self.plane3.SetInteractor(self.widget2)
       
        self.ortho_ren1 = vtk.vtkRenderer()
        self.ortho_ren1.GetActiveCamera().ParallelProjectionOn()
        self.ortho_ren2 = vtk.vtkRenderer()
        self.ortho_ren2.GetActiveCamera().ParallelProjectionOn()
        self.ortho_ren3 = vtk.vtkRenderer()
        self.ortho_ren3.GetActiveCamera().ParallelProjectionOn()

        self.widget1.Enable()
        self.widget2.Enable()
        self.widget3.Enable()

        self.widget1.GetRenderWindow().AddRenderer(self.ortho_ren1)
        self.widget1.GetRenderWindow().SetSize(300,300)
        self.widget2.GetRenderWindow().AddRenderer(self.ortho_ren2)
        self.widget2.GetRenderWindow().SetSize(300,300)
        self.widget3.GetRenderWindow().AddRenderer(self.ortho_ren3)
        self.widget3.GetRenderWindow().SetSize(300,300)

        self.interactor_style1 = vtk.vtkInteractorStyleTrackballCamera()
        self.interactor_style2 = vtk.vtkInteractorStyleTrackballCamera()
        self.interactor_style3 = vtk.vtkInteractorStyleTrackballCamera()

        self.widget1.SetInteractorStyle(self.interactor_style1)
        self.widget2.SetInteractorStyle(self.interactor_style2)
        self.widget3.SetInteractorStyle(self.interactor_style3)

        self.plane1.On()
        self.plane2.On()
        self.plane3.On()           
       
        self.widget1.AddObserver("ExitEvent", lambda o,e,f=self: f.Close())
        self.widget2.AddObserver("ExitEvent", lambda o,e,f=self: f.Close())
        self.widget3.AddObserver("ExitEvent", lambda o,e,f=self: f.Close())


        bagSizer.Add(self.widget1, pos=(0,0),
                     flag=wx.EXPAND,
                     border=5)
        bagSizer.Add(self.widget2, pos=(1,0),
                     flag=wx.EXPAND,
                     border=5)
        bagSizer.Add(self.widget3, pos=(0,1),
                     flag=wx.EXPAND,
                     border=5)
        
        bagSizer.Add(self.properties_panel, pos=(1,1))    
        

        bagSizer.AddGrowableCol(0)
        bagSizer.AddGrowableCol(1)
        bagSizer.AddGrowableRow(0)
        bagSizer.AddGrowableRow(1)
                
        bounds = self.frame_parent.my_loaded_volumes[0][0].data_bounds
        
        self.midx = (bounds[1]-bounds[0])/2
        self.midy = (bounds[3]-bounds[2])/2
        self.midz = (bounds[5]-bounds[4])/2
        
                
        cur_cam = self.ortho_ren1.GetActiveCamera()
        cur_cam.SetPosition(1000, self.midy, self.midz)
        cur_cam.SetFocalPoint(self.midx, self.midy, self.midz)
        cur_cam.SetViewUp(0,0,1)
        
        cur_cam = self.ortho_ren2.GetActiveCamera()
        cur_cam.SetPosition(self.midx, self.midy, -1000)
        cur_cam.SetFocalPoint(self.midx, self.midy, self.midz)
        cur_cam.SetViewUp(0,1,0)
        
        cur_cam = self.ortho_ren3.GetActiveCamera()
        cur_cam.SetPosition(self.midx, 1000, self.midz)
        cur_cam.SetFocalPoint(self.midx, self.midy, self.midz)
        cur_cam.SetViewUp(0,0,1)
        
        self.ortho_ren1.ResetCamera()  
        self.widget1.Render()
        self.ortho_ren2.ResetCamera()
        self.widget3.Render()
        self.ortho_ren3.ResetCamera()
        self.widget2.Render()
        
        self.widget1.Disable()
        self.widget2.Disable()
        self.widget3.Disable()
        
        self.panel.SetSizer(bagSizer)
       
        # SetSizeHints(minW, minH, maxW, maxH)
        self.SetSizeHints(400,400,750,750)
        bagSizer.Fit(self)
        
        self.Bind(wx.EVT_CLOSE, self.OnCloseWindow)
        
        self.frame_parent.SYNC_ortho = 0

    def OnL_R_flip(self, event):
        if self.L_R_flip.GetValue() == True:
            cur_cam = self.ortho_ren2.GetActiveCamera()
            cur_cam.SetPosition(self.midx, self.midy, +1000)
            self.ortho_ren2.ResetCamera()
            cur_cam = self.ortho_ren3.GetActiveCamera()
            cur_cam.SetPosition(self.midx, -1000, self.midz)
            self.ortho_ren3.ResetCamera()

        else:
            cur_cam = self.ortho_ren2.GetActiveCamera()
            cur_cam.SetPosition(self.midx, self.midy, -1000)
            self.ortho_ren2.ResetCamera()
            cur_cam = self.ortho_ren3.GetActiveCamera()
            cur_cam.SetPosition(self.midx, 1000, self.midz)
            self.ortho_ren3.ResetCamera()


        self.widget1.Render()
        self.widget2.Render()
        self.widget3.Render()
    
    def OnEnable_planes(self, event):
        if self.enable_planes.GetValue() == True:
            self.widget1.Enable()
            self.widget2.Enable()
            self.widget3.Enable()
        else:
            self.widget1.Disable()
            self.widget2.Disable()
            self.widget3.Disable()

    def OnSync_with_main_window(self, event):
        if self.Sync_with_main_window.GetValue() == True:
            self.frame_parent.SYNC_ortho = 1
            self.Manual_sync_with_main_window.Enable(0)
        else:
            self.frame_parent.SYNC_ortho = 0
            self.Manual_sync_with_main_window.Enable(1)
            
    def OnManual_sync_with_main_window(self, event):
        self.ortho_ren1.ResetCamera()  
        self.widget1.Render()
        self.ortho_ren2.ResetCamera()
        self.widget3.Render()
        self.ortho_ren3.ResetCamera()
        self.widget2.Render()


    def OnCloseWindow(self, event):
        try:
            # pass False to the close routine to veto the attempt to close
            #  the window .. only close it when the main app widow closes
            self.Close(0)
        except:
            pass #sys.exit()