# =================================================================
#  Some utilities. (M. Legube)
#  In particular, config parser and file management
# =================================================================
import os
import sys
import configparser
import errno
import re
import shutil
import socket
import numpy as NP
import scipy
import time
# ----------------------------------------
# debug: catch divisions by 0
NP.seterr(all='raise')

#==================================================================
# File management 
def which(file):
  for path in os.environ["PATH"].split(":"):
    if os.path.exists(path + "/" + file):
      return path + "/" + file
  return None

def cutExtension(filename):
  return ".".join(filename.split('.')[:-1])

def getExtension(filename):
  return filename.split('.')[-1]

def removeDoubleSlashes(path):
    return '/'.join(path.split('//'))

def mkdir_p(path):
    ''' Equivalent of mkdir -p '''
    try:
      os.makedirs(path)
    except OSError as exc:  # Python >2.5
      if exc.errno == errno.EEXIST and os.path.isdir(path):
        pass
      else:
        raise

def remove(filename):
  ''' thanks StackOverflow '''
  if os.path.isfile(filename):
    try:
      os.remove(filename)
    except OSError as e:
      if e.errno != errno.ENOENT:
        raise
  elif os.path.isdir(filename):
    shutil.rmtree(filename,ignore_errors=True)

def purge(dir,pattern):
  ''' thanks StackOverflow '''
  for f in os.listdir(dir):
    if re.search(pattern,f):
      os.remove(os.path.join(dir,f))

def getHostname():
  if socket.gethostname().find('.')>=0:
    return socket.gethostname()
  else:
    return socket.gethostbyaddr(socket.gethostname())[0]

#==================================================================

def testfunc(a):
  return NP.array(a*NP.ones((10,)))

#==================================================================

# Copied for CellML compatibility
def custom_piecewise(cases):
    """Compute result of a piecewise function"""
    return NP.select(cases[0::2],cases[1::2])

#==================================================================

class bColors:
    ''' Bold colors escape characters '''
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    @staticmethod
    def warning():
      return bColors.WARNING + 'Warning: ' + bColors.ENDC

#==================================================================
# Extension of the ConfigParser module used to get easier access
# to parameters given in a text file
class myConfigParser(configparser.RawConfigParser):
    ''' extension of rawConfigParser to avoid if config.has_option, then config.get... ''' 

    def __init__(self,filename):

      configparser.RawConfigParser.__init__(self)
      self.optionxform = str
      self.fileName_ = filename
      self.read(filename)

    def __call__(self,key,default_value=None):

      if default_value is not None:
        res = default_value

      if self.has_option(self.sections()[0],key):
        res = self.get(self.sections()[0],key)
      else:
        if default_value is None:
          raise IOError('No key '+key+' found in .init file and no default value was given.')

      # remove eventual line breaks from res
      if isinstance(res,str):
        return res.replace('\n',' ')
      else:
        return res

    def set(self,key,string):
      ''' Sets or adds an option '''
      configparser.RawConfigParser.set(self,self.sections()[0],key,string)

    def remove_option(self,key):
      ''' removes an option '''
      configparser.RawConfigParser.remove_option(self,self.sections()[0],key)

    def save(self,fileName):
      with open(fileName,'w') as of:
        self.write(of)
      self.fileName_ = fileName

    def update(self,keys,options,fileName=None):

      if len(keys) != len(options):
        raise Exception("Options and keys must have the same length")

      if fileName is not None:
        newC = myConfigParser(self.fileName_)
        toUpdate = newC
      else:
        toUpdate = self

      for (k,opt) in zip(keys,options):
        toUpdate.set(k,opt)

      if fileName is not None:
        newC.fileName_  = fileName
        newC.save(fileName)
        return newC

#==================================================================
# Norm stuff

def dotProd(a,b,axis=None):
  a = NP.array(a)
  b = NP.array(b)
  return NP.sum(a*b,axis=axis)

def norm(p,U,X=None):
  # L^p(X) norm 
  if not X is None:
    dX      = NP.zeros(len(X))
    dX[:-1] = X[1:]-X[:-1]
    l       = NP.zeros(len(X))
    l[1:-1] = (dX[:-2]+dX[1:-1])/2
    l[1]    = dX[1]/2
    l[-1]   = dX[-2]/2
  else:
    l = 1.e0

  return (NP.sum(l*(NP.fabs(U))**(p*1.0)))**(1/(p*1.0))

def norm2(U,X=None):
  if X is None:
    return NP.sqrt(NP.sum(NP.abs(NP.asarray(U))**2))
  else:
    return norm(2,U,X)


def getExtrema(U,x=None,type=None,interp=True):

  if x is None:
    x = range(len(U))

  res  = []
  resf = []

  for i in range(1,len(U)-1):
    f0 = U[i-1];
    f1 = U[i];
    f2 = U[i+1];

    bool = ((f1-f0)*(f1-f2) >= 0)
    if type=='MIN':
      bool = bool and (f1<f0)
    elif type=='MAX':
      bool = bool and (f1>f0)

    if bool:
      if interp:

        x1 = x[i];
        x0 = x[i-1];
        x2 = x[i+1];

        c0 = f0/((x0-x1)*(x0-x2));
        c1 = f1/((x1-x0)*(x1-x2));
        c2 = f2/((x2-x0)*(x2-x1));

        xmax = (x0*(c1+c2)+x1*(c0+c2)+x2*(c0+c1))/(2*(c0+c1+c2))
        res.append(xmax)
        resf.append(c0*(xmax-x1)*(xmax-x2) + c1*(xmax-x0)*(xmax-x2) + c2*(xmax-x0)*(xmax-x1))

      else:
        res.append(i)
        resf.append(f1)

  return NP.asarray(res),NP.asarray(resf)

def getZeros(U,x=None,interp=True):

  if x is None:
    x = range(len(U))

  res  = []

  for i in range(0,len(U)-1):
    f1 = U[i];
    f2 = U[i+1];

    bool = ((f2*f1)<= 0)

    if bool:
      if interp:

        x1 = x[i];
        x2 = x[i+1];
        xzero = x1-f1*(x2-x1)/(f2-f1)
        res.append(xzero)

      else:
        res.append(i)

  return NP.asarray(res)

#==================================================================
# Some commonly used function profiles

def smoothRectangle(x,wmin,wmax,ratio_t):
  ''' rectangle window function with smooth junctions '''

  d  = 0.5*ratio_t*(wmax-wmin)
  x0 = wmin-d/2.
  x1 = wmin+d/2.
  x2 = wmax-d/2.

  if hasattr(x,'shape'):
    res  = NP.zeros(x.shape)
    res  = NP.where( abs(x-wmin)<0.5*d, 0.5*(1.e0 - NP.cos( (x-x0)/d *NP.pi)),0.e0)
    res += NP.where( (x>x1)*(x<x2) , 1.e0,0.e0)
    res += NP.where( abs(x-wmax)<0.5*d, 0.5*(1.e0 + NP.cos( (x-x2)/d * NP.pi)),0.e0)
    return res
  else:
    if abs(x-wmin)<0.5*d:
      return 0.5*(1.e0 - NP.cos( (x-x0)/d *NP.pi))
    elif abs(x-wmax)<0.5*d:
      return 0.5*(1.e0 + NP.cos( (x-x2)/d * NP.pi))
    elif (x>x1) and (x<x2):
      return 1.e0
    else:
      return 0.e0

def gaussian(x,mean,sd,amp=1):
  ''' Gaussian profile '''
  return amp*(NP.exp(-(x-mean)**2/(2.e0*sd**2)))

def lorenz(x,mean,sd,amp=1):
  # Note that the FWHM = 2*sqrt(2)*sd
  return amp*1.e0/(1.e0+((x-mean)/(NP.sqrt(2.)*sd))**2)

def sigmoid(x,threshold,slope,minVal,maxVal):
  ''' Smooth Heaviside function'''
  return minVal+(maxVal-minVal)*0.5e0*(1+NP.tanh(slope*(x-threshold)))

def dsigmoid(x,threshold,slope,minVal,maxVal):
  ''' Derivative of smooth Heaviside function'''
  return (maxVal-minVal)*0.5e0*slope*(1-(NP.tanh(slope*(x-threshold)))**2)

#==================================================================
def posPart(a):
  return (a+abs(a))/2.

def negPart(a):
  return (a-abs(a))/2.

#==================================================================

def integrateBoole(f,x=None,axis=-1):
  ''' adapatation of Boole's Newton Cotes formula to non equally spaced points,
      as scipy.integrate.simps '''

  itgd,x = prepareIntegration(f,x,axis)

  if (N-1)%4!=0:
    istop  = 4*(N-1)/4
    res1   = integrateBoole(itgd[...,:istop],x[:istop],axis=-1) \
             + scipy.integrate.simps(itgd[...,istop-1:],x[istop-1:],axis=-1)
    istart = (N-1)%4
    res2   = integrateBoole(itgd[...,istart:],x[istart:],axis=-1) \
             + scipy.integrate.simps(itgd[...,:istart+1],x[:istart+1],axis=-1)

    res    = 0.5*(res1+res2)

  else:

    res  = NP.zeros(itgd.shape[:-1],dtype=f.dtype)

    for i in range(5):
      pximxk  = NP.ones(x[:-1:4].shape)
      pxk     = 1.e0
      sxk     = 0.e0
      sxkxj   = 0.e0
      sxjxkxl = 0.e0

      if i==0:
        islice = slice(i,-1,4)
      else:
        islice = slice(i,None,4)

      for k in range(5):

        if k==0:
          kslice = slice(k,-1,4)
        else:
          kslice = slice(k,None,4)

        # 3 terms product
        tpxjxkxl = NP.ones(x[:-1:4].shape)
        for j in range(5):
          if j==0:
            jslice = slice(j,-1,4)
          else:
            jslice = slice(j,None,4)
          if j!=i and j!=k:
            tpxjxkxl = tpxjxkxl * x[jslice]

        if k!= i:
          sxjxkxl = sxjxkxl + tpxjxkxl
        #----------------

        if k!=i:
          pximxk = pximxk*(x[islice]-x[kslice])
          pxk    = pxk   * x[kslice]
          sxk    = sxk   + x[kslice]
          tsxj   = 0.e0
          for j in range(k+1,5):

            if j==0:
              jslice = slice(j,-1,4)
            else:
              jslice = slice(j,None,4)

            if j!=i:
              tsxj = tsxj + x[jslice]
          sxkxj   = sxkxj   + tsxj*x[kslice]

      x4   = x[4::4]
      x0   = x[:-1:4]

      num  = 0.2e0*(x4**5-x0**5) - 0.25e0*(x4**4-x0**4)*sxk + 1.e0/3.e0*(x4**3-x0**3)*sxkxj - 0.5*(x4*x4-x0*x0)*sxjxkxl + pxk*(x4-x0)
      res  = res + NP.sum(itgd[...,islice]*(num/pximxk),axis=-1)

  return res

def integrateSimpsons38(f,x=None,axis=-1):
  ''' see above
      if the total interval cannot be divided by 3,
      simpsons rules is apply on the last interval '''


  itgd,x = prepareIntegration(f,x,axis)

  if (N-1)%3!=0:
    istop  = 3*(N-1)/3
    res1   = integrateSimpsons38(itgd[...,:istop],x[:istop],axis=-1) \
             + scipy.integrate.simps(itgd[...,istop-1:],x[istop-1:],axis=-1)
    istart = (N-1)%3
    res2   = integrateSimpsons38(itgd[...,istart:],x[istart:],axis=-1) \
             + scipy.integrate.simps(itgd[...,:istart+1],x[:istart+1],axis=-1)

    res    = 0.5*(res1+res2)

  else:

    H  = x[3::3]-x[:-1:3]
    h0 = x[1::3]-x[:-1:3]
    h1 = x[2::3]-x[1::3]
    h2 = x[3::3]-x[2::3]

    div  = ((h1+h2)*(h2-h1-h0)-h0*h2)/(h0*(h0+h1))
    div2 = ((h1+h0)*(h0-h1-h2)-h0*h2)/(h2*(h2+h0))

    res = NP.sum( H/12.e0 *(itgd[...,:-1:3] * (3.e0+div)\
                          + itgd[..., 1::3] * (H*H*(h0+h1-h2)/(h0*h1*(h1+h2)))\
                          + itgd[..., 2::3] * (H*H*(h2+h1-h0)/(h2*h1*(h1+h0)))\
                          + itgd[..., 3::3] * (3.e0+div2)))

  return res

def prepareIntegration(f,x,axis):

  f   = NP.asarray(f)
  N   = f.shape[axis]
  if x is None:
    x = NP.arange(N)
  else:
    x = NP.asarray(x)

  if len(x.shape)!=1 or len(x)!=N:
    raise Exception('x must be a 1D vector of the same size as f.shape[axis].')

  if axis!=-1 and axis!=f.ndim-1:
    itgd = NP.rollaxis(f,axis,f.ndim)
  else:
    itgd = f

  return itgd,x

def evalFloat(strf):
  return float(eval(strf.replace('pi',repr(NP.pi))))

def evalInt(strf):
  return int(eval(strf))
#==================================================================
# Coordinates conversion

def cartesianToSpherical(MC):
  MC = NP.asarray(MC)
  return NP.array([NP.sqrt(MC[0,...]**2+MC[1,...]**2+MC[2,...]**2),\
                   NP.arctan2(NP.sqrt(MC[0,...]**2+MC[1,...]**2),MC[2,...]),\
                   NP.arctan2(MC[1,...],MC[0,...]) % (2*NP.pi)])

def sphericalToCartesian(MS):
  MS = NP.asarray(MS)
  return NP.array([MS[0,...]*NP.sin(MS[1,...])*NP.cos(MS[2,...]),\
                   MS[0,...]*NP.sin(MS[1,...])*NP.sin(MS[2,...]),\
                   MS[0,...]*NP.cos(MS[1,...])])

def cartesianToCopolar(MC):
  r = NP.sqrt(MC[0,...]**2+MC[1,...]**2)
  theta = NP.arctan2(MC[0,...],MC[1,...])
  return NP.array([r,(NP.arctan2(MC[0,...],MC[1,...]) + 2*NP.pi) % (2*NP.pi)])

def copolarToCartesian(MS):
  return NP.array([MS[0,...]*NP.sin(MS[1,...]),\
          MS[0,...]*NP.cos(MS[1,...])])

def cylindricalToCartesian(MC):
  MC = NP.asarray(MC)
  return NP.array([MC[0,...]*NP.cos(MC[1,...]),\
                   MC[0,...]*NP.sin(MC[1,...]),\
                   MC[2,...]])

def cartesianToCylindrical(MC):
  MC = NP.asarray(MC)
  return NP.array([NP.sqrt(MC[0,...]**2+MC[1,...]**2),\
                   NP.arctan2(MC[1,...],MC[0,...]),\
                   MC[2,...]])

def cylindricalToSpherical(MC):
  MC = NP.asarray(MC)
  return NP.array([NP.sqrt(MC[0,...]**2+MC[2,...]**2),\
                   NP.arctan2(MC[0,...],MC[2,...]),\
                   MC[1,...]])

def sphericalToCylindrical(MS):
  MS = NP.asarray(MS)
  return NP.array([MS[0,...]*NP.sin(MS[1,...]),\
                    MS[2,...],\
                    MS[0,...]*NP.cos(MS[1,...])])

# Conversion of vectors between coordinates systems.
# The input arrays must be (3,...)
# theta and phi are arrays of locations of the vectors
# use geom.GetXXXXXXXCoordsList() or geom.GetXXXXXCoordsMeshGrid() 
# to get these arrays

def sphericalToCartesianVector(VS,theta,phi):

  VS = NP.asarray(VS)
  VC = NP.zeros(VS.shape,dtype=VS.dtype)

  ct = NP.cos(theta)
  st = NP.sin(theta)
  cf = NP.cos(phi)
  sf = NP.sin(phi)

  VC[0,...] = st*cf*VS[0,...] + ct*cf*VS[1,...] - sf*VS[2,...]
  VC[1,...] = st*sf*VS[0,...] + ct*sf*VS[1,...] + cf*VS[2,...]
  VC[2,...] =    ct*VS[0,...] -    st*VS[1,...]

  return VC

def sphericalToCylindricalVector(VS,theta):

  VS = NP.asarray(VS)
  VC = NP.zeros(VS.shape,dtype=VS.dtype)

  ct = NP.cos(theta)
  st = NP.sin(theta)

  VC[0,...] = st*VS[0,...] + ct*VS[1,...]
  VC[1,...] = VS[2,...]
  VC[2,...] = ct*VS[0,...] - st*VS[1,...]

  return VC

def cartesianToSphericalVector(VC,theta,phi):

  VC = NP.asarray(VC)
  VS = NP.zeros(VC.shape,dtype=VC.dtype)

  ct = NP.cos(theta)
  st = NP.sin(theta)
  cf = NP.cos(phi)
  sf = NP.sin(phi)

  VS[0,...] = st*cf*VC[0,...] + st*sf*VC[1,...] + ct*VC[2,...]
  VS[1,...] = ct*cf*VC[0,...] + ct*sf*VC[1,...] - st*VC[2,...]
  VS[2,...] =   -sf*VC[0,...] +    cf*VC[1,...]

  return VS

def cartesianToCylindricalVector(VCart,phiCyl):

  VCart = NP.asarray(VCart)
  VCyl  = NP.zeros(VCart.shape,dtype=VCart.dtype)

  cf = NP.cos(phiCyl)
  sf = NP.sin(phiCyl)

  VCyl[0,...] =  cf*VCart[0,...] + sf*VCart[1,...]
  VCyl[1,...] = -sf*VCart[0,...] + cf*VCart[1,...]
  VCyl[2,...] = VCart[2,...]

  return VCyl

def cylindricalToSphericalVector(VC,theta):

  VC = NP.asarray(VC)
  VS = NP.zeros(VC.shape,dtype=VC.dtype)

  ct = NP.cos(theta)
  st = NP.sin(theta)

  VS[0,...] = st*VC[0,...] + ct*VC[2,...]
  VS[1,...] = ct*VC[0,...] - st*VC[2,...]
  VS[2,...] = VC[1,...]

  return VS

def cylindricalToCartesianVector(VCyl,phiCyl):
  # phiCyl : polar angle (cylindrical coords) of location of vectors

  VCyl  = NP.asarray(VCyl)
  VCart = NP.zeros(VCyl.shape,dtype=VCyl.dtype)

  cf = NP.cos(phiCyl)
  sf = NP.sin(phiCyl)

  VCart[0,...] = cf*VCyl[0,...] - sf*VCyl[1,...]
  VCart[1,...] = sf*VCyl[0,...] + cf*VCyl[1,...]
  VCart[2,...] = VCyl[2,...]

  return VCart

def power_bit_length(x):
  return 2**(x-1).bit_length()

# =================================================================
'''
class progressBar(object):

  def __init__(self,N,bonus=None,w=50):

    self.progress    = 0.e0
    self.intProgress = 0
    self.end         = N

    self.w  = w
    str     = '  0% |'+(w*' ')+'|'
    if bonus is not None:
      self.bonus = bonus
      str = str + ' ' + bonus
    self.pbl = len(str)
    print(str,end='\r')

  def update(self,w=50):

    self.progress += 1
    val = self.progress
    percent = int(NP.floor(val/(1.0*self.end)*100))
    prctstr = '%3d%%' % percent
    nequal  = max(percent*self.w//100-1,0)
    if (percent!=self.intProgress):
      self.intProgress = percent

      print(self.pbl*'\r',end='\r')
      print(prctstr + ' |' + (nequal* '=') + '|' + ((self.w-nequal-1)*' ') + '|',end='\r')
      if hasattr(self,'bonus'):
        print(' '+self.bonus,end='\r')
      sys.stdout.flush()

  def __del__(self):
    print()
    sys.stdout.flush()
'''
