#!/usr/local/bin/python3.9

import os
import sys
from optparse import OptionParser

try:
    import numpy as np
except:
    print('qmc-fit error: numpy is not present on your machine.\n  Please install scipy and retry.')
#end try

try:
    from scipy.optimize import fmin
except ImportError:
    print('qmc-fit error: scipy is not present on your machine.\n  Please install scipy and retry.')
#end try

try:
    import matplotlib.pyplot as plt
    params = {'legend.fontsize':14,'figure.facecolor':'white','figure.subplot.hspace':0.,
          'axes.labelsize':16,'xtick.labelsize':14,'ytick.labelsize':14}
    plt.rcParams.update(params)
    plots_available = True
except (ImportError,RuntimeError):
    plots_available = False
#end try


def find_nexus_modules():
    import sys
    nexus_lib = os.path.abspath(os.path.join(__file__,'..','..','lib'))
    assert(os.path.exists(nexus_lib))
    sys.path.append(nexus_lib)
#end def find_nexus_modules


def import_nexus_module(module_name):
    import inspect
    import importlib
    return importlib.import_module(module_name)
#end def import_nexus_module


# Load Nexus modules
try:
    # Attempt specialized path-based imports.
    #  (The executable should still work even if Nexus is not installed)
    find_nexus_modules()

    versions = import_nexus_module('versions')
    nexus_version = versions.nexus_version
    del versions

    generic = import_nexus_module('generic')
    obj   = generic.obj
    log   = generic.log
    warn  = generic.warn
    error = generic.error
    del generic

    developer = import_nexus_module('developer')
    DevBase = developer.DevBase
    del developer

    numerics = import_nexus_module('numerics')
    jackknife            = numerics.jackknife
    jackknife_aux        = numerics.jackknife_aux
    simstats             = numerics.simstats
    equilibration_length = numerics.equilibration_length
    curve_fit            = numerics.curve_fit
    least_squares        = numerics.least_squares
    del numerics
except:
    # Failing path-based imports, import installed Nexus modules.
    from versions import nexus_version
    from generic import obj,log,warn,error
    from developer import DevBase
    from numerics import jackknife,jackknife_aux
    from numerics import simstats,equilibration_length
    from numerics import curve_fit,least_squares
#end try



all_fit_functions = obj(
    ts = obj(
        linear = obj(
            nparam   = 2,
            function = lambda p,t: p[0]+p[1]*t,
            format   = '{0} + {1}*t',
            params   = [('intercept',lambda p: p[0])],
            ),
        quadratic = obj(
            nparam   = 3,
            function = lambda p,t: p[0]+p[1]*t+p[2]*t*t,
            format   = '{0} + {1}*t + {2}*t^2',
            params   = [('intercept',lambda p: p[0])],
            ),
        sqrt = obj(
            nparam   = 3,
            function = lambda p,t: p[0]+p[1]*np.sqrt(t)+p[2]*t,
            format   = '{0} + {1}*sqrt(t) + {2}*t',
            params   = [('intercept',lambda p: p[0])],
            ),
        ),
    )

fit_functions = obj()



def qmcfit(q,E,fname='linear',minimizer=least_squares):
    # ensure data is in proper array format
    if isinstance(E,(list,tuple)):
        E = np.array(E,dtype=float)
    #end if
    Edata = None
    if len(E)!=E.size and len(E.shape)==2:
        E = E.T
        Edata = E
        E     = Edata.mean(axis=0)
    #end if

    # unpack fitting function information
    finfo = fit_functions[fname]
    fitfunc  = finfo.function
    auxfuncs = obj()
    auxres   = obj()
    for name,func in finfo.params:
        auxfuncs[name]=func
    #end for

    # setup initial guess parameters
    if fname=='quadratic':
        pp = np.polyfit(q,E,2)
    else:
        pp = np.polyfit(q,E,1)
    #end if
    p0 = tuple(list(reversed(pp))+(finfo.nparam-len(pp))*[0])

    # get an optimized fit of the means
    pf = curve_fit(q,E,fitfunc,p0,minimizer)

    # obtain jackknife mean+error estimates of fitted parameters
    jcapture = obj()
    pmean,perror = jackknife(data     = Edata,
                             function = curve_fit,
                             args     = [q,None,fitfunc,pf,minimizer],
                             position = 1,
                             capture  = jcapture,
                             )
    
    # obtain jackknife estimates of derived parameters
    if len(auxfuncs)>0:
        psamples = jcapture.jsamples
        for auxname,auxfunc in auxfuncs.items():
            auxres[auxname] = jackknife_aux(psamples,auxfunc)
        #end for
    #end if

    return pf,pmean,perror,auxres
#end def qmcfit



# Reads scalar.dat files and extracts energy series
def process_scalar_files(scalar_files,equils=None,reblock_factors=None,series_start=None):
    if len(scalar_files)==0:
        error('must provide at least one scalar file')
    #end if
    for scalar_file in scalar_files:
        if not os.path.exists(scalar_file):
            error('scalar file does not exist: {0}'.format(scalar_file))
        #end if
        if not scalar_file.endswith('.scalar.dat'):
            error('file must be of type scalar.dat: {0}'.format(scalar_file))
        #end if
    #end for
    if series_start!=None:
        n=0
        for scalar_file in scalar_files:
            filename = os.path.split(scalar_file)[1]
            series = int(filename.split('.')[-3][1:])
            if series==series_start:
                scalar_files = scalar_files[n:]
                break
            #end if
            n+=1
        #end for
    #end if

    if isinstance(equils,(int,np.int_)):
        equils = len(scalar_files)*[equils]
    elif equils is not None and len(equils)!=len(scalar_files):
        error('must provide one equilibration length per scalar file\nnumber of equils provided: {0}\nnumber of scalar files provided: {1}\nequils provided: {2}\nscalar files provided: {3}'.format(len(equils),len(scalar_files),equils,scalar_files))
    #end if

    if isinstance(reblock_factors,(int,np.int_)):
        reblock_factors = len(scalar_files)*[reblock_factors]
    elif reblock_factors is not None and len(reblock_factors)!=len(scalar_files):
        error('must provide one reblocking factor per scalar file\nnumber of reblock_factors provided: {0}\nnumber of scalar files provided: {1}\nreblock_factors provided: {2}\nscalar files provided: {3}'.format(len(reblock_factors),len(scalar_files),reblock_factors,scalar_files))
    #end if

    # extract energy data from scalar files
    Edata = []
    Emean = []
    Eerr  = []
    Ekap  = []
    n = 0
    for scalar_file in scalar_files:
        fobj = open(scalar_file,'r')
        quantities = fobj.readline().split()[2:]
        fobj.close()
        rawdata = np.loadtxt(scalar_file)[:,1:].transpose()
        qdata = obj()
        for i in range(len(quantities)):
            q = quantities[i]
            d = rawdata[i,:]
            qdata[q]  = d
        #end for
        E = qdata.LocalEnergy

        # exclude blocks marked as equilibration
        if equils is not None:
            nbe = equils[n]
        else:
            nbe = equilibration_length(E)
        #end if
        if nbe>len(E):
            error('equilibration cannot be applied\nequilibration length given is greater than the number of blocks in the file\nfile name: {0}\n# blocks present: {1}\nequilibration length given: {2}'.format(scalar_file,len(E),nbe))
        #end if
        E = E[nbe:]

        mean,var,err,kap = simstats(E)
        Emean.append(mean)
        Eerr.append(err)
        Ekap.append(kap)

        Edata.append(E)
        n+=1
    #end for
    Emean = np.array(Emean)
    Eerr  = np.array(Eerr)
    Ekap  = np.array(Ekap)

    # reblock data into target length
    block_targets = []
    if reblock_factors is None:
        # find block targets based on autocorrelation time, if needed
        for n in range(len(Edata)):
            block_targets.append(len(Edata[n])//Ekap[n])
        #end if
    else:
        for n in range(len(Edata)):
            block_targets.append(len(Edata[n])//reblock_factors[n])
        #end if        
    #end if

    bt = np.array(block_targets,dtype=int).min()
    for n in range(len(Edata)):
        E = Edata[n]
        nbe = len(E)%bt
        E = E[nbe:]
        reblock = len(E)//bt
        E.shape = (bt,reblock)
        if reblock>1:
            E = E.sum(1)/reblock
        #end if
        E.shape = (bt,)
        Edata[n] = E
    #end for
    Edata = np.array(Edata,dtype=float)

    return Edata,Emean,Eerr,scalar_files
#end def process_scalar_files



def stat_strings(mean,error):
    d = int(max(2,1-np.floor(np.log(error)/np.log(10.))))
    fmt = '{0:16.'+str(d)+'f}'
    mstr = fmt.format(mean).strip()
    estr = fmt.format(error).strip()
    return mstr,estr
#end def stat_strings


def parse_list(opt,name,dtype,len1=False):
    try:
        if opt[name]!=None:
            opt[name] = np.array(opt[name].split(),dtype=dtype)
            if len1 and len(opt[name])==1:
                opt[name] = opt[name][0]
            #end if
        #end if
    except:
        error('{0} list misformatted: {1}'.format(name,opt[name]))
    #end try
#end def parse_list


def timestep_fit():
    # read command line inputs
    usage = '''usage: %prog ts [options] [scalar files]'''
    parser = OptionParser(usage=usage,add_help_option=False,version='%prog {}.{}.{}'.format(*nexus_version))

    parser.add_option('-f','--fit',dest='fit_function',
                          default='linear',
                          help='Fitting function, options are {0} (default=%default).'.format(sorted(fit_functions.keys()))
                      )
    parser.add_option('-s', '--series_start',dest='series_start',
                      type="int", default=None,
                      help='Series number for first DMC run.  Use to exclude prior VMC scalar files if they have been provided (default=%default).'
                      )
    parser.add_option('-t','--timesteps',dest='timesteps',
                          default=None,
                          help='Timesteps corresponding to scalar files, excluding any prior to --series_start (default=%default).'
                      )
    parser.add_option('-e','--equils',dest='equils',
                          default=None,
                          help='Equilibration lengths corresponding to scalar files, excluding any prior to --series_start.  Can be a single value for all files.  If not provided, equilibration periods will be estimated.'
                      )
    parser.add_option('-b','--reblock_factors',dest='reblock_factors',
                          default=None,
                          help='Reblocking factors corresponding to scalar files, excluding any prior to --series_start.  Can be a single value for all files.  If not provided, reblocking factors will be estimated.'
                      )
    parser.add_option('--noplot',dest='noplot',
                      action='store_true',default=False,
                      help='Do not show plots. (default=%default).'
                      )

    options,args = parser.parse_args()
    scalar_files = list(sorted(args[1:]))
    opt = obj(**options.__dict__)

    if len(scalar_files)==0:
        log('\n'+parser.format_help().strip()+'\n')
        exit()
    #end if

    if opt.fit_function not in fit_functions:
        error('invalid fitting function: {0}\nvalid options are: {1}'.format(opt.fit_function,sorted(fit_functions.keys())))
    #end if

    if opt.timesteps is None:
        opt.timesteps = ''
    #end if
    parse_list(opt,'timesteps',float)
    parse_list(opt,'equils',int,len1=True)
    parse_list(opt,'reblock_factors',int,len1=True)
    
    # read in scalar energy data
    Edata,Emean,Eerror,scalar_files = process_scalar_files(
        scalar_files    = scalar_files,
        series_start    = opt.series_start,
        equils          = opt.equils,
        reblock_factors = opt.reblock_factors,
        )

    if len(Edata)!=len(opt.timesteps):
        error('must provide one timestep per scalar file\nnumber of timesteps provided: {0}\nnumber of scalar files provided: {1}\ntimeteps provided: {2}\nscalar files provided: {3}'.format(len(opt.timesteps),len(scalar_files),opt.timesteps,scalar_files))
    #end if

    # perform jackknife analysis of the fit
    pf,pmean,perror,auxres = qmcfit(opt.timesteps,Edata,opt.fit_function)

    # print text info about the fit results
    func_info = fit_functions[opt.fit_function]
    pvals = []
    for n in range(len(pmean)):
        pvals.append('({0} +/- {1})'.format(*stat_strings(pmean[n],perror[n])))
    #end for

    log('\nfit function  : '+opt.fit_function)
    log('fitted formula: '+func_info.format.format(*pvals))
    for pname,pfunc in func_info.params:
        pm,pe = stat_strings(*auxres[pname])
        log('{0:<14}: {1} +/- {2}  Ha\n'.format(pname,pm,pe))
    #end for

    # plot the fit (if available)
    if plots_available and not opt.noplot:
        lw = 2
        ms = 10

        ts = opt.timesteps
        tsmax = ts.max()
        E0,E0err = auxres.intercept
        tsfit = np.linspace(0,1.1*tsmax,400)
        Efit  = func_info.function(pmean,tsfit)
        plt.figure()
        plt.plot(tsfit,Efit,'k-',lw=lw)
        plt.errorbar(ts,Emean,Eerror,fmt='b.',ms=ms)
        plt.errorbar([0],[E0],[E0err],fmt='r.',ms=ms)
        plt.xlim([-0.1*tsmax,1.1*tsmax])
        plt.xlabel('DMC Timestep (1/Ha)')
        plt.ylabel('DMC Energy (Ha)')
        plt.show()
    #end if
#end def timestep_fit



if __name__=='__main__':
    fit_types = sorted(all_fit_functions.keys())
    if len(sys.argv)<2:
        error('first argument must be type of fit\ne.g. for a timestep fit, type "qmc-fit ts ..."\nvalid fit types are: {0}'.format(fit_types))
    #end if
    fit_type = sys.argv[1]
    if fit_type in fit_types:
        fit_functions.clear()
        fit_functions.transfer_from(all_fit_functions[fit_type])
    else:
        error('unknown fit type: {0}\nvalid options are: {1}'.format(fit_type,fit_types))
    #end if
        
    if fit_type=='ts':
        timestep_fit()
    else:
        error('unsupported fit type: {0}'.format(fit_type))
    #end if
#end if
