# -*- coding: UTF-8 -*-
# -----------------------------------------------------------------
# parallelTools.py
#
# Authors: M.Leguebe
# 
# Methods used to launch batch computations in parallel
# so that it works in a similar way as MPI reductions,
# but more efficiently than Python default parallel tool
# (ie without copying the entirety of the memory state to
# all processes...)
# -----------------------------------------------------------------

import itertools as it
import numpy     as np
import getpass
import os
import os.path
import sys
import pickle
import subprocess
import shutil
import shlex
import time

from multiprocessing import Pool,Lock,Value
from math            import floor,ceil

class progressBar(object):

    ''' A progress bar that also works
        in parallel. Basic usage:

        pb = progressBar(Niters)
        for i=1,Niters:
          #...
          pb.update()
        del pb
    '''

    def __init__(self,N,stype='serial',bonus=None,w=50):

        if stype == 'parallel':
          self.stype    = 'parallel'
          self.progress = Value('i',0)
        else:
          self.stype    = 'serial'
          self.progress = 0.e0
        self.end      = N
        self.lock     = Lock()

        self.w  = w
        str      = '  0% |'+(w*' ')+'|'
        if bonus is not None:
          self.bonus = bonus
          str = str + ' ' + bonus
        self.pbl = len(str)
        print(str,end='\r')

    def update(self,w=50):

        if self.stype=='parallel':
          self.lock.acquire()
          self.progress.value += 1
          val = self.progress.value
        else:
          self.progress += 1
          val = self.progress
        percent = int(floor(val/(1.0*self.end)*100))
        prctstr = '%3d%%' % percent
        nequal  = max(percent*self.w//100-1,0)

        print(self.pbl*'\r',end='\r')
        print(prctstr + ' |' + (nequal* '=') + '|' + ((self.w-nequal-1)*' ') + '|',end='\r')
        if hasattr(self,'bonus'):
          print(' '+self.bonus,end='\r')
        sys.stdout.flush()

        if self.stype=='parallel':
          self.lock.release()

    def __del__(self):

        print()
        sys.stdout.flush()

###################################################################
# Top reduction routines

def reduce(funcname,args,n,nCpu,rtype=None,progressBar=False):
    '''
        Applies the type of reduction rtype on a pool map of nCpu processes
        n:     number of individual calls to funcname to be performed
        ncpu:  number of CPUs on which tasks will be dispatched
        args:  a tuple of arguments passed to funcname. Each np.array of args with last
               dimension equal to N is split amongst processors
        rtype: can be None 'SUM' 'MAX' 'MIN'

        Example: if the function to be call is defined by
        def f(a0,a1,a2,a3)
        and we want to perform the reduction on results of f for 4 different values of a3 on 2 CPUs, then

        reduce(f,(a0,a1,a2,np.asarray((a30,a31,a32,a33))),4,2)

        will make CPU0 compute r0 = f(a0,a1,a2,a30) then r1 = f(a0,a1,a2,a31)
        and CPU1 compute  r2 = f(a0,a1,a2,a32) then r3 = f(a0,a1,a2,a33)

        The result will be the array composed by (r0,r1,r2,r3), or just the max, min, sum.
    '''

    argTable = buildArgTable(args,n,nCpu,progressBar)

    pool = Pool(nCpu)
    if (rtype==None):
      resPerProc = pool.map(runningPar,zip(it.repeat(funcname),argTable))
    if (rtype=='SUM'):
      resPerProc = pool.map(runningSumPar,zip(it.repeat(funcname),argTable))
    if (rtype=='MIN'):
      resPerProc = pool.map(runningMinPar,zip(it.repeat(funcname),argTable))
    if (rtype=='MAX'):
      resPerProc = pool.map(runningMaxPar,zip(it.repeat(funcname),argTable))

    # Final reduction
    if (rtype==None):
      res2 = []
      for i in range(nCpu):
        for j in range(len(resPerProc[i])):
          res2.append(np.array(resPerProc[i][j]))
      D   = np.array(res2)
      res = np.rollaxis(D,0,D.ndim)

    if (rtype=='SUM'):
      res = np.sum(resPerProc,axis=0)
    if (rtype=='MIN'):
      res = np.min(resPerProc,axis=0)
    if (rtype=='MAX'):
      res = np.max(resPerProc,axis=0)

    pool.close()
    return np.asarray(res)

def apply(funcname,args,n,nCpu,progressBar=False):
    '''same as reduce, without return, only modifies content of array'''

    argTable = buildArgTable(args,n,nCpu,progressBar)

    pool = Pool(nCpu)
    pool.apply_async(applyPar,zip(it.repeat(funcname),argTable))
    pool.close()

###################################################################
# Reductions on clusters

def reduceOnSlurmCluster(func,args,n,nCpu,\
                         time,tmpDir,rType=None,\
                         deleteFiles=True,parallelRegroup=1,\
                         machine="miriel"):

    '''
        See also function 'reduce' for common arguments.
        This functions has the same goal as reduce,
        but computations are performed on a computing center
        with the slurm scheduler.
        >> WARNING <<
        The default argument miriel corresponds to our
        usual machine. Also, the batch files generated by this
        method correspond to the setting of our machine.
        This routine is highly non-robust!

        time            : expected time in hours for each CPU to perform
                          a single call to funcname
        tmpDir          : directory in which to write temporary files
        deleteFiles     : delete temporary files generated by each CPUs
        parallelRegroup : (int) when regrouping results from each process,
                          specify the number of CPUs to use to do so.
                          This number must not be too large as you are running
                          this function on an access machine.
        machine         : select machine on which to launch your job.
                          The number of CPUs per node will be deduced from this name
    '''

    ID   = int(np.random.rand(1)*1.e8)
    rDir = tmpDir + '/reduceData/%d' %ID
    if not deleteFiles:
      with open('reduceID','w') as elf:
        elf.write('%d\n' % ID)
    # Clear existing tmp files
    if deleteFiles and os.path.exists(rDir):
      shutil.rmtree(rDir,ignore_errors=True)
    runCommand('mkdir -p '+rDir)
    runCommand('mkdir -p '+rDir+'/jobLogs')

    # Save function and arguments
    argsFile = rDir + '/arguments.pkl'
    funcFile = rDir + '/function.pkl'
    saveObject(args,argsFile)
    saveObject(func,funcFile)

    # Get the number of cpus per machine to compute the number of needed nodes
    out         = runCommand("sinfo -o \"%%30N %%10c %%60f\" | grep %s | awk '{print $2}' | head -n 1" % machine)
    nCpuPerNode = int(out)
    cpuOffsets  = [i for i in range(0,nCpu,nCpuPerNode)]
    nNodes      = len(cpuOffsets)
    cpuOffsets.append(nCpu)

    # Evaluate walltime for each job
    starts      = partition(n,nCpu)
    hours       = int(np.ceil(time*np.max(np.diff(starts))))
    cwd         = os.getcwd()

    # Now generate jobs (one per node)
    jobIDs = []
    for iJob in range(nNodes):
      subFileName  = rDir + '/toSubmit%s.sub' % iJob
      with open(subFileName,'w') as QSF:
        QSF.write('#!/bin/bash\n')
        QSF.write('#SBATCH -J %s%d\n'% (func.__name__,iJob))
        QSF.write('#SBATCH -N 1\n')
        QSF.write('#SBATCH --ntasks %d\n' % (cpuOffsets[iJob+1]-cpuOffsets[iJob]))
        QSF.write('#SBATCH -C %s\n' % machine)
        QSF.write('#SBATCH --time=%d:00:00\n' % hours)
        QSF.write('#SBATCH -o %s/jobLogs/%%j.out\n' % rDir)
        QSF.write('#SBATCH -e %s/jobLogs/%%j.err\n' % rDir)

        QSF.write('cd %s\n' % cwd)
#        QSF.write('module load language/python/3.8.0\n')
        for iCpu in range(cpuOffsets[iJob],cpuOffsets[iJob+1]):
          QSF.write('python3 -W ignore reduceOnCluster.py %s %s %d %d %d %s %s &\n'%\
                    (funcFile,argsFile,iCpu,nCpu,n,rType,rDir))
        QSF.write('wait\n')

      # Submit
      out,err = runCommand('sbatch %s' % subFileName,stdErr=True)

      if err:
        print("Job couldn't be submitted.\n",err)
        return
      # Get jobId ("Submitted batch job XXXX")
      jobIDs.append(out.split()[3])

    # Wait for all jobs to finish
    print("Submitted jobs to the cluster. Now waiting for completion.")
    nNotOk = waitSlurmJobs(jobIDs)

    # Check missing files
    if nNotOk != 0:
      print("\nWARNING: %d jobs abnormally terminated! Result won't be computed\n" % nNotOk)
      return
    else:
      return reducePartialReduction(rDir,n,nCpu,rType,deleteFiles,parallelRegroup)

###################################################################
# Routines to unroll the list of arguments (pool cannot pass several arguments)

def applyPar(args):
    return apppply(*args)

def runningPar(args):
    return running(*args)

def runningSumPar(args):
    return runningSum(*args)

def runningMinPar(args):
    return runningMin(*args)

def runningMaxPar(args):
    return runningMax(*args)

###################################################################
# Process or reduction routines

def apppply(funcname,args):

    if args[0][0]:
      PG = progressBar(len(args),'serial',bonus=' (parallel)')

    for i in range(len(args)):
      funcname(*(args[i][1:]))
      if args[i][0]:
        PG.update()

    if args[i][0]:
      del PG

def running(funcname,args):
    res = []
    if args[0][0]:
      PB = progressBar(len(args),'serial',bonus=' (parallel)')
    for i in range(len(args)):
      res.append(funcname(*(args[i][1:])))
      if args[i][0]:
        PB.update()
    if args[i][0]:
      del PB
    return res

def runningSum(funcname,args):
    ''' compute the sum from start to end, applying func_name to whatever args'''
    res = 0
    if args[0][0]:
      PG = progressBar(len(args),'serial',bonus=' (parallel)')
    for i in range(len(args)):
      res += funcname(*(args[i][1:]))
      if args[i][0]:
        PG.update()
    if args[i][0]:
      del PG
    return res

def runningMin(funcname,args):
    res = funcname(*(args[0][1:]))
    if args[0][0]:
      PG = progressBar(len(args),'serial',bonus=' (parallel)')
      PG.update()
    for i in range(1,len(args)):
      res = np.minimum(res,funcname(*(args[i][1:])))
      if args[i][0]:
        PG.update()
    if args[i][0]:
      del PG
    return res

def runningMax(funcname,args):
    res = funcname(*(args[0][1:]))
    if args[0][0]:
      PG = progressBar(len(args),'serial',bonus=' (parallel)')
      PG.update()
    for i in range(1,len(args)):
      res = np.maximum(res,funcname(*(args[i][1:])))
      if args[i][0]:
        PG.update()
    if args[i][0]:
      del PG
    return res

###################################################################
# Reduction preparation routines

def buildArgTable(args,N,nbproc,progressBar):

    starts = partition(N,nbproc)
    # prepare list of arguments
    arg_table = []
    for i in range(nbproc):
      arg_table_proc = []
      for j in range(starts[i],starts[i+1]):

        list_args = []
        # Precise in arguments if progress bar is needed
        if progressBar and i==0:
          list_args.append(True)
        else:
          list_args.append(False)
        for arg in list(args):
          # Detect if the argument is a numpy array with last dimension N
          try:
            if (arg.shape[-1]==N):
              if arg.ndim == 1:
                list_args.append(arg[j])
              else:
                list_args.append(arg[...,j])
            else:
              list_args.append(arg)
          except:
            list_args.append(arg)
        arg_table_proc.append(tuple(list_args))

      arg_table.append(arg_table_proc)
    return arg_table

def partition(N,Nprocs):

    division = N / float(Nprocs)
    starts   = []
    for i in range(Nprocs):
      starts.append(int(round(division * i)))
    starts.append(min(int(round(division*(i+1))),N))
    return starts

########################################################################

def loadPartialReduction(outDir,iStart,iEnd):
    res =  np.load('%s/res_%d_%d.npy'%(outDir,iStart,iEnd))
    return res

def reducePartialReduction(rDir,n,nProc,rType,deleteFiles=True,parallelRegroup=1):
    ''' regroups partial results files '''

    starts = partition(n,nProc)
    print("Regrouping results")

    if parallelRegroup <= 1:

        PB  = progressBar(nProc,'serial')
        res = loadPartialReduction(rDir,starts[0],starts[1])
        PB.update()

        for i in range(1,nProc):
          resTmp = loadPartialReduction(rDir,starts[i],starts[i+1])
          if rType is None:
            res = np.concatenate((res,resTmp))
          elif rType == 'SUM':
            res += resTmp
          elif rType == 'MIN':
            res = np.minimum(res,resTmp)
          elif rType == 'MAX':
            res = np.maximum(res,resTmp)
          PB.update()
        del PB

    else:
        ## Special reduce
        hn = getHostname()
        nRegroup = int(floor(parallelRegroup))
        if nRegroup>4:
          print("Maybe you asked for too much nodes on an interactive machine !")

        argTable = buildArgTable((rDir,np.array(starts[:nProc]),np.array(starts[1:])),nProc,nRegroup,progressBar)

        pool = Pool(nRegroup)
        if (rType==None):
          resPerProc = pool.map(runningPar,zip(it.repeat(loadPartialReduction),argTable))
        elif (rType=='SUM'):
          resPerProc = pool.map(runningSumPar,zip(it.repeat(loadPartialReduction),argTable))
        elif (rType=='MIN'):
          resPerProc = pool.map(runningMinPar,zip(it.repeat(loadPartialReduction),argTable))
        elif (rType=='MAX'):
          resPerProc = pool.map(runningMaxPar,zip(it.repeat(loadPartialReduction),argTable))

        # Final reduction
        if (rType is None):
          res2 = []
          for i in range(nRegroup):
            for j in range(len(resPerProc[i])):
              for k in range(len(resPerProc[i][j])):
                res2.append(np.array(resPerProc[i][j][k]))
          res = np.array(res2)
        elif (rType=='SUM'):
          res = np.sum(resPerProc,axis=0)
        elif (rType=='MIN'):
          res = np.min(resPerProc,axis=0)
        elif (rType=='MAX'):
          res = np.max(resPerProc,axis=0)

        pool.close()
        # End parallel regroup

    res = np.array(res)

    if deleteFiles:
      shutil.rmtree(rDir,ignore_errors=True)

    if rType is None:
      return np.rollaxis(res,0,res.ndim)
    else:
      return res

def waitSlurmJobs(jobIDs):
    ''' Wait for jobs to finish and check their status '''

    # Small wait delay to launch the jobs
    time.sleep(10)

    # Counting remaining jobs and keep printing dots until finished
    command  = "squeue -j " + ",".join(jobIDs) + " | wc -l"
    pings    = 0
    while int(runCommand(command)) != 1:
        # 1 because of line "JOBID PARTITION NAME ..."
        if pings < 12:
          time.sleep(10)
        elif pings < 69:
          time.sleep(60)
        else:
          time.sleep(600)
        print('.',end='\r')
    print()

    # Check job status
    nNotOk = 0
    for j in jobIDs:
      out = runCommand("scontrol show job " + j + " | grep JobState ")
      if "FAILED" in out or "CANCELLED" in out:
        nNotOk+=1
    return nNotOk

########################################################################

def runCommand(command,stdErr=False):
    '''
        runs given command on system
        if stdErr is True, returns the stderr with the stdout
        else script stops if err contains anything
    '''

    subP = subprocess.Popen(command,shell=True,executable="/usr/bin/bash",stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE,encoding='UTF-8')
    out,err = subP.communicate()
    if stdErr and err:
        print("Problem while running command line:\n {}\n".format(command))
        print("Error was:\n {}".format(err))
        sys.exit()
    if stdErr:
        return out,err
    else:
        return out

def saveObject(thing,filename):
    with open(filename,'wb') as output:
      pickle.dump(thing,output,pickle.HIGHEST_PROTOCOL)

def loadObject(filename):
    with open(filename,'rb') as input:
      return pickle.load(input)

def testParallelFunction(a):
  time.sleep(10)
  return 2*a
