# -----------------------------------------------------------------
# genetic.py
#
# Authors: B.Tarraf, M.Leguebe
#
# A simple implementation of genetic algorithm
#
# -----------------------------------------------------------------

import random
import shutil
import numpy as np
from matplotlib import pyplot as plt

def evaluatePopulation(params,fitnessFunc,fitnessArgs):

    ''' Computes the distance between model and data
        for many sets of parameters
        params      : (nSamples x nModelParams)
        fitnessFunc : the function to evaluate. If evaluation
                      fails the result should be negative
        fitnessArgs : extra arguments of the fitness function
    '''

    args = tuple([params.T,*fitnessArgs])

    # Laptop version
    Y = reduce(fitnessFunc,args,params.shape[0],nCpus,progressBar=True)
    # Cluster version
    #Y = reduceOnSlurmCluster(fitnessFunc,args,params.shape[0],nCpus,timePerRun,tmpDir,deleteFiles=False)
    return Y

def selection(Y,param,ratio):

    ''' Select the better half of a batch of parameter sets,
        based on the lowest fitness value
        Y      : values of the fitness function
        params : (nSamples x nModelParams)
    '''

    # Remove the failed simulations
    params = params[Y>=0,:]
    Y      = Y[Y>=0]

    # save generation results ???

    argSorted = np.argsort(Y)
    argSelect = argSorted[0:int(np.size(argSorted)*ratio)]

    return Y[argSelect],params[argSelect,:]

def reproduction(parents,pMutate,pCross):

    ''' Generates a new population of parameters
        from 'parent' parameters.
        Parents can breed, mutate, or stay unchanged
        with fixed probabilities.
        If parents reproduce, the new parameters are a linear
        combination of the parents.
        If parents mutate, their values change up to +- 30%.
    '''

    toBreed   = []
    toMutate  = []
    newParams = []

    # Determine the fate of each parent
    for i in range(parents.shape[0]):
      u = np.random.uniform()
      if u<=pCross:
        toBreed.append(parents[i])
      elif u>pCross and u<=pCross+pMutate :
        toMutate.append(parents[i])
      else:
        newParams.append(parents[i])

    # Mutations
    for j in range(len(toMutate)):
      b = np.random.uniform(-0.3,0.3)
      newParams.append(b*np.asarray(toMutate)[j])

    # If odd number of parents for breeding,
    # do not change the first
    if len(toBreed)%2!=0:
      newParams.append(toBreed[0])
      toBreed = toBreed[1:]

    # Cross breeding
    for k in range(0,len(toBreed),2):
      alpha = np.random.uniform(0,1)
      newParams.append( alpha   *toBreed[k]+(1-alpha)*toBreed[k+1])
      newParams.append((1-alpha)*toBreed[k]+ alpha   *toBreed[k+1])

    print("- {:d} new parameters generated ({:d} from cross-breeding,".format(len(newParams),len(toBreed)))
    print("  {:d} from mutations and {:d} unchanged)".format(len(toMutate),len(parents)-len(toBreed)-len(toMutate)))

    return np.asarray(newParams)

def geneticCalibration(params,fitnessFunction,fitnessArgs,maxGenerations,\
                       selectRatio=0.5,pMutate=0.05,pCross=0.8,\
                       startFromGen=None,paramsFile=None,fitnessFile=None):

    ''' Calibrates the parameters, starting from a random population (params)
        or from partially optimized population (startFromGen)
    '''

    converged  = False
    generation = 1
    pNew       = params
    if (startFromGen is not None):
      generation = startFromGen
      pSelected  = np.load(paramsFile)
      ySelected  = np.load(fitnessFile)

    while generation <= maxGenerations and not converged:

      print("===================")
      print("Generation", generation)
      print("- Evaluate fitness for {:d} sets of parameters".format(pNew.shape[0]))
      yNew = evaluatePopulation(pNew,fitnessFunction,fitnessArgs)

      if generation==1:
        ySelected,pSelected = selection(yNew,pNew,selectRatio)
      else:
        ySelected,pSelected = selection(np.append(ySelected,yNew),\
                                        np.row_stack((pSelected,pNew)),\
                                        selectRatio)

      print("- Selected {:d} sets of parameters for next generation".format(len(ySelected)))
      np.save(tmpDir+"/selectedParams.npy" ,pSelected)
      np.save(tmpDir+"/selectedFitness.npy",ySelected)

      pNew = reproduction(pSelected,pMutate,pCross)

      # Check convergence
      if (generation!=1):
        yMean     = np.mean(ySelected)
        converged = abs(yMean-yMeanPrev)<1.e-5
        yMeanPrev = yMean

      generation = generation+1

    return pSelected[0],ySelected[0]

