#!/usr/bin/python3
# -----------------------------------------------------------------
# costFunction.py
#
# Authors: B.Tarraf, M.Leguebe
#
# Evaluation of the distance between an experimental data set
# and a run of the ODE solver with our model
#
# -----------------------------------------------------------------

import os
import sys
import copy
import shutil
import numpy as np
import time
import SALib

from scipy.optimize import curve_fit

from .misc        import *
from .ODESolvers  import *
from .experiments import *

# ----------------------------------------
# debug: catch divisions by 0
np.seterr(all='raise')

# =================================================================

def runSingleTCLD(p,manip):

    ''' Performs the simulation for a given set of parameters.
        Output is the respiration rate at times before and
        after additions of ADP.
    '''

    config = myConfigParser(manip.paramFile)
    outDir = config('OutDir','./results')
    mkdir_p(outDir)

    solver   = getODESolver(config)
    model    = TCLD_2019_7_param(config,p)
    Y0       = model.initializeStateVar(solver)

    try:

      # Run the simulation
      # Generate an ID from the random parameters
      ID = ""
      for param in p:
        ID = ID + "{:d}".format(int(param*1e4))
      d = outDir + "/res_" + ID + ".dat"
      solver.run(model,Y0,d,quiet=True)
      Y = np.loadtxt(d,comments='#')

      simTimes = Y[:,0]
      newTimes = np.arange(manip.startTime,manip.endTime+1e-5,1e-5)
      iStart   = np.argmin(abs(simTimes-manip.startTime))
      J        = model.jHrespFit((Y[iStart:,model.i_DEresp+1],Y[iStart:,model.i_Dp+1]),\
                                 model.p0,model.p1,model.p2,model.p3,model.p4,model.p5)
      # Interpolate the flux J to match experimental times
      J        = np.interp(newTimes,simTimes[iStart:],J)
      finished = True
      os.remove(d)

    except:
      # If it fails for whatever reason, dummy results
      J        = 0
      newTimes = 0
      finished = False

    return finished,newTimes,J

def line(x,a,b):
    return a*x+b

def intersection(a1,b1,a2,b2):
    return (b2-b1)/(a1-a2)

def normalize(a):
    return (np.abs(a)-np.min(np.abs(a)))/(np.max(np.abs(a))-np.min(np.abs(a)))

def manipExtractor(simTimes,manip):

    ''' Extracts the average respiration rates
        from experimental data.
        Fit are linear, over the interval defined
        in the 'experiment' class

        simTimes : simulation time stamps
        manip    : experiment characteristic times

        Output: the rates of states sub, 3, 4 and 3 again,
                and the indices of the experimental time points
                used for the fits
    '''

    maxO2   = 525 # nmol/ mg prot supposed by manip when mitos 
                  # are included 420uM O2 with 0.4mgprot/mL mitos
    expData = np.load(manip.dataFile)

    # Get the boundaries of the fit intervals
    iBounds = []
    for interval in manip.bounds:
      iBounds.append([np.argmin(abs(simTimes-interval[0])),\
                      np.argmin(abs(simTimes-interval[1]))])

    # interpolate data on simulation simTimes and 
    # normalize between 0 and maxO2  
    data = np.interp(simTimes,expData[0,:],expData[1,:])
    data = normalize(data)*maxO2

    # linear regression to compute Jresp at each identified state
    rates = []
    for bounds in iBounds:
      (rate,bDummy),cDummy = curve_fit(line,simTimes[bounds[0]:bounds[1]],data[bounds[0]:bounds[1]])
      # The rates are multiplied by 2 to account for O2 stoechiometry
      rates.append(abs(2*rate))

    return rates,iBounds

def costFunc(p,model,manip):

    ''' Measures the distance between a simulation and
        an experimental data set.

        The distance is defined by the weighted sum of
        - the l2 norm of the difference between the
          respiration rates, fitted over selected intervals
        - an indicator of the linearity of the respiration
          on these intervals

        p         : parameters passed to the model (different from file)
        model     : function used to compute simulated rates
        manip     : time characteristics of the experiment
    '''

    # Run the simulation
    ok,simTime,JoSim = model(p,manip)

    # Detect if the simulation failed
    if not ok:
      return -1

    else:
      # dlXX : indicator of linearity
      # dXX  : difference of values for rates,
      #        It will be 0 for k=0 as we are interested
      #        in the ratio between the rates, not the
      #        absolute value
      dls = []
      ds  = [0]

      # Compute the derivative of the flux
      dJo = np.diff(JoSim)
      # Tolerances (the derivative must be < epsilon)
      # it is different depending on the state
      epsilons = [1.e-6,2.e-5,1.e-6,2.e-5]

      # Get the rates from the experiment, as well as bounds
      # of fit intervals
      expJs,bounds = manipExtractor(simTime,manip)

      for k in range(manip.nStates):

        # Difference in value
        Jstate = np.mean(JoSim[bounds[k][0]:bounds[k][1]])
        if (k==0):
          JoSub = Jstate
        else:
          ds.append(((Jstate/JoSub)-(expJs[k]/expJs[0]))**2)

        # Linearity
        dls.append(np.size(np.where(np.abs(dJo[bounds[k][0]:bounds[k][1]])<epsilons[k]))\
                  /(bounds[k][1]-bounds[k][0]))

      # The coefficient 1/25. was set after comparing the values of the
      # two terms in previous runs for several sets of parameters
      return 1./25.*np.sum(np.asarray(ds)) + np.sum(np.asarray(dls))

