# =================================================================
# TCLD2019.py
# Classes for mitochondrial model of Tarraf et al. 2019.
# There are two versions of the model, one with all parameters 
# that can be changed when creating the object from the code, 
# the other with only the seven parameters that were hightlighted 
# by the Sobol analysis
# =================================================================

import numpy as np
from ..misc        import *
from .mitoODEModels import *

class TCLD_2019_all_param(mitoODEModel):

  def __init__(self,config,p=None):
    '''
        Read model parameters from configuration file.
        Define state variable size and indices.
    '''

    self.config  = config
    self.nvars   = 6
    self.dtype   = float

    self.i_DEresp = 0
    self.i_Dp     = 1
    self.i_ADPm   = 2
    self.i_ADPc   = 3
    self.i_Cam    = 4
    self.i_Cac    = 5
    # -----------------------------------------------------
    # Model INPUTS

    self.ADP_add_targets  = [evalFloat(s) for s in config("ADP_add_targets","1" ).split()]
    self.ADP_add_times    = [evalFloat(s) for s in config("ADP_add_times"  ,"10").split()]

    self.nparams = 32

    # If the model is initialized with a specific set of parameters
    if p is not None:
      self.p0 ,self.p1 ,self.p2 ,self.p3 ,self.p4 ,self.p5 ,self.p6 ,self.p7 ,self.p8 ,self.p9 ,\
      self.p10,self.p11,self.p12,self.p13,self.p14,self.p15,self.p16,self.p17,self.p18,self.p19,\
      self.p20,self.p21,self.p22,self.p23,self.p24,self.p25,self.p26,\
      self.C_m,self.f,self.f_i,self.gamma,self.ADP_add_dynamics = p

    # Otherwise, read it from parameter file, or use default values
    else:
      self.p0    = evalFloat(config("p0" ,"3.22224816e+02  "))
      self.p1    = evalFloat(config("p1" ,"1.09381423e-01  "))
      self.p2    = evalFloat(config("p2" ,"1.33029756e+02 "))
      self.p3    = evalFloat(config("p3" ,"1.50748254e-02 "))
      self.p4    = evalFloat(config("p4" ,"1.11938086e+03 "))
      self.p5    = evalFloat(config("p5" ,"1.08287599e+03 "))
      self.p6    = evalFloat(config("p6" ,"35.30339213    "))
      self.p7    = evalFloat(config("p7" ,"0.40381467     "))
      self.p8    = evalFloat(config("p8" ,"2.3e4 "))
      self.p9    = evalFloat(config("p9" ,"1"))
      self.p10   = evalFloat(config("p10" ,"9.79135446e+02"))
      self.p11   = evalFloat(config("p11","1.17193389e-02 "))
      self.p12   = evalFloat(config("p12","2.58422097e+02 "))
      self.p13   = evalFloat(config("p13","-5.59766429e-02"))
      self.p14   = evalFloat(config("p14","1.76167939e+02 "))
      self.p15   = evalFloat(config("p15" ,"-3.61317806e-05 "))
      self.p16   = evalFloat(config("p16","1000"))
      self.p17   = evalFloat(config("p17","0.8"))
      self.p18   = evalFloat(config("p18","615600 "))
      self.p19   = evalFloat(config("p19","0.5  "))
      self.p20   = evalFloat(config("p20" ,"0.00274"))
      self.p21   = evalFloat(config("p21","0.05245"))
      self.p22   = evalFloat(config("p22","22.764"))
      self.p23   = evalFloat(config("p23","0.003 "))
      self.p24   = evalFloat(config("p24","1.006"))
      self.p25   = evalFloat(config("p25","8.4e1"))
      self.p26   = evalFloat(config("p26","0.625  "))
      self.C_m   = 1.45e-3       # nmol/mg prot . (mV)^-1  (MK1997) 
      self.f     = 3e-4          # fraction of free matrix Ca unitless(MK1997) 
      self.f_i   = 1e-2          # fraction of free cyto Ca unitless  (MK1998a) 
      self.gamma = 1.25*3.9/53.2 # From nmol/mg prot to mM cyto volume
      self.ADP_add_dynamic = evalFloat(config("ADP_add_dynamic","0.05"))

    # --------------------------------------------------
    # Fixed parameters

    R = 8.315          # Unit is J mol**(-1) K**(-1) 
    T = 310.16         # Unit is Kelvin 
    F = 96.480         # Unit in mC mol**(-1) 
    self.FRT = F/(R*T) # Unit is m C/J = mV**(-1)

    self.nadTot = 8       # 8 nmol/mg-protein see MK97 page C720 
    self.aTot   = 12      # Unit in nmol/ mg prot
    self.kResp  = 1.35e18 # Unitless (MK97)
    self.kF1    = 1.71e6  # Unit in mM (??)
    self.piM    = 20      # Unit in mM (Cortassa 2003)
    self.piC    = 1.46    # Unit in mM of cyto volume (Invented as Pim)

    self.a_tot_cyto = 2                   # mM in cyto volume  (Mk98a)
    self.kF1_cyto   = self.kF1*3.9/53.2 # (Invented)3.9/53.2 is the mito/cyto ratio MK97 in mM cyto volume

  # ===============================================================
  # Variable definitions from concentrations

  def DEResp(self,NADH):
    return np.log(self.kResp*np.sqrt(NADH/(self.nadTot-NADH)))/self.FRT

  def DGpm(self,adpm):
    return np.log((self.kF1/self.piM)*(self.aTot-adpm)/adpm)/self.FRT

  def DGpc(self,adpC):
    return np.log((self.kF1_cyto*(self.a_tot_cyto-adpC))/(adpC*self.piC))/self.FRT

  def adp_c(self,Dgpc):
    return (self.kF1_cyto*self.a_tot_cyto)/(self.piC*np.exp(self.FRT*Dgpc)+self.kF1_cyto)

  def adp_m(self,Dgpm):
    return (self.kF1*self.aTot)/(self.piM*np.exp(self.FRT*Dgpm)+self.kF1)

  # ===============================================================

  def initializeStateVar(self,solver,data=None):
    '''
        Read from config or given through variable data.
    '''

    Y0 = np.zeros((solver.nstencil,self.nvars),dtype=self.dtype)

    if data is None:
      NADH0 = evalFloat(self.config("nadh_init","0.48"))
      ADPM0 = evalFloat(self.config("adpm_init","1e-5"))
      ADPC0 = evalFloat(self.config("adpc_init","1e-5"))

      Y0[0,self.i_DEresp] = self.DEResp(NADH0)
      Y0[0,self.i_Dp    ] = evalFloat(self.config("dp_init" ,"153"))
      Y0[0,self.i_ADPm  ] = ADPM0
      Y0[0,self.i_ADPc  ] = ADPC0
      Y0[0,self.i_Cam   ] = evalFloat(self.config("cam_init","8.e-5"))
      Y0[0,self.i_Cac   ] = evalFloat(self.config("cac_init","0"))

    else:
      Y0 = data

    if (solver.nstencil!=1):
      # do something for multi steps methods
      pass

    return Y0

  # ===============================================================
  # Differential equations (see section 2.2.1)

  def evolStateVar(self,t,y):
    '''
        Evolution function y'=f(y,t) (Eq 2.18)
    '''

    res = np.zeros((y.shape[-1],),dtype=self.dtype)

    DE   = y[self.i_DEresp]
    Dp   = y[self.i_Dp    ]
    ADPm = y[self.i_ADPm  ]
    ADPc = y[self.i_ADPc  ]
    Cam  = y[self.i_Cam   ]
    Cac  = y[self.i_Cac   ]

    DGm = self.DGpm(ADPm)
    DGc = self.DGpc(ADPc)

    JHres = self.jHrespFit((DE,Dp),self.p0,self.p1,self.p2,self.p3,self.p4,self.p5)
    JF1F0 = self.jHF1F0Fit((DGm,Dp),self.p10,self.p11,self.p12,self.p13,self.p14,self.p15)
    JANT  = self.jANT((DGm,DGc,Dp),self.p16,self.p17,self.p18,self.p19)
    Juni  = self.jHUniFit((Cac,Dp),self.p24,self.p25,self.p26)
    Jnaca = self.JNaCa(Cam,self.p22,self.p23)
    Jpdh  = self.jPdhFit0((DE,Cam),self.p6,self.p7,self.p8,self.p9)
    Jleak = self.jleak(Dp,self.p20,self.p21)

    Jsource_ADP = self.J_ADP_source1(t,self.ADP_add_targets[0],self.ADP_add_targets[1],\
                                       self.ADP_add_times  [0],self.ADP_add_times  [1],\
                                       self.ADP_add_dynamic)

    # The phi2 and phi3 functions were finally dropped, and we solve in ADP(m,c), 
    # since ADP concentrations can be 0 in the cyto at the beginning of somes simulations
    res[self.i_DEresp] = self.phi1(DE) * (Jpdh - JHres)
    res[self.i_Dp    ] = (12 * JHres - 3 * JF1F0 - Jleak - JANT - 2 * Juni) / self.C_m
    res[self.i_ADPm  ] = (-JF1F0+JANT)
    res[self.i_ADPc  ] = (self.gamma * (-JANT)  + Jsource_ADP)
    res[self.i_Cam   ] = self.f * (Juni - Jnaca )
    res[self.i_Cac   ] = - self.f_i * self.gamma * 10**3 * (Juni - Jnaca)

    return res

  # ===============================================================
  # Derivatives of variables changing functions

  def phi1(self,Deres):
      alpha = (np.exp(self.FRT*Deres)/self.kResp)**2
      return 0.5/self.FRT*(1.+alpha)**2/(alpha*self.nadTot)

  def phi2(self,Dgpm):
      beta = np.exp(self.FRT*Dgpm)*self.piM/self.kF1
      return -(1+beta)**2/(self.FRT*self.aTot*beta)

  def phi3(self,Dgpc):
      beta = np.exp(self.FRT*Dgpc)*self.piC/self.kF1_cyto
      return -(1+beta)**2/(self.FRT*self.a_tot_cyto*beta)

  # ===============================================================
  # Fluxes (fitted versions, see section 2.2.2)

  # The order of arguments was different than in the misc.py file, so I keep these functions here
  def Sigmoid(self,x,vmin,vmax,k,thr):
      return vmin+(vmax-vmin)*0.5*(1+np.tanh(k*(x-thr)))

  def jHrespFit(self,x,p0,p1,p2,p3,p4,p5):
    deltaEResp,deltaP = x
    return self.Sigmoid(deltaEResp,0,self.Sigmoid(deltaP,p0,0,p1,p2),p3,self.Sigmoid(deltaP,p4,p5,p1,p2))

  def jPdhFit0(self,x,p6,p7,p8,p9):
    deltaEResp,cam = x
    return p6*(1-p7*np.exp(-p8*cam)) / ((np.exp(self.FRT*deltaEResp-np.log(self.kResp)))**2+p9)

  def jHF1F0Fit(self,x,p10,p11,p12,p13,p14,p15):
    deltaGp,deltaP = x
    return self.Sigmoid(deltaGp,p10,p15,p11,p12)*self.Sigmoid(deltaP,1,p15,p13,p14)

  def jANT(self,x,p16,p17,p18,p19):
    Dgpm,Dgpc,Dpsi = x
    E = np.exp(self.FRT*Dgpm)
    return p16*(E-p17*np.exp(self.FRT*(Dgpc-Dpsi)))/(1+(p17/p18)*np.exp(self.FRT*(Dgpc-p19*Dpsi)))/(E+p18)

  def jHUniFit(self,x,p24,p25,p26):
    caC,deltaP = x
    return np.maximum(p24*(deltaP-p25)*(caC-p26),0.)

  def jleak(self,x,p20,p21):
    deltaP = x
    return p20*np.exp(p21*deltaP)

  def JNaCa(self,x,p22,p23):
    cam = x
    return p22*cam/(p23+cam)

  def J_ADP_source1(self,t,ADP1,ADP2,tADP1,tADP2,tau):
    if type(t)==int or type(t)==float:
      if t<tADP1:
        res = 0
      elif t<=tADP2:
        res = ADP1*(np.exp(-(t-tADP1)/tau))/tau
      else:
        res = (ADP1*(np.exp(-(t-tADP1)/tau)) + ADP2*(np.exp(-(t-tADP2)/tau)))/tau

    else:
      res = np.zeros(np.shape(t))
      for i in range(np.size(t)):
        if t[i]<tADP1:
            res[i] = 0
        elif t[i]<=tADP2:
            res[i] = ADP1*(np.exp(-(t[i]-tADP1)/tau))/tau
        else:
            res[i] = (ADP1*(np.exp(-(t[i]-tADP1)/tau)) + (ADP2)*(np.exp(-(t[i]-tADP2)/tau)))/tau
    return res

  # ===============================================================
  # Output stuff

  def writeOutputHeader(self,outFile):
    '''
        Prints info in the already opened output file
    '''

    outFile.write("# Output of model of TCLD 2019 (corrected) with input parameters from %s\n" % self.config.fileName_)
    outFile.write("# Columns are\n")
    outFile.write("# 0. Time (s).\n")
    outFile.write("# 1. Delta E resp (mV)\n")
    outFile.write("# 2. Delta psi (mV)\n")
    outFile.write("# 3. ADP mito (XXX FIXME unit XXX)\n")
    outFile.write("# 4. ADP cyto (XXX FIXME unit XXX)\n")
    outFile.write("# 5. Calcium mito (XXX FIXME unit XXX)\n")
    outFile.write("# 6. Calcium cyto (XXX FIXME unit XXX)\n")
    outFile.write("#\n")

  def getOutputStateVar(self,t,y):
    '''
         Returns line to write in ascii output file.
    '''

    mystr = "%1.12f "%(t)
    for i in range(self.nvars):
      mystr = mystr + "{:1.12e} ".format(y[i])
    return mystr + "\n"

# =================================================================

class TCLD_2019_7_param(TCLD_2019_all_param):
  ''' Same as TCLD but with less parameters (7 influential params from Sobol)'''

  def __init__(self,config,p=None):

    # Common elements with base class
    super(TCLD_2019_7_param,self).__init__(config,p=None)

    # Different parameters
    if p is not None:

      self.p2 ,self.p5,self.p11 , self.p12,self.p14 , self.p21, self.gamma  = p

    else:
      self.p2  = evalFloat(config("p2" ,"1.33029756e+02 "))
      self.p5  = evalFloat(config("p5" ,"1.08287599e+03 "))
      self.p11 = evalFloat(config("p11" ,"1.16496340e-02"))
      self.p12 = evalFloat(config("p12" ,"2.58422097e+02 "))
      self.p14 = evalFloat(config("p14","1.76167939e+02   "))
      self.p21 = evalFloat(config("p21","0.05245  "))
      self.gamma = 1.25*3.9/53.2      # From nmol/mg prot to mM cyto volume

# =================================================================
# Models written in the SALib formalism

def swapNegatives(l):
    for b in l:
      if b[0]<0:
        tmp  = b[0]
        b[0] = b[1]
        b[1] = tmp

# Complete model with all parameters
# The last parameter, gamma, does not have a default value
# from the litterature, so we use a wider interval for sampling
# this parameter.
meansFull = np.asarray([3.22224816e+02, 1.09381423e-01,1.33029756e+02, 1.50748254e-02,
                        1.11938086e+03, 1.08287599e+03,35.303392     , 0.40381467    ,
                        2.3e4         , 1.            ,9.79135446e+02, 1.17193389e-02,
                        2.58422097e+02,-5.59766429e-02,1.76167939e+02,-3.61317806e-05,
                        1000          , 0.8           ,615600        , 0.5           ,
                        0.00274       , 0.05245       ,22.764        , 0.003         ,
                        1.006         , 8.4e1         ,0.625         , 1.45e-3       ,
                        3e-4          , 1e-2          ,0.05          ])

# Create intervals for all parameters
boundsFull = np.asarray([meansFull*0.9,meansFull*1.1]).T.tolist()
swapNegatives(boundsFull)
# SALib dictionary
SALibTcldFull = {'num_vars': 32,
                 'bounds'  : boundsFull+ [[0.14*3.9/53.2,0.35*3.9/53.2]],
                 'names'   : [r"$p_0$"   ,r"$p_1$"   ,r"$p_2$"   ,r"$p_3$"   ,r"$p_4$"   ,\
                              r"$p_5$"   ,r"$p_6$"   ,r"$p_7$"   ,r"$p_8$"   ,r"$p_9$"   ,\
                              r"$p_{10}$",r"$p_{11}$",r"$p_{12}$",r"$p_{13}$",r"$p_{14}$",\
                              r"$p_{15}$",r"$p_{16}$",r"$p_{17}$",r"$p_{18}$",r"$p_{19}$",\
                              r"$p_{20}$",r"$p_{21}$",r"$p_{22}$",r"$p_{23}$",r"$p_{24}$",\
                              r"$p_{25}$",r"$p_{26}$",\
                              r"$C_{\mathrm{m}}$",r"$f$",r"$f_i$",r"ADP dyn.",r"$\gamma$"]}

# -----------------------------------------------------
# Model with some parameters excluded from flux analysis
meansReduced = np.asarray([1.33029756e+02,1.11938086e+03,1.08287599e+03,35.30339213   ,
                           0.40381467    ,1.            ,1.17193389e-02,2.58422097e+02,
                           1.76167939e+02,1000          ,0.8           ,615600        ,
                           0.5           ,0.05245       ,22.764        ,0.003         ,
                           8.4e1         ,0.625         ,1.45e-3       ,3e-4          ,
                           1e-2          ,0.05])
boundsReduced = np.asarray([meansReduced*0.9,meansReduced*1.1]).T.tolist()
TcldReduced = {'num_vars': 32-9,
               'bounds'  : boundsReduced+ [[0.14*3.9/53.2,0.35*3.9/53.2]],\
               'names'   : [r"$p_2$"   ,r"$p_4$"   ,r"$p_5$"   ,r"$p_6$"   ,\
                            r"$p_7$"   ,r"$p_9$"   ,r"$p_{11}$",r"$p_{12}$",\
                            r"$p_{14}$",r"$p_{16}$",r"$p_{17}$",r"$p_{18}$",\
                            r"$p_{19}$",r"$p_{21}$",r"$p_{22}$",r"$p_{23}$",\
                            r"$p_{25}$",r"$p_{26}$",r"$C_{\mathrm{m}}$",\
                            r"$f$",r"$f_i$",r"ADP dyn.",r"$\gamma$"]}

# -----------------------------------------------------
# Model with seven parameters only.
means7Param  = np.asarray([1.33029756e+02,1.08287599e+03,1.17193389e-02,
                           2.58422097e+02,1.76167939e+02,0.05245])
bounds7Param = np.asarray([means7Param*0.9,means7Param*1.1]).T.tolist()
SALibTcld7Params = {'num_vars': 7,
                    'bounds'  : bounds7Param + [[0.14*3.9/53.2,0.35*3.9/53.2]],\
                    'names'   : [r"$p_2$",r"$p_5$",r"$p_{11}$",r"$p_{12}$",\
                                 r"$p_{14}$",r"$p_{21}$",r"$\gamma$"]}

