#!/usr/local/bin/python3.9

import gc
import os
from optparse import OptionParser


def find_nexus_modules():
    import sys
    nexus_lib = os.path.abspath(os.path.join(__file__,'..','..','lib'))
    assert(os.path.exists(nexus_lib))
    sys.path.append(nexus_lib)
#end def find_nexus_modules


def import_nexus_module(module_name):
    import importlib
    return importlib.import_module(module_name)
#end def import_nexus_module


# Load Nexus modules
try:
    # Attempt specialized path-based imports.
    #  (The executable should still work even if Nexus is not installed)
    find_nexus_modules()

    versions = import_nexus_module('versions')
    nexus_version = versions.nexus_version
    del versions

    generic = import_nexus_module('generic')
    obj = generic.obj
    del generic

    developer = import_nexus_module('developer')
    DevBase     = developer.DevBase
    error       = developer.error
    ci          = developer.ci
    unavailable = developer.unavailable
    del developer

    memory = import_nexus_module('memory')

    hdfreader = import_nexus_module('hdfreader')
    HDFreader = hdfreader.HDFreader
    del hdfreader

    fileio = import_nexus_module('fileio')
    XsfFile    = fileio.XsfFile
    ChgcarFile = fileio.ChgcarFile
    del fileio

    structure = import_nexus_module('structure')
    read_structure = structure.read_structure
    Structure      = structure.Structure
    del structure

    numerics = import_nexus_module('numerics')
    simplestats = numerics.simplestats
    simstats    = numerics.simstats
    del numerics

    qmcpack_input = import_nexus_module('qmcpack_input')
    QmcpackInput    = qmcpack_input.QmcpackInput
    spindensity_xml = qmcpack_input.spindensity
    spindensity_new_xml = qmcpack_input.spindensity_new  # temporary
    density_xml = qmcpack_input.density  # temporary
    del qmcpack_input
except:
    from versions import nexus_version
    from generic import obj
    from developer import DevBase,error,ci,unavailable
    import memory
    from hdfreader import HDFreader
    from fileio import XsfFile,ChgcarFile
    from structure import read_structure,Structure
    from numerics import simplestats,simstats
    from qmcpack_input import QmcpackInput
    from qmcpack_input import spindensity as spindensity_xml
    from qmcpack_input import spindensity_new as spindensity_new_xml # temporary
    from qmcpack_input import density as density_xml
#end try


try:
    import h5py
except:
    h5py = unavailable('h5py')
#end try
try:
    import numpy as np
except:
    np = unavailable('numpy')
#end try


try:
    import matplotlib
    gui_envs = ['GTKAgg','TKAgg','Qt4Agg','WXAgg']
    for gui in gui_envs:
        try:
            matplotlib.use(gui,warn=False, force=True)
            from matplotlib import pyplot
            success = True
            break
        except:
            continue
        #end try
    #end for
    import matplotlib.pyplot as plt
    params = {'legend.fontsize':14,'figure.facecolor':'white','figure.subplot.hspace':0.,
          'axes.labelsize':16,'xtick.labelsize':14,'ytick.labelsize':14}
    plt.rcParams.update(params)
except:
    plt = unavailable('matplotlib.pyplot')
#end try





def h5_int(val):
    return int(np.array(val))
#end def h5_int

def h5_float(val):
    return float(np.array(val))
#end def h5_float

def h5_array(val):
    return np.array(val)
#end def h5_array





def comma_list(s):
    if ',' in s:
        s = s.replace(',',' ')
    #end if
    s = s.strip()
    if ' ' in s:
        tokens = s.split()
        s = ''
        for t in tokens:
            s+=t+','
        #end for
        s=s[:-1]
    #end if
    return s
#end def comma_list


def input_list(s):
    if ',' in s:
        s = s.replace(',',' ')
    #end if
    s = s.strip()
    if ' ' in s:
        lst = s.split()
    else:
        lst = [s]
    #end if
    return lst
#end def input_list



def qmcpack_filepath(basepath,tokens,g=None):
    file = ''
    for t in tokens:
        if g is not None and t.startswith('g') and t[1:].isdigit():
            t = g
        #end if
        file+=t+'.'
    #end for
    file = file[:-1]
    return os.path.join(basepath,file)
#end def qmcpack_filepath


def get_grid(s,dr=None,delta=None,x_min=None,x_max=None,y_min=None,y_max=None,z_min=None,z_max=None):
    s = s.copy()
    s.change_units('B')
    grid = []
    if delta is None:
        cell = s.axes.copy()
    else:
        dr = delta
        cell = np.array([[1,0,0],
                         [0,1,0],
                         [0,0,1]])
    #end if
    n = 0
    for a in cell:
        grid.append(int(np.ceil(np.sqrt(np.dot(a,a))/dr[n])))
        n+=1
    #end for
    return tuple(grid)
#end def get_grid


def reblock(data,eq,rb=None,filepath=None):
    # return equilibration adjusted data if not reblocking
    if rb is None or rb<2:
        return data[eq:,...]
    #end if
    # adjust equilibration so reblock factor divides evenly
    nd = len(data)
    eq+=(nd-eq)%rb
    d = data[eq:,...]
    # fail if reblocking results in a single data point
    if len(d)//rb<2:
        msg = ''
        if filepath is not None:
            msg = '\nfor file: {0}'.format(filepath)
        #end if
        error('reblocking data results in too few points{0}\npoint in file: {1}\npoints after equilibration: {2}\npoints after reblocking: {3}'.format(msg,nd,nd-eq,len(d)//rb),'reblock')
    #end if
    # reblock the data
    s = list(d.shape)
    snew = tuple([len(d)//rb,rb]+s[1:])
    d.shape = snew
    d = d.mean(1)
    return d
#end def reblock


class QDBase(DevBase):
    name    = 'qdens'
    verbose = None
    options = obj()
    parser  = None
    
    def vlog(self,*args,**kwargs):
        if self.verbose:
            DevBase.log(self,*args,**kwargs)
        #end if
    #end def vlog

    def vmlog(self,*args,**kwargs):
        args = list(args)+[' (memory %3.2f MB)'%(memory.resident(children=True)/1e6)]
        self.vlog(*args,**kwargs)
    #end def vmlog

    def help(self):
        self.log('\n'+self.parser.format_help().strip()+'\n')
    #end def help

    def exit(self):
        self.vlog('\n{0} finished\n'.format(self.name))
        exit()
    #end def exit

    def error(self,msg,loc=None):
        if loc is None:
            loc = self.name
        #end if
        error(msg,loc)
        #self.exit()
    #end def error

    # options accessor functions
    def get_cell(self):
        if opt.structure is not None:
            cell = opt.structure.axes
        else:
            cell = opt.cell
        #end if
        return cell
    #end def get_cell
#end class QDBase



class SingleDensity(QDBase):
    def __init__(self,
                 filepath  = None,
                 format    = None,
                 mean      = None,
                 error     = None,
                 data      = None,
                 grid      = None,
                 structure = None,
                 extension = None,
                 ):
        # initialize class variables
        self.mean      = None
        self.error     = None
        self.grid      = None
        self.structure = None
        self.extension = ''
        # read in data if provided
        if filepath is not None:
            self.read(filepath,format)
        #end if
        # check that inputted types meet expectations
        self.check_type('mean'     ,mean     ,'list/ndarray'      ,(list,np.ndarray))
        self.check_type('error'    ,error    ,'list/ndarray'      ,(list,np.ndarray))
        self.check_type('grid'     ,grid     ,'tuple/list/ndarray',(tuple,list,np.ndarray))
        self.check_type('structure',structure,'Structure'         ,Structure)
        self.check_type('extension',extension,'string'            ,str)
        # process other inputs
        if mean is not None:
            self.mean = np.array(mean,dtype=float)
        #end if
        if error is not None:
            self.error = np.array(error,dtype=float)
        #end if
        if grid is not None:
            if len(grid)!=3:
                self._error('inputted grid must have length 3, received {0}'.format(grid))
            #end if
            self.grid = np.array(grid,dtype=int)
        #end if
        if structure is not None:
            self.structure = structure.copy()
        #end if
        if extension is not None:
            self.extension = extension
        #end if
        if data is not None:
            self.analyze(data)
        #end if
    #end def __init__


    def check_type(self,name,value,stypes,types,allow_none=True):
        if allow_none and value is None:
            return
        #end if
        if not isinstance(value,types):
            self._error('expected {0} for {1}, but received type "{2}"'.format(stypes,name,value.__class__.__name__))
        #end if
    #end def check_type


    def has(self,name):
        return self[name] is not None
    #end def has


    def require(self,name,action):
        if self[name] is None:
            self._error('cannot {0}, {1} is not present'.format(action,name))
        #end if
    #end def require


    def analyze(self,data):
        self.check_type('data',data,'list/ndarray',(list,np.ndarray),False)
        self.mean,self.error = simplestats(data,dim=0)
    #end def analyze

        
    def write(self,prefix,format):
        self.vlog('      writing files for {0}{1} (format={2})'.format(prefix,self.extension,format))
        if format=='xsf':
            self.write_xsf(prefix)
        elif format=='dat':
            self.write_dat(prefix)
        elif format=='chgcar':
            self.write_chgcar(prefix)
        else:
            self._error('invalid density format requested\nformat requested: {0}\nallowed options: xsf, dat'.format(format))
        #end if
    #end def write


    def write_dat(self,prefix):
        extension   = self.extension
        density     = self.mean
        density_err = self.error
        f = open('{0}{1}.dat'.format(prefix,extension),'w')
        for d,de in zip(density.ravel(),density_err.ravel()):
            f.write('{0: 16.8e}  {1: 16.8e}\n'.format(d,de))
        #end for
        f.close()
    #end def write_dat


    def write_xsf(self,prefix):
        extension   = self.extension
        density     = self.mean
        density_err = self.error
        g           = self.grid
        s           = self.structure
        # data needs to be laid out in i,j,k order for xsf
        if g is None:
            self._error('grid must be specified (via --grid or --input) for output format xsf')
        #end if
        if s is None:
            self._error('structure must be specified (via --structure) for output format xsf')
        #end if
        s = s.copy()
        p = s.pos.ravel()
        if p.min()>0 and p.max()<1.0:
            s.pos_to_cartesian()
        #end if
        s.change_units('A')
        cell   = s.axes

        f = XsfFile()
        f.incorporate_structure(s)

        prefix = '{0}{1}'.format(prefix,extension)

        c = 1
        g = 1
        t = 1

        # mean
        #f.add_density(cell,density,centered=c,add_ghost=g,transpose=t)
        f.add_density(cell,density,centered=c,add_ghost=g)
        f.write(prefix+'.xsf')

        # mean + errorbar
        #f.add_density(cell,density+density_err,centered=c,add_ghost=g,transpose=t)
        f.add_density(cell,density+density_err,centered=c,add_ghost=g)
        f.write(prefix+'+err.xsf')

        # mean - errorbar
        #f.add_density(cell,density-density_err,centered=c,add_ghost=g,transpose=t)
        f.add_density(cell,density-density_err,centered=c,add_ghost=g)
        f.write(prefix+'-err.xsf')
    #end def write_xsf


    def write_chgcar(self,prefix):
        extension   = self.extension
        density     = self.mean
        density_err = self.error
        g           = self.grid
        s           = self.structure
        # data needs to be laid out in i,j,k order for xsf
        if g is None:
            self._error('grid must be specified (via --grid or --input) for output format xsf')
        #end if
        if s is None:
            self._error('structure must be specified (via --structure) for output format xsf')
        #end if
        s = s.copy()
        p = s.pos.ravel()
        if p.min()>0 and p.max()<1.0:
            s.pos_to_cartesian()
        #end if
        s.change_units('A')
        cell   = s.axes

        f = XsfFile()
        f.incorporate_structure(s)

        prefix = '{0}{1}'.format(prefix,extension)

        c = 1
        g = 1
        t = 1

        # mean
        f.add_density(cell,density,centered=c,add_ghost=g)
        c = ChgcarFile()
        c.incorporate_xsf(f)
        c.write(prefix+'.CHGCAR')

        # mean + errorbar
        f.add_density(cell,density+density_err,centered=c,add_ghost=g)
        c = ChgcarFile()
        c.incorporate_xsf(f)
        c.write(prefix+'+err.CHGCAR')

        # mean - errorbar
        f.add_density(cell,density-density_err,centered=c,add_ghost=g)
        c = ChgcarFile()
        c.incorporate_xsf(f)
        c.write(prefix+'-err.CHGCAR')
    #end def write_chgcar

#end class SingleDensity



class StatFile(QDBase):
    def __init__(self,filepath=None):
        self.filepath = filepath
        if filepath is None:
            self.file_prefix = None
        else:
            self.file_prefix = filepath.replace('.stat.h5','')
        #end if
        self.density_data    = None
        self.density_results = None
        self.spin_density_data    = None
        self.spin_density_results = None

        if filepath is not None:
            self.read(filepath)
        #end if
    #end def __init__

    def has_density_data(self):
        return self.density_data is not None
    #end def has_data

    def has_density_results(self):
        return self.density_results is not None
    #end def has_results

    def has_data(self):
        return self.spin_density_data is not None
    #end def has_data

    def has_results(self):
        return self.spin_density_results is not None
    #end def has_results


    def read(self,filepath):
        opt = self.options
        if not os.path.exists(filepath):
            self.error('attempted to read stat.h5 file that does not exist\nfile path: {0}'.format(filepath))
        #end if
        hdf = HDFreader(filepath)
        if not hdf._success:
            self.error('read failed for stat.h5 file\nfile path: {0}'.format(filepath))
        #end if
        h5 = hdf.obj
        h5._remove_hidden()
        self.spin_density_data = obj()
        self.density_data = obj()
        for name in h5.keys():
            lname = name.lower()
            if 'spin' in lname and 'density' in lname:
                self.spin_density_data[name] = h5[name]
            elif 'density' in lname:
                self.density_data[name] = h5[name]
            #end if
        #end for
        self.use_dens = False
        if len(self.spin_density_data)==0 and len(self.density_data)==0:
            self.error('spin density is not present in stat.h5 file\nfile path: {0}'.format(filepath))
        elif len(self.density_data)>0:
            self.use_dens=True
        #end if
        opt = self.options
        if not self.use_dens:
            if len(self.spin_density_data)==1 and opt.grids is None and opt.grid is not None:
                opt.grids = obj()
                opt.grids[list(self.spin_density_data.keys())[0]] = opt.grid
            #end if
        else:
            if len(self.density_data)==1 and opt.grids is None and opt.grid is not None:
                opt.grids = obj()
                opt.grids[list(self.density_data.keys())[0]] = opt.grid
            #end if
        #end if
    #end def read


    def write(self,filepath):
        h5 = h5py.File(filepath,'w')
        if not self.use_dens:
            for name,spin_density in self.spin_density_data.items():
                h5_spin_density = h5.create_group(name)
                for s,spin in spin_density.items():
                    h5_spin = h5_spin_density.create_group(s)
                    h5_v  = h5_spin.create_dataset('value'        ,data=spin.value        )
                    #h5_v2 = h5_spin.create_dataset('value_squared',data=spin.value_squared)
                #end for
            #end for
        else:
            for name,density in self.density_data.items():
                h5_density = h5.create_group(name)
                h5_v  = h5_density.create_dataset('value'        ,data=density.value        )
            #end for
        #end if
        h5.close()
    #end def write


    def accumulate(self,stat,weight=1.0):
        if 'total_weight' not in self:
            self.total_weight = 0.0
        #end if
        self.total_weight += weight
        if not stat.use_dens:
            if self.spin_density_data is None:
                self.spin_density_data = stat.spin_density_data.copy()
                for spin_density in self.spin_density_data:
                    for spin in spin_density:
                        spin.value         *= weight
                        #spin.value_squared *= weight
                    #end for
                #end for
            else:
                for name,spin_density in self.spin_density_data.items():
                    sd = stat.spin_density_data[name]
                    for s,spin in spin_density.items():
                        spin.value         += weight*sd[s].value
                        #spin.value_squared += weight*sd[s].value_squared
                    #end for
                #end for
            #end if
        else:
            if self.density_data is None:
                self.density_data = stat.density_data.copy()
                for density in self.density_data:
                    density.value *= weight
                #end for
            else:
                for name,density in self.density_data.items():
                    sd = stat.density_data[name]
                    density.value += weight*sd.value
                #end for
            #end if
        #end if
    #end def accumulate

                
    def normalize(self):
        if not self.use_dens:
            for spin_density in self.spin_density_data:
                for spin in spin_density:
                    spin.value         /= self.total_weight
                    #spin.value_squared /= self.total_weight
                #end for
            #end for
        else:
            for density in self.density_data:
                density.value /= self.total_weight
            #end if
        #end if
        del self.total_weight
    #end def normalize


    def analyze(self,equilibration=0,reblock_factor=None):
        if not self.has_data():
            self.error('cannot analyze results, data is not present')
        #end if
        if not self.use_dens:
            self.spin_density_results = obj()
            opt = self.options
            for name,spin_density in self.spin_density_data.items():
                g = None
                if opt.grids is not None and name in opt.grids:
                    g = opt.grids[name]
                    for data in spin_density:
                        ncells = g[0]*g[1]*g[2]
                        data_cells = data.value.shape[-1]
                        if ncells!=data_cells:
                            self.error('grid does not match number of data cells\ngrid provided: {0}\nnumber of cells in grid: {1}\nnumber of cells in data: {2}'.format(grid,ncells,data_cells))
                        #end if
                        data.value.shape = len(data.value),g[0],g[1],g[2]
                    #end for
                #end for
                # reblock the data
                u_data = reblock(spin_density.u.value,equilibration,reblock_factor,self.filepath)
                d_data = reblock(spin_density.d.value,equilibration,reblock_factor,self.filepath)
                # save the reblocked data
                spin_density.u.clear()
                spin_density.u.value = u_data
                spin_density.d.clear()
                spin_density.d.value = d_data
                # make the single density means/errors
                sres = obj()
                sres.u = SingleDensity(
                    structure = opt.structure,
                    grid      = g,
                    data      = u_data,
                    extension = '_u',
                    )
                sres.d = SingleDensity(
                    structure = opt.structure,
                    grid      = g,
                    data      = d_data,
                    extension = '_d',
                    )
                sres.tot = SingleDensity(
                    structure = opt.structure,
                    grid      = g,
                    data      = u_data+d_data,
                    extension = '_u+d',
                    )
                sres.pol = SingleDensity(
                    structure = opt.structure,
                    grid      = g,
                    data      = u_data-d_data,
                    extension = '_u-d',
                    )
                self.spin_density_results[name] = sres
            #end for
        else:
            self.density_results = obj()
            opt = self.options
            for name,density in self.density_data.items():
                g = None
                if opt.grids is not None and name in opt.grids:
                    g = opt.grids[name]
                    for data in density:
                        ncells = g[0]*g[1]*g[2]
                        data_cells = data.shape[-1]*data.shape[-2]*data.shape[-3]
                        if ncells!=data_cells:
                            self.error('grid does not match number of data cells\ngrid provided: {0}\nnumber of cells in grid: {1}\nnumber of cells in data: {2}'.format(grid,ncells,data_cells))
                        #end if
                        data.shape = len(data),g[0],g[1],g[2]
                    #end for
                #end for
                # reblock the data
                q_data = reblock(density.value,equilibration,reblock_factor,self.filepath)
                # save the reblocked data
                density.clear()
                density.value = q_data
                # make the single density means/errors
                sres = obj()
                sres.q = SingleDensity(
                    structure = opt.structure,
                    grid      = g,
                    data      = q_data,
                    extension = '_q',
                    )
                self.density_results[name] = sres
            #end for
        #end if
    #end def analyze
        


    def write_output_files(self,formats):
        if not self.use_dens:
            for name,spin_density in self.spin_density_results.items():
                prefix = '{0}.{1}'.format(self.file_prefix,name)
                for format in formats:
                    for single_density in spin_density:
                        single_density.write(prefix,format)
                    #end for
                #end for
            #end for
        else:
            for name,density in self.density_results.items():
                prefix = '{0}.{1}'.format(self.file_prefix,name)
                for format in formats:
                    for single_density in density:
                        single_density.write(prefix,format)
                    #end for
                #end for
            #end for
        #end if
    #end def write_output_files


    def line_plot(self,dim):
        self.vlog('      making line plots for {0}'.format(self.file_prefix))
        opt = self.options
        ndim = 3
        permute = dim!=0
        if opt.structure is not None:
            s = opt.structure.copy()
            s.change_units('A')
            rmax = np.linalg.norm(s.axes[dim])
        else:
            rmax = None
        #end if
        if permute:
            r = list(range(1,ndim+1))
            r.pop(dim)
            permutation = tuple([0,dim+1]+r)
        #end if
        sdim = tuple('xyz')[dim]
        for name,spin_density in self.spin_density_data.items():
            prefix = '{0}.{1}_lineplot_{2}'.format(self.file_prefix,name,sdim)
            data = spin_density.u.value+spin_density.d.value
            if permute:
                data = data.transpose(permutation)
            #end if
            s = data.shape
            data.shape = s[0],s[1],s[2]*s[3]
            data = data.sum(2)
            lmean,lvar,lerror,lkappa = simstats(data,dim=0)
            if rmax is None:
                r = np.arange(len(lmean))
                xlab = 'bin number'
                ylab = '$N_e$/bin'
            else:
                dr = rmax/len(lmean)
                r = dr/2+dr*np.arange(len(lmean),dtype=float)
                lmean/=dr
                lerror/=dr
                xlab = sdim+' ($\AA$)'.format(dim)
                ylab = '$N_e/\AA$'
            #end if
            plt.figure(tight_layout=True)
            plt.errorbar(r,lmean,lerror,fmt='b.-')
            plt.xlabel(xlab)
            plt.ylabel(ylab)
            plt.title('{0}  (axis {1})  {2}'.format(name,dim,os.path.split(self.file_prefix)[1]))
            self.vlog('        creating file {0}.pdf'.format(prefix))
            plt.savefig(prefix+'.pdf')
            self.vlog('        creating file {0}.dat'.format(prefix))
            np.savetxt(prefix+'.dat',np.array(list(zip(r,lmean,lerror))))
        #end for
    #end def line_plot
#end class StatFile



class QMCDensityProcessor(QDBase):
    def __init__(self):
        self.file_list = []
        self.stat_files = obj()
    #end def __init__

    def read_command_line(self):
        usage = '''usage: %prog [options] [file(s)]'''
        parser = OptionParser(usage=usage,add_help_option=False,version='%prog {}.{}.{}'.format(*nexus_version))
        parser.add_option('-h','--help',dest='help',
                          action='store_true',default=False,
                          help='Print help information and exit (default=%default).'
                          )
        parser.add_option('-v','--verbose',dest='verbose',
                          action='store_true',default=False,
                          help='Print detailed information (default=%default).'
                          )
        #parser.add_option('-q','--quantities',dest='quantities',
        #                  default='all',
        #                  help='Quantity or list of quantities to analyze.  See names and abbreviations below (default=%default).'
        #                  )
        parser.add_option('-f','--formats',dest='formats',
                          default='None',
                          help='Format or list of formats for density file output.  Options: dat, xsf, chgcar (default=%default).'
                          )
        parser.add_option('-e','--equilibration',dest='equilibration',
                          default='0',
                          help='Equilibration length in blocks (default=%default).'
                          )
        parser.add_option('-r','--reblock',dest='reblock',
                          default='None',
                          help='Block coarsening factor; use estimated autocorrelation length (default=%default).'
                          )
        parser.add_option('-a','--average',dest='average',
                          action='store_true',default=False,
                          help='Average over files in each series (default=%default).'
                          )
        parser.add_option('-w','--weights',dest='weights',
                          default='None',
                          help='List of weights for averaging (default=%default).'
                          )
        parser.add_option('-i','--input',dest='input',
                          default='None',
                          help='QMCPACK input file containing structure and grid information (default=%default).'
                          )
        parser.add_option('-s','--structure',dest='structure',
                          default='None',
                          help='File containing atomic structure (default=%default).'
                          )
        parser.add_option('-g','--grid',dest='grid',
                          default='None',
                          help='Density grid dimensions (default=%default).'
                          )
        parser.add_option('-c','--cell',dest='cell',
                          default='None',
                          help='Simulation cell axes (default=%default).'
                          )
        parser.add_option('--lineplot',dest='lineplot',
                          default='None',
                          help='Produce a line plot along the selected dimension: 0, 1, or 2 (default=%default).'
                          )
        parser.add_option('--noplot',dest='noplot',
                          action='store_true',default=False,
                          help='Do not show plots interactively (default=%default).'
                          )
        parser.add_option('--twist_info',dest='twist_info',
                          default='use',
                          help='Use twist weights in twist_info.dat files or not.  Options: "use", "ignore", "require".  "use" means use when present, "ignore" means do not use, "require" means must be used (default=%default).'
                          )
        
        
        options,files_in = parser.parse_args()
        
        QDBase.parser = parser
        QDBase.options.transfer_from(options.__dict__)
        
        opt = self.options
        
        QDBase.verbose = opt.verbose
         
        if opt.help or len(files_in)==0:
            self.help()
            self.exit()
        #end if

        self.vlog('\n{0} initializing'.format(self.name))
        
        if self.verbose:
            self.log('\noptions provided:')
            self.log(str(self.options))
        #end if


        # handle options
        #   options initialized to "None"
        for k,v in opt.items():
            if isinstance(v,str) and v=='None':
                opt[k] = None
            #end if
        #end for

        #   --format option (output file formats)
        if opt.formats is not None:
            opt.formats = input_list(opt.formats)
            allowed_formats = set(['dat','xsf','chgcar'])
            invalid = set(opt.formats)-allowed_formats
            if len(invalid)>0:
                error('invalid output file format(s) requested\ninvalid requests: {0}\nallowed options: {1}'.format(sorted(invalid),sorted(allowed_formats)))
            #end if
        #end if

        #   --equilibration option
        opt.equilibration = comma_list(opt.equilibration)
        equil_failed = False
        try:
            opt.equilibration = int(opt.equilibration)
        except:
            try:
                ei = eval(opt.equilibration)
                e = obj()
                if isinstance(ei,int):
                    e = ei
                elif isinstance(ei,(list,tuple)):
                    ei = np.array(ei,dtype=int)
                    for s in range(len(ei)):
                        e[s] = ei[s]
                    #end for
                elif isinstance(ei,dict):
                    e.transfer_from(ei)
                else:
                    equil_failed = True
                #end if
                opt.equilibration = e
            except:
                equil_failed = True
            #end try
        #end try
        if equil_failed:
            self.error('cannot process equilibration option\nvalue must be an integer, list, or dict\nyou provided: '+str(opt.equilibration))
        #end if

        #   --reblock option
        if opt.reblock is not None:
            opt.reblock = comma_list(opt.reblock)
            reblock_failed = False
            try:
                opt.reblock = int(opt.reblock)
            except:
                try:
                    ei = eval(opt.reblock)
                    e = obj()
                    if isinstance(ei,int):
                        e = ei
                    elif isinstance(ei,(list,tuple)):
                        ei = np.array(ei,dtype=int)
                        for s in range(len(ei)):
                            e[s] = ei[s]
                        #end for
                    elif isinstance(ei,dict):
                        e.transfer_from(ei)
                    else:
                        reblock_failed = True
                    #end if
                    opt.reblock = e
                except:
                    reblock_failed = True
                #end try
            #end try
            if reblock_failed:
                self.error('cannot process reblock option\nvalue must be an integer, list, or dict\nyou provided: '+str(opt.reblock))
            #end if        
        #end if

        #   --average option
        if opt.average:
            if opt.weights is not None:
                weight_failed = False
                try:
                    w = comma_list(opt.weights)
                    w = np.array(eval(w),dtype=float)
                    opt.weights = w
                    weight_failed = len(w.shape)!=1
                except:
                    weight_failed = True
                #end try
                if weight_failed:
                    self.error('weights must be a list of values\nyou provided: {0}'.format(ws))
                #end if
            #end if
        #end if

        # set grids
        opt.grids = None

        #   --input option
        structure = None
        cell      = None
        if opt.input is not None:
            if not os.path.exists(opt.input):
                self.error('qmcpack input file does not exist: {0}'.format(opt.input))
            #end if
            qi = QmcpackInput(opt.input)
            ps = qi.return_system()
            s = ps.structure
            structure = s
            cell = s.axes.copy()
            qi.pluralize()
            # collect every bunch of estimators strewn throughout the input file
            est_sources = []
            ham = qi.get('hamiltonian')
            if ham is not None:
                if 'estimators' in ham:
                    est_sources.append(ham.estimators)
                elif len(ham)>0:
                    ham = ham.first() # hamiltonian collection
                    if 'estimators' in ham:
                        est_sources.append(ham.estimators)
                    #end if
                #end if
            #end if
            calcs = qi.get('calculations')
            if calcs is not None:
                for series in sorted(calcs.keys()):
                    qmc = calcs[series]
                    if 'estimators' in qmc:
                        est_sources.append(qmc.estimators)
                    #end if
                #end for
            #end if
            # collect input information from the first spin density instance with a particular name
            #   assume all other instances that share this name have identical inputs
            grids = obj()
            for est in est_sources:
                for name,xml in est.items():
                    # the if statement below is a nasty hack to temporarily support 
                    # the name mismatch between the input file and stat.h5 currently 
                    # enforced by the (new) batched spin denisity class in qmcpack
                    if isinstance(xml,spindensity_new_xml):
                        name = 'SpinDensity' # override all user input, just like qmcpack
                    elif isinstance(xml,density_xml):
                        name = 'Density' # override all user input, just like qmcpack
                    #end if
                    if name not in grids and isinstance(xml,(spindensity_xml,spindensity_new_xml)):
                        sd = xml
                        if 'grid' in sd:
                            grids[name] = sd.grid
                        elif 'dr' in sd:
                            grids[name] = get_grid(s,sd.dr)
                        else:
                            self.error('could not identify grid data for spin density named "{0}"\bin QMCPACK input file: {1}'.format(name,opt.input))
                        #end if
                    elif name not in grids and isinstance(xml,density_xml):
                        sd = xml
                        if 'grid' in sd:
                            grids[name] = sd.grid
                        elif 'delta' in sd:
                            if len(cell)==0:
                                if 'x_min' not in sd or 'x_max' not in sd or 'y_min' not in sd or 'y_max' not in sd or 'z_min' not in sd or 'z_max' not in sd:
                                    self.error('could not identify cell data for density named "{0}"\bin QMCPACK input file: {1}'.format(name,opt.input))
                                else:
                                    grids[name] = get_grid(s,
                                                           delta=sd.delta,
                                                           x_min=sd.x_min,
                                                           x_max=sd.x_max,
                                                           y_min=sd.y_min,
                                                           y_max=sd.y_max,
                                                           z_min=sd.z_min,
                                                           z_max=sd.z_max,
                                            )
                                #end if
                                cell = np.array([[sd.x_max,0.,0.],
                                                 [0.,sd.y_max,0.],
                                                 [0.,0.,sd.z_max]])
                                s.axes = cell
                            else:
                                grids[name] = get_grid(s,delta=sd.delta)
                            #end if
                        else:
                            self.error('could not identify grid data for spin density named "{0}"\bin QMCPACK input file: {1}'.format(name,opt.input))
                        #end if
                    #end if
                #end for
            #end for
            if len(grids)>0:
                opt.grids = grids
            else:
                self.error('Could not find any spin density estimators in the input file provided.\nInput file provided: {}'.format(opt.input))
            #end if
        #end if

        #   --structure option
        if opt.structure is not None:
            if not os.path.exists(opt.structure):
                self.error('structure file does not exist: {0}'.format(opt.structure))
            #end if
            opt.structure = read_structure(opt.structure)
        else:
            opt.structure = structure
        #end if

        #   --cell option
        if opt.cell is not None:
            cell = input_list(opt.cell)
            try:
                cell = np.array(cell,dtype=float)
            except:
                self.error('--cell input misformatted\nexpected a list of real values\nreceived: {0}'.format(cell))
            #end try
            if len(cell)!=9:
                self.error('--cell input misformatted\nexpected 9 elements for 3x3 matrix\nreceived {0} elements with values: {1}'.format(len(cell),cell))
            #end if
            cell.shape = 3,3
            opt.cell = cell
        else:
            opt.cell = cell
        #end if

        #   --grid option
        if opt.grid is not None:
            grid = input_list(opt.grid)
            try:
                grid = np.array(grid,dtype=int)
            except:
                self.error('--grid input misformatted\nexpected a list of integers\nreceived: {0}'.format(grid))
            #end try
            if len(grid)!=3:
                self.error('--grid input misformatted\nexpected 3 elements\nreceived {0} elements with values: {1}'.format(len(grid),grid))
            #end if
            opt.grid = grid
        #end if

        #   --lineplot option
        if opt.lineplot is not None:
            if opt.lineplot not in tuple('012'):
                self.error('--lineplot input misformatted\nexpected 0, 1, or 2\nreceived: {0}'.format(opt.lineplot))
            #end if
            opt.lineplot = int(opt.lineplot)
        #end if

        #   --twist_info option
        twist_info_options = ('use','ignore','require')
        if opt.twist_info not in twist_info_options:
            self.error('twist_info option is invalid\ntwist_info must be one of: {}\nyou provided: {}'.format(twist_info_options,opt.twist_info))
        #end if

        if opt.verbose and opt.grids is not None:
            self.log('grids found:')
            for name,grid in opt.grids.items():
                self.log('  {0} {1}'.format(name,grid))
            #end for
        #end if
        
        for file in files_in:
            if not os.path.exists(file):
                self.error('file does not exist: {0}'.format(file))
            #end if
            if not file.endswith('.stat.h5'):
                self.error('only stat.h5 files are allowed as inputs\nreceived file: {0}'.format(file))
            #end if
        #end for

        self.file_list.extend(files_in)

    #end def read_command_line


    def process(self):
        opt = self.options
        self.vmlog('\nprocessing stat.h5 files')
        # separate stat.h5 files into series and groups
        batch_files = obj()
        basepath = None
        for filepath in self.file_list:
            path,file = os.path.split(filepath)
            tokens = file.split('.')
            batch_prefix = os.path.join(path,tokens[0])
            if batch_prefix not in batch_files:
                batch_files[batch_prefix] = obj()
            #end if
            stat_files = batch_files[batch_prefix]
            group  = 0
            series = None
            for t in tokens:
                if t.startswith('g') and t[1:].isdigit():
                    group = int(t[1:])
                #end if
                if t.startswith('s') and t[1:].isdigit():
                    series = int(t[1:])
                #end if
            #end for
            if series not in stat_files:
                stat_files[series] = obj()
            #end if
            stat_files[series][group] = filepath
            if basepath is None:
                basepath,ref_file = os.path.split(filepath)
                ref_tokens = tokens
            #end if
        #end for
        self.vlog('  {0} batches identified'.format(len(batch_files)))
        for batch_prefix in sorted(batch_files.keys()):
            self.vlog('    {0}'.format(batch_prefix))
        #end for

        # loop over file batches
        for batch_prefix in sorted(batch_files.keys()):
            stat_files = batch_files[batch_prefix]
            basepath,file_prefix = os.path.split(batch_prefix)
            ref_file = os.path.split(stat_files.first().first())[1]
            ref_tokens = ref_file.split('.')

            nseries = len(stat_files)
            ngroups = len(stat_files.first())

            self.vmlog('\n\nprocessing batch {0}, {1} series, {2} groups'.format(batch_prefix,nseries,ngroups))
            self.vlog(str(stat_files))


            # average files, if requested
            if opt.average:
                self.vmlog('  averaging files')
                for series in sorted(stat_files.keys()):
                    self.vmlog('    processing series {0} files'.format(series))
                    sfiles = stat_files[series]
                    if len(sfiles)>1:
                        if opt.weights is None:
                            uniform_weights = np.ones((len(sfiles),),dtype=float)
                            if opt.twist_info=='ignore':
                                weights = uniform_weights
                            else:
                                weights = []
                                for group in sorted(sfiles.keys()):
                                    twist_info_file = '{}.g{}.twist_info.dat'.format(batch_prefix,str(group).zfill(3))
                                    if os.path.exists(twist_info_file):
                                        fobj = open(twist_info_file)
                                        try:
                                            weight = float(fobj.read().strip().split()[0])
                                            weights.append(weight)
                                        except:
                                            None
                                        #end try
                                    #end if
                                #end for
                                if len(weights)!=len(sfiles):
                                    if opt.twist_info=='require':
                                        self.error('twist_info files are either missing or mis-formatted for batch prefix {}'.format(batch_prefix))
                                    else:
                                        weights = uniform_weights
                                    #end if
                                #end if
                            #end if
                        else:
                            if len(sfiles)!=len(opt.weights):
                                self.error('weights provided do not match number of files in series {0}\nnumber of weights provided: {1}\nnumber of files in series {0}: {2}\nfiles in series {0}:\n{3}'.format(series,len(opt.weights),len(sfiles),sfiles))
                            #end if
                            weights = opt.weights
                        #end if
                        sref_tokens = os.path.split(sfiles.first())[1].split('.')
                        avg_filepath = qmcpack_filepath(basepath,sref_tokens,g='avg')
                        stat_avg = StatFile()
                        n=0
                        use_dens=False
                        for group in sorted(sfiles.keys()):
                            self.vlog('      accumulating file with weight {0} {1}'.format(weights[n],sfiles[group]))
                            w = weights[n]
                            stat_tmp = StatFile(sfiles[group])
                            if stat_tmp.use_dens:
                                use_dens = True
                            #end if
                            stat_avg.accumulate(stat_tmp,w)
                            del stat_tmp
                            gc.collect()
                            n+=1
                        #end for
                        stat_avg.use_dens = use_dens
                        stat_avg.normalize()
                        self.vlog('    writing averaged file: {0}'.format(avg_filepath))
                        stat_avg.write(avg_filepath)
                        # overwrite stat files to operate on
                        sfiles.clear()
                        sfiles[0] = avg_filepath
                    #end if
                #end for
            #end if

            # read files and create output data
            if opt.formats is not None or opt.lineplot is not None:
                for series in sorted(stat_files.keys()):
                    self.vmlog('  processing series {0} files'.format(series))
                    sfiles = stat_files[series]
                    if len(sfiles)>0:
                        for filepath in sfiles:
                            self.vlog('    processing file {0}'.format(filepath))
                            stat = StatFile(filepath)
                            if isinstance(opt.equilibration,int):
                                eq = opt.equilibration
                            else:
                                eq = opt.equilibration[series]
                            #end if
                            if opt.reblock is None or isinstance(opt.reblock,int):
                                rb = opt.reblock
                            else:
                                rb = opt.reblock[series]
                            #end if
                            stat.analyze(eq,rb)
                            if opt.formats is not None:
                                stat.write_output_files(opt.formats)
                            #end if
                            if opt.lineplot is not None:
                                stat.line_plot(opt.lineplot)
                            #end if
                        #end for
                    #end if
                #end for
            #end if
        #end for
        if not opt.noplot and opt.lineplot is not None:
            plt.show()
        #end if
    #end def process
#end class QMCDensityProcessor


if __name__=='__main__':
    qdens = QMCDensityProcessor()

    qdens.read_command_line()

    qdens.process()

    qdens.exit()
#end if



