#!/usr/bin/python3
# -----------------------------------------------------------------
# 01 - fitAndPlotFluxes.py
#
# Authors: B.Tarraf, M.Leguebe
# 
# Script used to perform multi-D fits of the MK model
# flux functions to our custom functions with hopefully
# much less parameters.
# First guesses given to curve_fit are obtained by performing
# several monodirectional fits beforehand.
#
# -----------------------------------------------------------------

import os
import sys
import copy
import shutil
import numpy as np
import time
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib           import cm
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize       import curve_fit

# Import definitions of fluxes and some common constants

# Uncomment the following line and set the path 
# or use PYTHONPATH environment variable :
#sys.path.insert(0,'/path/to/pyCompMito/')

from pyCompMito.fluxFitting.commonParams import *
from pyCompMito.fluxFitting.MkFluxes     import *
from pyCompMito.fluxFitting.newFluxes    import *

# =================================================================
# Plot parameters

npts = 50 # Points in each direction, for surfaces

plt.rc('text',usetex=True)
plt.rc('font',family='serif')

# Select fluxes to fit
fitAll    = True
fitJHresp = False
fitJPdh   = False
fitJHF1F0 = False
fitJHUni  = False

# Select plots to display
plotAll    = True
plotJHresp = False
plotJPdh   = False
plotJHF1F0 = False
plotJHUni  = False

fs             = 24  # Reference font size
plotPoints     = npts
plotPointsGrid = 20

# =================================================================
# Definition of our variables from previous vars
def DEResp(NADH):
  return np.log(kResp*np.sqrt(NADH/(nadTot-NADH)))/FRT
def DGp(atpM):
  return np.log(kF1/piM*atpM/(0.8*aTot-atpM))/FRT

# Intervals for state variables
der  = np.linspace(DEResp(1.e-2),DEResp(nadTot-1.e-2),npts) # NADH in nmol/mgProt
dpsi = np.linspace(100,200,npts)                            # mV
cam  = np.linspace(0,0.5,npts)                              # ??
cac  = np.linspace(0.0001,4.e-3,npts)                       # nmol/mgProt
dgp  = np.linspace(DGp(1.e-2),DGp(0.8*aTot-1.e-6),npts)     # atp in ???

# =================================================================
def plotFitError(x,y,toFit,fitted,xlabel,ylabel,zlabel,
                 cameraPos=None,xticks=None,yticks=None,zticks=None):

    ''' Plot the fitted surface, colored by pointwise fit error
        cameraPos = (elev,azim) controls the 3D view
        use x,y,zticks if you want personnalized ticks
    '''

    fig = plt.figure(figsize=(10,6))
    ax  = fig.add_subplot(111,projection='3d',proj_type = 'ortho')
    if cameraPos is not None:
      ax.view_init(elev=cameraPos[0],azim=cameraPos[1])
    # Error is not normalized (it could, by the max of toFit)
    err    = (fitted-toFit)
    maxErr = np.max(np.abs(err))
    # Center the colorbar of the error at 0
    norm   = colors.Normalize(vmin=-maxErr,vmax=maxErr)
    err2   = np.zeros(err.shape)
    for i in range(npts):
      # Norm does not operate on 2D arrays...
      tmp = norm(err[i])
      err2[i] = norm(err[i])
    p    = ax.plot_surface(x,y,fitted,shade=False,ccount=plotPoints,rcount=plotPoints,facecolors=cm.bwr(err2))
    p    = ax.plot_wireframe(x,y,fitted,linewidth=0.2,color='#777777',ccount=plotPointsGrid,rcount=plotPointsGrid)
    p    = ax.plot_wireframe(x,y,fitted,linewidth=1,color='k',ccount=1,rcount=1)
    err3 = np.linspace(-maxErr,maxErr,100)
    m    = cm.ScalarMappable(cmap=cm.bwr)
    m.set_array(err3)
    cb   = plt.colorbar(m,shrink=0.8)
    cb.ax.set_title (r'Fit error',fontsize=fs)
    ax.   set_xlabel(xlabel,fontsize=fs,labelpad=fs)
    ax.   set_ylabel(ylabel,fontsize=fs,labelpad=fs)
    ax.   set_zlabel(zlabel,fontsize=fs,labelpad=fs/2.)
    # Ticks
    ax   .tick_params(axis='both',which='major',labelsize=fs)
    cb.ax.tick_params(axis='both',which='major',labelsize=fs)
    if xticks is not None:
      ax.set_xticks(xticks)
    if yticks is not None:
      ax.set_yticks(yticks)
    if zticks is not None:
      ax.set_zticks(yticks)
    fig.tight_layout()

if plotAll or plotJHresp or plotJPdh or plotJHF1F0 or plotJHUni :
  plt.close('all')
  plt.ion()

# =================================================================
# JHresp

if fitAll or fitJHresp:

  print ("\nFitting flux JHresp\n")

  dd,pp = np.meshgrid(der,dpsi,indexing='ij')
  toFit = jHresMk(dd,pp)

  params,cov = curve_fit(jHrespFit,(dd.ravel(),pp.ravel()),toFit.ravel(),p0=[3600,0.11,132,0.018,1120,1060])
  print("Estimated parameters:",params)

  fitted = jHrespFit((dd,pp),*params)
  nSol   = np.sqrt(np.sum(toFit**2))
  nErr   = np.sqrt(np.sum((toFit-fitted)**2))
  print("Fitting error, L2: ",nErr/nSol)

  if plotAll or plotJHresp:
    plotFitError(dd,pp,toFit,fitted,r'$\Delta E_{\mathrm{resp}}$ (mV)',\
                 r'$\Delta \psi$ (mV)',r'$J_{\mathrm{H,resp}}$',cameraPos=(30.,135.),\
                 xticks=[1050,1100,1150,1200],yticks=[100,120,140,160,180,200])

# ==================================================================
# JPDH

if fitAll or fitJPdh:

  print ("\nFitting flux JPDH\n")

  dd,cc = np.meshgrid(der,cam,indexing='ij')
  toFit = jPdhBertram(dd,cc)

  params,cov = curve_fit(jPdhFit0,(dd.ravel(),cc.ravel()),toFit.ravel(),p0=[35,0.4,18.8,1])
  print("Estimated parameters:",params)

  fitted = jPdhFit0((dd,cc),*params)
  nSol   = np.sqrt(np.sum(toFit**2))
  nErr   = np.sqrt(np.sum((toFit-fitted)**2))
  print("Fitting error, L2: ",nErr/nSol)

  if plotAll or plotJPdh:
    plotFitError(dd,cc,toFit,fitted,r'$\Delta E_{\mathrm{resp}}$ (mV)',\
                 r'$[\mathrm{Ca}^{2+}]_\mathrm{m}$ ($\mu$M)',r'$J_{\mathrm{PDH}}$',cameraPos=(30.,-45))

# =================================================================
# JHF1F0

if fitAll or fitJHF1F0:

  print ("\nFitting flux JHF1F0\n")

  dd,pp = np.meshgrid(dgp,dpsi,indexing='ij')
  toFit = jHF1F0Mk(dd,pp)

  params,cov = curve_fit(jHF1F0Fit,(dd.ravel(),pp.ravel()),toFit.ravel(),p0=[2400,0.01,250,0.05,175])
  print("Estimated parameters:",params)

  fitted = jHF1F0Fit((dd,pp),*params)
  nSol   = np.sqrt(np.sum(toFit**2))
  nErr   = np.sqrt(np.sum((toFit-fitted)**2))
  print("Fitting error, L2: ",nErr/nSol)

  if plotAll or plotJHF1F0:
    plotFitError(dd,pp,toFit,fitted,r'$\Delta G_{\mathrm{p}}$ (mV)',\
                 r'$\Delta p$ (mV)',r'$J_{\mathrm{F1F0}}$',cameraPos=(30.,-45))


# =================================================================
# JHAnt
# For JHAnt, nothing needs to be done, as the MK expression is already quite simple

# =================================================================
# JHuni

if fitAll or fitJHUni:

  print ("Fitting flux JHUni\n")

  cc,pp = np.meshgrid(cac,dpsi,indexing='ij')
  toFit = jHUniMk(cc,pp)

  params,cov = curve_fit(jHUniFit,(cc.ravel(),pp.ravel()),toFit.ravel(),p0=[1000,-1.e5,-0.4,30])
  print("Estimated parameters:",params)

  fitted = jHUniFit((cc,pp),*params)
  nSol   = np.sqrt(np.sum(toFit**2))
  nErr   = np.sqrt(np.sum((toFit-fitted)**2))
  print("Fitting error, L2: ",nErr/nSol)

  if plotAll or plotJHUni:
    plotFitError(cc,pp,toFit,fitted,r'$[\mathrm{Ca}]_\mathrm{c}$ (nmol/mg prot)',\
                 r'$\Delta p$ (mV)',r'$J_{\mathrm{uni}}$',cameraPos=(30.,225))

