#!/usr/local/bin/python3.9

from optparse import OptionParser


def find_nexus_modules():
    import sys
    nexus_lib = os.path.abspath(os.path.join(__file__,'..','..','lib'))
    assert(os.path.exists(nexus_lib))
    sys.path.append(nexus_lib)
#end def find_nexus_modules


def import_nexus_module(module_name):
    import importlib
    return importlib.import_module(module_name)
#end def import_nexus_module


# Load Nexus modules
try:
    # Attempt specialized path-based imports.
    #  (The executable should still work even if Nexus is not installed)
    find_nexus_modules()

    versions = import_nexus_module('versions')
    nexus_version = versions.nexus_version
    del versions

    generic = import_nexus_module('generic')
    obj = generic.obj
    del generic

    developer = import_nexus_module('developer')
    DevBase     = developer.DevBase
    error       = developer.error
    ci          = developer.ci
    unavailable = developer.unavailable
    del developer

    memory = import_nexus_module('memory')

    observables = import_nexus_module('observables')
    ChargeDensity = observables.ChargeDensity
    vlog = observables.vlog
    del observables

    fileio = import_nexus_module('fileio')
    XsfFile = fileio.XsfFile
    del fileio
except:
    from versions import nexus_version
    from generic import obj
    from developer import DevBase,error,ci,unavailable
    import memory
    from observables import ChargeDensity, vlog
    from fileio import XsfFile
#end try


# External imports
try:
    import numpy as np
except:
    np = unavailable('numpy')
#end try


def comma_list(s):
    if ',' in s:
        s = s.replace(',',' ')
    #end if
    s = s.strip()
    if ' ' in s:
        tokens = s.split()
        s = ''
        for t in tokens:
            s+=t+','
        #end for
        s=s[:-1]
    #end if
    return s
#end def comma_list


class QDRBase(DevBase):
    name    = 'qdens-radial'
    verbose = None
    options = obj()
    parser  = None
    
    def vlog(self,*args,**kwargs):
        if self.verbose:
            DevBase.log(self,*args,**kwargs)
        #end if
    #end def vlog

    def vmlog(self,*args,**kwargs):
        args = list(args)+[' (memory %3.2f MB)'%(memory.resident(children=True)/1e6)]
        self.vlog(*args,**kwargs)
    #end def vmlog

    def help(self):
        self.log('\n'+self.parser.format_help().strip()+'\n')
    #end def help

    def exit(self):
        self.vlog('\n{0} finished\n'.format(self.name))
        exit()
    #end def exit

    def error(self,msg,loc=None):
        if loc is None:
            loc = self.name
        #end if
        error(msg,loc)
        #self.exit()
    #end def error

#end class QDBase


class QMCDensityRadialProcessor(QDRBase):
    def __init__(self):
        pass
    #end def __init__

    def read_command_line(self):
        usage = '''usage: %prog [options] xsf_file'''
        parser = OptionParser(usage=usage,add_help_option=False,version='%prog {}.{}.{}'.format(*nexus_version))
        parser.add_option('-h','--help',dest='help',
                          action='store_true',default=False,
                          help='Print help information and exit (default=%default).'
                          )
        parser.add_option('-v','--verbose',dest='verbose',
                          action='store_true',default=False,
                          help='Print detailed information (default=%default).'
                          )
        parser.add_option('-r','--radii',dest='radii',
                          default='None',
                          help='List of cutoff radii (default=%default).'
                          )
        parser.add_option('-s','--species',dest='species',
                          default='None',
                          help='List of species (default=%default).'
                          )
        parser.add_option('-a','--app',dest='app',
                          default='qmcpack',
                          help='Source that generated the .xsf file. Options: "pwscf", "qmcpack" (default=%default).'
                          )
        parser.add_option('-c','--cumulative',dest='cumulative',
                          action='store_true',default=False,
                          help='Evaluate cumulative radial density at cutoff radii (default=%default).'
                          )
        parser.add_option('--vmc',dest='vmc_file',
                          default='None',
                          help='Location of VMC to be used for extrapolating mixed-estimator bias (default=%default).'
                          )
        parser.add_option('--write',dest='write_extrap',
                          action='store_true',default=False,
                          help='Write extrapolated values to qmc-extrap.xsf (default=%default).'
                          )
        parser.add_option('--vmcerr',dest='vmc_err_file',
                          default='None',
                          help='Location of VMC+err to be used for extrapolating mixed-estimator bias and resampling (default=%default).'
                          )
        parser.add_option('--dmcerr',dest='dmc_err_file',
                          default='None',
                          help='Location of DMC+err to be used for resampling (default=%default).'
                          )
        parser.add_option('--seed',dest='random_seed',
                          default='None',
                          help='Random seed used for re-sampling. (default=%default).'
                          )
        parser.add_option('-n','--nsamples',dest='nsamples',
                          default='50',
                          help='Number of samples for resampling (default=%default).'
                          )
        parser.add_option('-p','--plot',dest='plot',
                          action='store_true',default=False,
                          help='Show plots interactively (default=%default).'
                          )

        options,file_in = parser.parse_args()

        QDRBase.parser = parser
        QDRBase.options.transfer_from(options.__dict__)

        # Command line options as Nexus obj() data structure
        opt = self.options

        opt.nsamples = int(opt.nsamples)

        QDRBase.verbose = opt.verbose

        if opt.help or len(file_in)!=1:
            self.help()
            self.exit()
        #end if

        self.vlog('\n{0} initializing'.format(self.name))
        self.vlog('\noptions provided:')
        self.vlog(str(self.options))

        # handle options
        #   options initialized to "None"
        for k,v in opt.items():
            if isinstance(v,str) and v=='None':
                opt[k] = None
            #end if
        #end for

        self.file_in = file_in[0]

        #   --species option
        if opt.species is not None:
            opt_failed = False
            spc = opt.species
            try:
                opt.species = np.array(comma_list(opt.species).split(','))
                opt_failed = len(opt.species.shape)!=1
            except:
                opt_failed = True
            #end try
            if opt_failed:
                self.error('species must be a list of species labels\nyou provided: {0}'.format(spc))
            #end if
        #end if

        #   --radii option
        if opt.radii is not None:
            opt_failed = False
            rad = opt.radii
            try:
                opt.radii = eval(comma_list(opt.radii))
                if np.isscalar(opt.radii):
                    opt.radii = np.array([opt.radii],dtype=float)
                else:
                    opt.radii = np.array(opt.radii,dtype=float)
                #end if
                opt_failed = len(opt.radii.shape)!=1
                if opt.species is not None:
                    opt_failed = opt_failed or \
                            (len(opt.radii)!=len(opt.species) and len(opt.radii)>1) # If only 1 radius was given, we apply it globally to all species 
                #end if
                # Note: if species list not provided. the length of radii list must be one or the same as 
                # number of species in .xsf file, but this will need to be checked after .xsf is read
            except:
                opt_failed = True
            #end try
            if opt_failed:
                self.error('radii must be a list of values and must be the same length as species list\nyou provided: {0}'.format(rad))
            #end if
        #end if

        #   --seed option
        if opt.random_seed is not None:
            opt_failed = False
            rseed = opt.random_seed
            try:
                opt.random_seed = int(opt.random_seed)
            except:
                opt_failed = True
            #end try
            if opt_failed:
                self.error('seed must be an integer\nyou provided: {0}'.format(rseed))
            #end if
        #end if


    #end def read_command_line(self)

    def process(self):
        opt = self.options
        file = self.file_in

        # Cutdown on the printing, unless user provides -v option
        if not opt.verbose:
            vlog.set_verbosity('none')
        #end if

        if (opt.vmc_err_file is not None or opt.dmc_err_file is not None) and opt.app!='qmcpack':
            self.error('Error files (and subsequent resampling) are not compatable with {} app files.'.format(opt.app))
        #end if

        if opt.vmc_err_file is not None and opt.vmc_file is None:
            self.error('A VMC file (--vmc) must be provided along with the VMC+err file (--vmcerr)')
        #end if

        # if VMC file is given, perform extrapolation before 
        # instantiating ChargeDensity object
        if opt.vmc_file is not None:
            self.log('Extrapolating from VMC and DMC densities...')
            vmc_xsf_data = XsfFile(filepath=opt.vmc_file)
            dmc_xsf_data = XsfFile(filepath=file)
            # Initialize extrapolated XsfFile object
            extrap_xsf_data = dmc_xsf_data.copy()
            extrap_xsf_data.data[3]['density']['density'].values = 2.*dmc_xsf_data.data[3]['density']['density'].values-vmc_xsf_data.data[3]['density']['density'].values
            file = extrap_xsf_data.copy()
            if opt.write_extrap:
                file.write('qmc-extrap.xsf')
            #end if
        else:
            dmc_xsf_data = XsfFile(filepath=file)
            file = dmc_xsf_data.copy()
            if opt.write_extrap:
                self.log('\nWARNING: You provided --write, however, extrapolated density will not be written since a VMC file was not provided (via --vmc=VMC_FILE).\n')
            #end if
        #end if

        show_error=False
        # If error files are given, then user wants resampled error bars
        if opt.dmc_err_file is not None:
            self.log('Resampling to obtain error bar (NOTE: This can be slow)...')
            if opt.vmc_file is not None and opt.vmc_err_file is not None:
                # Data without errors
                vmc_xsf_data = XsfFile(filepath=opt.vmc_file)
                dmc_xsf_data = XsfFile(filepath=self.file_in)
                # Data with errors
                vmc_plus_err_xsf_data = XsfFile(filepath=opt.vmc_err_file)
                dmc_plus_err_xsf_data = XsfFile(filepath=opt.dmc_err_file)
                # Errors only
                vmc_err_xsf_data = vmc_plus_err_xsf_data.copy() 
                vmc_err_xsf_data.data[3]['density']['density'].values = vmc_plus_err_xsf_data.data[3]['density']['density'].values \
                                   - vmc_xsf_data.data[3]['density']['density'].values
                dmc_err_xsf_data = dmc_plus_err_xsf_data.copy() 
                dmc_err_xsf_data.data[3]['density']['density'].values = dmc_plus_err_xsf_data.data[3]['density']['density'].values \
                                   - dmc_xsf_data.data[3]['density']['density'].values

                # Get standard deviation corresponding to the extrapolation formula: 
                #   
                # std_dev = sqrt(4*sigma_dmc**2+sigma_vmc**2)
                #
                extrap_err_xsf_data = dmc_err_xsf_data.copy()
                extrap_err_xsf_data.data[3]['density']['density'].values = np.sqrt(4.0*dmc_err_xsf_data.data[3]['density']['density'].values**2+vmc_err_xsf_data.data[3]['density']['density'].values**2)
                err_data = extrap_err_xsf_data.copy()
            elif opt.vmc_file is None:
                # Data without errors
                dmc_xsf_data = XsfFile(filepath=self.file_in)
                # Data with errors
                dmc_plus_err_xsf_data = XsfFile(filepath=opt.dmc_err_file)
                # Errors only
                dmc_err_xsf_data = dmc_plus_err_xsf_data.copy() 
                dmc_err_xsf_data.data[3]['density']['density'].values = dmc_plus_err_xsf_data.data[3]['density']['density'].values \
                                   - dmc_xsf_data.data[3]['density']['density'].values
                err_data = dmc_err_xsf_data.copy()

            else:
                self.error('VMC and DMC densities provided but only one +err file was given.\nPlease either provide no +err file (which avoids resampling) or\nprovide a +err file for both VMC and DMC with --vmcerr and --dmcerr arguments.')
            #end if
            show_error=True

            # We now have the means and the errors (either extrapolated or not). So let's now resample for the errorbars
            
            sum_dict = {}
            sum_sq_dict = {}
            std_dict = {}

            if opt.random_seed is not None:
                np.random.seed(opt.random_seed)
            #end if

            nsamples = opt.nsamples
            self.log('Will compute {} samples...'.format(nsamples))
            for i in range(nsamples):
                self.log('sample: {}'.format(i))
                resample_xsf_data = file.copy()
                for d1_idx, d1 in enumerate(file.data[3]['density']['density'].values):
                    for d2_idx, d2 in enumerate(d1):
                        for d3_idx, d3 in enumerate(d2):
                            mu = file.data[3]['density']['density'].values[d1_idx][d2_idx][d3_idx]
                            sgma = err_data.data[3]['density']['density'].values[d1_idx][d2_idx][d3_idx]
                            sample = np.random.normal(mu,sgma)
                            resample_xsf_data.data[3]['density']['density'].values[d1_idx][d2_idx][d3_idx] = sample
                        #end for
                    #end for
                #end for

                # We've resampled each grid point, so now calculate desired quantity
                cds = ChargeDensity()
                cds.read_xsf(resample_xsf_data)
                equiv_atoms = list(cds.info.structure.equivalent_atoms().keys())
                if opt.app=='pwscf':
                    cds.change_distance_units('B')
                else:
                    cds.volume_normalize()
                #end if

                if i==0:
                    if opt.species is not None:
                        for at in opt.species:
                            sum_dict[at] = 0.0
                            sum_sq_dict[at] = 0.0
                        #end for
                    else:
                        for at in equiv_atoms:
                            sum_dict[at] = 0.0
                            sum_sq_dict[at] = 0.0
                        #end for
                    #end if
                #end if

                if opt.species is not None and not set(opt.species).issubset(equiv_atoms):
                    self.error('Input species (-s,--species) must be a subset of the species in the .xsf file.\nThe .xsf contains: {}\nYou provided: {}.'.format(equiv_atoms,opt.species.tolist()))
                #end if
                if opt.cumulative:
                    if opt.species is None:
                        rdens_samp = cds.cumulative_radial_density(rmax=opt.radii.tolist())
                    else:
                        rdens_samp = cds.cumulative_radial_density(rmax=opt.radii.tolist(),species=opt.species.tolist())
                    #end if
                else:
                    if opt.species is None:
                        rdens_samp = cds.radial_density(rmax=opt.radii.tolist())
                    else:
                        rdens_samp = cds.radial_density(rmax=opt.radii.tolist(),species=opt.species.tolist())
                    #end if
                #end if

                if opt.species is not None:
                    for ati,at in enumerate(opt.species):
                        sum_dict[at] += rdens_samp.tot[at].density[-1]
                        sum_sq_dict[at] += rdens_samp.tot[at].density[-1]**2
                    #end for
                else:
                    for ati,at in enumerate(equiv_atoms):
                        sum_dict[at] += rdens_samp.tot[at].density[-1]
                        sum_sq_dict[at] += rdens_samp.tot[at].density[-1]**2
                    #end for
                #end if

            #end for
            if opt.species is not None:
                for ati,at in enumerate(opt.species):
                    std_dict[at] = np.sqrt(sum_sq_dict[at]/nsamples-(sum_dict[at]/nsamples)**2)
                #end for
            else:
                for ati,at in enumerate(equiv_atoms):
                    std_dict[at] = np.sqrt(sum_sq_dict[at]/nsamples-(sum_dict[at]/nsamples)**2)
                #end for
            #end if

        else:
            if opt.vmc_err_file is not None:
                self.error('VMC and DMC densities provided but only one +err file was given.\nPlease either provide no +err file (which avoids resampling) or\nprovide a +err file for both VMC and DMC with --vmcerr and --dmcerr arguments.')
            #end if
        #end if

        # Instantiate ChargeDensity object
        cd = ChargeDensity()
        cd.read_xsf(file)
        equiv_atoms = list(cd.info.structure.equivalent_atoms().keys())

        if opt.species is not None and not set(opt.species).issubset(equiv_atoms):
            self.error('Input species (-s,--species) must be a subset of the species in the .xsf file.\nThe .xsf contains: {}\nYou provided: {}.'.format(equiv_atoms,opt.species.tolist()))
        #end if

        if opt.app=='pwscf':
            cd.change_distance_units('B')
        else:
            cd.volume_normalize()
        #end if

        self.log('\nNorm:',cd.norm())

        if opt.plot:
            if opt.species is None:
                cd.plot_radial_density(rmax=opt.radii.tolist(),cumulative=opt.cumulative)
            else: 
                cd.plot_radial_density(rmax=opt.radii.tolist(),species=opt.species.tolist(),cumulative=opt.cumulative)
            #end if
        #end if

        if opt.cumulative:
            cstr = 'Cumulative'
            if opt.species is None:
                rdens = cd.cumulative_radial_density(rmax=opt.radii.tolist())
            else:
                rdens = cd.cumulative_radial_density(rmax=opt.radii.tolist(),species=opt.species.tolist())
        else:
            cstr = 'Non-Cumulative'
            if opt.species is None:
                rdens = cd.radial_density(rmax=opt.radii.tolist())
            else:
                rdens = cd.radial_density(rmax=opt.radii.tolist(),species=opt.species.tolist())
        #end if

        if opt.species is not None:
            for ati,at in enumerate(opt.species):
                if len(opt.radii)>1:
                    # Distinct radii for all species
                    r = opt.radii[ati]
                else:
                    # Global radius
                    r = opt.radii[0]
                #end if
                if not show_error:
                    self.log('{} Value of {} Species at Cutoff {} is: {}\n'.format(cstr,at,r,rdens.tot[at].density[-1]))
                else:
                    self.log('{} Value of {} Species at Cutoff {} is: {}+/-{}\n'.format(cstr,at,r,rdens.tot[at].density[-1],std_dict[at]))
                #end if
            #end for
        else:
            for ati,at in enumerate(equiv_atoms):
                if len(opt.radii)>1:
                    # Distinct radii for all species
                    r = opt.radii[ati]
                else:
                    # Global radius
                    r = opt.radii[0]
                #end if
                if not show_error:
                    self.log('{} Value of {} Species at Cutoff {} is: {}\n'.format(cstr,at,r,rdens.tot[at].density[-1]))
                else:
                    self.log('{} Value of {} Species at Cutoff {} is: {}+/-{}\n'.format(cstr,at,r,rdens.tot[at].density[-1],std_dict[at]))
                #end if
            #end for
        #end if


    #end def process(self)

#end class QMCDensityRadialProcessor

if __name__=='__main__':
    qdensradial = QMCDensityRadialProcessor()

    qdensradial.read_command_line()

    qdensradial.process()

    qdensradial.exit()
#end if



