# =================================================================
# ODESolvers_.py (M. Leguebe)
# =================================================================

import os
import sys
import copy
import shutil
import numpy as np
import time

# import personnal routines and classes
sys.path.insert(0,'./')
from .misc          import *
from .mitoModels    import *
from .parallelTools import *

def getODESolver(config):

  solvertype = config("scheme","Fwd_Euler").upper()
  if (solvertype == "FWD_EULER"):
    solver = FwdEuler(config)
  if (solvertype == "PREDCORR"):
    solver = PredCorr(config)
  elif (solvertype == "RK4"):
    solver = RK4(config)
  elif (solvertype == "RKF45"):
    solver = RKF45(config)
  else:
    raise Exception("Unknown solver type. Available types are\n"+\
                    "FWD_Euler, RK4, RKF45")

  return solver

# =================================================================
# Generic class for common methods
# =================================================================

class ODESolverTypes:

  Explicit = 0
  Implicit = 1

  FixedTS    = 10
  VariableTS = 11

class ODESolver(object):

  def initTime(self,config):

    self.dt       = evalFloat(config("dt","1"))
    self.t        = evalFloat(config("tmin","0"))
    self.tmin     = evalFloat(config("tmin","0"))
    self.tmax     = evalFloat(config("tmax","100"))
    self.tOutStep = evalInt  (config("tOutStep","1")) 

    if (self.timestepping == ODESolverTypes.VariableTS):
      self.tol      = evalFloat(config("TStol"     ,"0.01" )) # error tolerance 
      self.minTS    = evalFloat(config("TSmin"     ,"0.001")) # absolute minimum time step
      self.maxTS    = evalFloat(config("TSmax"     ,"50   ")) # absolute maximum time step
      self.ratioMin = evalFloat(config("TSratiomin","0.8"  )) # min ratio between consecutive time steps
      self.ratioMax = evalFloat(config("TSratiomax","10"   )) # max ratio between consecutive time steps
      if (self.ratioMin>self.ratioMax):
        self.ratioMax,self.ratioMin = self.ratioMin,self.ratioMax
      if (self.ratioMin>1):
        self.ratioMin = 1./self.ratioMin
      if (self.ratioMax<1):
        self.ratioMax = 1./self.ratioMax
      self.TSniterMax = 100

  # =============================================================

  def getExplButcherEvol(self,y,model,butchA,butchB,butchC):
    ''' 
        Result from Runge Kutta Explicit step.
        This routine is written here as many Runge-Kutta methods
        can be used.
        Butcher tables are defined in child class (eg RK4, RKF45).
    '''

    kShape = list(y.shape)
    kShape.insert(0,len(butchA))
    ks     = np.zeros(kShape,dtype=y.dtype)

    for ik in range(len(butchA)):
      toAddState = np.zeros(y.shape,dtype=y.dtype)
      for ikk in range(ik):
        toAddState = toAddState + ks[ikk]*butchB[ik,ikk]
      if (isinstance(model,mitoODEModel)):
        ks[ik] = model.evolStateVar(self.t+self.dt*butchA[ik],y+self.dt*toAddState)
      elif (callable(model)):
        ks[ik] = model(self.t+self.dt*butchA[ik],y+self.dt*toAddState)

    res = np.zeros(y.shape,dtype=y.dtype)
    for ik in range(len(butchC)):
      res = res + ks[ik]*butchC[ik]

    return res

  # =============================================================

  def evaluateNewTimeStep(self,err):
    ''' adaptive time stepping classes only '''

    if (abs(err)<1.e-20):
      self.nextdt = self.dt*0.9*self.ratioMax
      self.nextdt = min(self.maxTS,self.nextdt)
    else:
      if (self.tol/err > 1.e0):
        self.nextdt = self.dt*0.9*min(self.ratioMax,self.tol/err)
        self.nextdt = min(self.maxTS,self.nextdt)
      else:
        # We update dt as the solution was discarded
        self.dt = self.dt*0.9*max(self.tol/err,self.ratioMin) 
        self.dt = max(self.minTS,self.dt)

  def run(self,model,Y0,outputFile=None,outputTimes=None,quiet=False):

    wrOutput = type(outputFile)==str
    if (wrOutput):
      outFile = open(outputFile,'w')
      model.writeOutputHeader(outFile)

    if outputTimes is not None:
      res = []

    fixedDt = (self.timestepping == ODESolverTypes.FixedTS)
    if fixedDt and not quiet:
      PB = progressBar(np.floor((self.tmax-self.tmin)/self.dt))

    self.Y  = copy.copy(Y0)
    niter   = 0
    tstr    = ''
    if wrOutput:
      outFile.write(model.getOutputStateVar(self.t,Y0[0]))

    while(self.t<self.tmax):
      Ynext  = self.nextStep(self.Y,model)
      niter  = niter + 1
      self.t = self.t + self.dt
      if (wrOutput and niter%self.tOutStep==0):
        outFile.write(model.getOutputStateVar(self.t,Ynext))

      if (outputTimes is not None and np.sum(np.abs(self.t-np.asarray(outputTimes))<0.5*self.dt)):
        res.append(Ynext)

      if (self.timestepping == ODESolverTypes.VariableTS):
        self.dt = self.nextdt

      # update state variables
      for ik in range(self.nstencil-1):
        self.Y[ik] = self.Y[ik+1]
      self.Y[-1] = Ynext

      if not quiet:
        if (fixedDt):
          PB.update()
        else:
          if (niter%500==0):
            print(len(tstr)*'\r',end='\r')
            tstr = "%12d/%12d ms, dt: %10.3f" % (int(self.t),int(self.tmax),self.dt)
            print(tstr,end='\r')
            sys.stdout.flush()

    # Clear progress display
    if not quiet:
      if (fixedDt):
        del PB
      else:
        print()

    if (wrOutput):
      outFile.close()
    if outputTimes is not None:
      return np.asarray(res)

  def __call__(self,model,Y0,outputFile=None):
    self.run(model,Y0,outputFile)

  # =============================================================

  def checkConvergence(self):
    ''' 
        Runs a test vs an analytical solution for several time steps
        or tolerance parameters for adaptive time stepping

        tested equation is
        y' = 2k^2t cos((kt)^2)*exp(sin((kt)^2)), y(0)=1
        k = 0.1, t in [0,100], solution y = exp(sin((kt)^2))

        error is L^2([0,100]), relative
    '''

    dts   = [0.01,0.02,0.05,0.1,0.2,0.5]
    cvRes = []
    solSim = []

    def cvFunc(t_,y):
      return 2*0.01*t_*np.cos((0.1*t_)**2)*y #exp(sin((0.1*t_)**2))
    def cvSol(t_):
      return np.exp(np.sin((0.1*t_)**2))

    for dt in dts:
      print(type(self),dt)
      try:
        self.dt   = dt
        self.t    = 0
        self.tmax = 100
        state     = np.array([1.]).reshape(1,1)
        errtot    = 0.e0
        reftot    = 0.e0
        tmpSolSim = [np.array([self.t,state[0][0],cvSol(self.t)])]
        while (self.t<self.tmax):
          state  = self.nextStep(state,cvFunc)
          self.t = self.t + self.dt
          errtot = errtot + self.dt*(state-cvSol(self.t))**2
          reftot = reftot + self.dt*cvSol(self.t)**2
          tmpSolSim.append(np.array([self.t,state[0][0],cvSol(self.t)]))
          if (self.timestepping == ODESolverTypes.VariableTS):
            self.dt = self.nextdt

        cvRes.append([dt,np.sqrt(errtot)/np.sqrt(reftot)])
        solSim.append(np.array(tmpSolSim))
      except:
        pass

    return np.array(cvRes),np.array(solSim)

# =================================================================

class FwdEuler(ODESolver):
  ''' Simple one step forward Euler method '''

  def __init__(self,config):

    self.type         = ODESolverTypes.Explicit
    self.timestepping = ODESolverTypes.FixedTS
    self.initTime(config)
    self.nstencil     = 1   

  def nextStep(self,y,model):

    self.updateDt = self.dt
    if (isinstance(model,mitoODEModel)):
      res = y[0] + self.dt*model.evolStateVar(self.t,y[0])
    elif callable(model):
      res = y[0] + self.dt*model(self.t,y[0])

    return res

# =================================================================

class PredCorr(ODESolver):
  ''' Predictor (Euler) corrector (trapeze) scheme  '''

  def __init__(self,config):

    self.type         = ODESolverTypes.Explicit
    self.timestepping = ODESolverTypes.FixedTS
    self.initTime(config)
    self.nstencil     = 1

  def nextStep(self,y,model):

    self.updateDt = self.dt
    if (isinstance(model,mitoODEModel)):
      ftn  = model.evolStateVar(self.t,y[0])
      pred = y[0] + self.dt*ftn
      corr = y[0] + self.dt*(ftn+model.evolStateVar(self.t+self.dt,pred))/2.
    elif callable(model):
      ftn  = model(self.t,y[0])
      pred = y[0] + self.dt*ftn
      corr = y[0] + self.dt*(ftn+model(self.t+self.dt,pred))/2.

    return corr

# =================================================================

class RK4(ODESolver):
  ''' Classical Runge-Kutta 4 solver '''

  def __init__ (self,config):

    self.type         = ODESolverTypes.Explicit
    self.timestepping = ODESolverTypes.FixedTS
    self.initTime(config)
    self.nstencil     = 1

    # Butcher table
    self.butcherA = np.array([0.e0,0.5e0,0.5e0,1.e0])
    self.butcherB = np.array([[0.0e0,0.0e0,0.e0,0.e0],\
                              [0.5e0,0.0e0,0.e0,0.e0],\
                              [0.0e0,0.5e0,0.e0,0.e0],\
                              [0.0e0,0.0e0,1.e0,0.e0]])
    self.butcherC = np.array([1.e0,2.e0,2.e0,1.e0])/6.e0

  def nextStep(self,y,model):

    self.updateDt = self.dt
    res = y[0] + self.dt*self.getExplButcherEvol(y[0],model,\
                         self.butcherA,self.butcherB,self.butcherC)

    return res

# =================================================================

class RKF45(ODESolver):
  ''' 
      Runge-Kutta-Fehlberg method.
      Adaptive time stepping.
      Error evaluation is done between two RK solutions 
      of order 4 and 5
  '''

  def __init__ (self,config):

    self.type         = ODESolverTypes.Explicit
    self.timestepping = ODESolverTypes.VariableTS
    self.initTime(config)
    self.nstencil     = 1

    # Butcher tables
    self.butcherA  = np.array([0.e0,0.25e0,3.e0/8.e0,12.e0/13.e0,1.e0,0.5e0])
    self.butcherB  = np.array([[          0.0e0,           0.0e0,            0.e0,           0.e0,        0.e0,0.e0],\
                               [         0.25e0,           0.0e0,            0.e0,           0.e0,        0.e0,0.e0],\
                               [   3.e0/  32.e0,    9.e0/  32.e0,            0.e0,           0.e0,        0.e0,0.e0],\
                               [1932.e0/2197.e0,-7200.e0/2197.e0, 7296.e0/2197.e0,           0.e0,        0.e0,0.e0],\
                               [ 439.e0/ 216.e0,           -8.e0, 3680.e0/ 513.e0,-845.e0/4104.e0,        0.e0,0.e0],\
                               [  -8.e0/  27.e0,           2.0e0,-3544.e0/2565.e0,1859.e0/4104.e0,-11.e0/40.e0,0.e0]])
    self.butcherC4 = np.array([25.e0/216.e0,0.e0,1408.e0/ 2565.e0, 2197.e0/ 4104.e0,-1.e0/ 5.e0])
    self.butcherC5 = np.array([16.e0/135.e0,0.e0,6656.e0/12825.e0,28561.e0/56430.e0,-9.e0/50.e0,2.e0/55.e0])

  def nextStep(self,y,model):

    errt  = 1337.e42
    niter = 0
    while (errt > self.tol and niter < self.TSniterMax):
      #try:
        res4 = y[0] + self.dt*self.getExplButcherEvol(y[0],model,\
                              self.butcherA[:-1],self.butcherB[:-1,:-1],self.butcherC4)
        res5 = y[0] + self.dt*self.getExplButcherEvol(y[0],model,\
                              self.butcherA,self.butcherB,self.butcherC5)

        errt = norm2((res4-res5))/norm2(res5)
      #except:
      #  errt = 1337.e42
        self.evaluateNewTimeStep(errt)
        niter = niter + 1

    if (niter == self.TSniterMax):
      print(bColors.warning() + "Maximum time step reduction reached. Consider changing tolerance or minimum time step.")

    return res5


