#!/usr/local/bin/python3.9

# Python standard library imports
import os
import sys
from optparse import OptionParser

# Non-standard Python library imports
import numpy as np



def find_nexus_modules():
    import sys
    nexus_lib = os.path.abspath(os.path.join(__file__,'..','..','lib'))
    assert(os.path.exists(nexus_lib))
    sys.path.append(nexus_lib)
#end def find_nexus_modules


def import_nexus_module(module_name):
    import importlib
    return importlib.import_module(module_name)
#end def import_nexus_module


# Load Nexus modules
try:
    # Attempt specialized path-based imports.
    #  (The executable should still work even if Nexus is not installed)
    find_nexus_modules()

    versions = import_nexus_module('versions')
    nexus_version = versions.nexus_version
    del versions

    generic = import_nexus_module('generic')
    obj = generic.obj
    del generic

    developer = import_nexus_module('developer')
    log         = developer.log
    warn        = developer.warn
    error       = developer.error
    ci          = developer.ci
    del developer

    hdfreader = import_nexus_module('hdfreader')
    read_hdf = hdfreader.read_hdf
    del hdfreader
except:
    from versions import nexus_version
    from generic import obj
    from developer import log,warn,error,ci
    from hdfreader import read_hdf
#end try



opt = obj(
    verbose = False,
    )

def vlog(*args,**kwargs):
    if opt.verbose:
        log(*args,**kwargs)
    #end if
#end def vlog


def user_error(msg):
    error(msg,'User',trace=False)
#end def user_error



def read_eshdf_nofk_data(filename,Ef,ncore=0,nval=-1):
    from numpy import array,pi,dot,sqrt,abs,zeros
    from numpy.linalg import inv,det
    from unit_converter import convert

    def h5int(i):
        return array(i,dtype=int)[0]
    #end def h5int

    E_fermi   = Ef + 1e-8
    h        = read_hdf(filename,view=True)
    gvu      = array(h.electrons.kpoint_0.gvectors)
    axes     = array(h.supercell.primitive_vectors)
    kaxes    = 2*pi*inv(axes).T
    gv       = dot(gvu,kaxes)
    Ngv      = len(gv[:,0])
    kmag     = sqrt((gv**2).sum(1))
    nk       = h5int(h.electrons.number_of_kpoints)
    ns       = h5int(h.electrons.number_of_spins)
    occpaths = obj()
    data     = obj()
    for k in range(nk):
        kin_k   = obj()
        eig_k   = obj()
        k_k     = obj()
        nk_k    = obj()
        nelec_k = zeros((ns,),dtype=float)
        kp = h.electrons['kpoint_'+str(k)]
        gvs = dot(array(kp.reduced_k),kaxes)
        gvk = gv.copy()
        for d in range(3):
            gvk[:,d] += gvs[d]
        #end for
        kinetic=(gvk**2).sum(1)/2 # Hartree units
        for s in range(ns):
            #print ' ',(k,s),(nk,ns)
            kin_s = []
            eig_s = []
            k_s   = gvk
            nk_s  = 0*kmag
            nelec_s = 0
            path = 'electrons/kpoint_{0}/spin_{1}'.format(k,s)
            spin = h.get_path(path)
            eig = convert(array(spin.eigenvalues),'Ha','eV')
            nst = h5int(spin.number_of_states)
            nlast = nst
            if nval > 0:
              nlast = ncore + nval
            for st in range(ncore, nlast):
                e = eig[st]
                if e<E_fermi:
                    stpath = path+'/state_{0}/psi_g'.format(st)
                    occpaths.append(stpath)
                    psi = array(h.get_path(stpath))
                    nk_orb = (psi**2).sum(1)
                    kin_orb = (kinetic*nk_orb).sum()
                    nelec_s += nk_orb.sum()
                    nk_s += nk_orb
                    kin_s.append(kin_orb)
                    eig_s.append(e)
                #end if
            #end for
            data[k,s] = obj(
                kpoint = array(kp.reduced_k),
                kin    = array(kin_s),
                eig    = array(eig_s),
                k      = k_s,
                nk     = nk_s,
                ne     = nelec_s,
                )
        #end for
    #end for
    res = obj(
        orbfile  = filename,
        E_fermi  = E_fermi,
        axes     = axes,
        kaxes    = kaxes,
        nkpoints = nk,
        nspins   = ns,
        data     = data,
        )
    return res
#end def read_eshdf_nofk_data


def first_eshdf_file(args):
    # check files provided
    eshdf_files = list(sorted(args[1:]))
    if len(eshdf_files)!=1:
        user_error('exactly one ESHDF file is allowed as input.\nYou provided: {}'.format(eshdf_files))
    #end if
    eshdf_filepath = eshdf_files[0]
    if not os.path.exists(eshdf_filepath):
        user_error('ESHDF file does not exist.\nPlease check the path provided:\n  {}'.format(eshdf_filepath))
    elif not eshdf_filepath.endswith('.h5'):
        user_error('file provided is not an HDF5 file.\nAn ESHDF file must have a .h5 extension.\nPlease check the path provided:\n  {}'.format(eshdf_filepath))
    #end if
    return eshdf_filepath
#end def first_eshdf_file


def interpret_nk_options(parser):
    options,args = parser.parse_args()
    opt.transfer_from(options.__dict__)
    for k,v in opt.items():
        if v=='None':
            opt[k] = None
        #end if
    #end for
    eshdf_files = list(sorted(args[1:]))
    if opt.help or len(eshdf_files)==0:
        print('\n'+parser.format_help().strip()+'\n')
        exit()
    #end if
    # check Fermi energy
    if opt.E_fermi is None:
        user_error('please provide the Fermi energy via the --Ef option.')
    else:
        try:
            opt.E_fermi = float(opt.E_fermi)
        except:
            user_error('value provided for Fermi energy is not a real number.\nYou provided: {}'.format(opt.E_fermi))
        #end try
    #end if
    # set active space
    opt.ncore = int(opt.ncore)
    opt.nval = int(opt.nval)
    return opt, args
#end def interpret_nk_options


def kinetic():
    # read command line inputs
    usage = '''usage: %prog kinetic [options] [eshdf_file]'''
    parser = OptionParser(usage=usage,add_help_option=False,version='%prog {}.{}.{}'.format(*nexus_version))

    parser.add_option('-h','--help',dest='help',
                      action='store_true',default=False,
                      help='Print help information and exit (default=%default).'
                      )
    #parser.add_option('-v','--verbose',dest='verbose',
    #                  action='store_true',default=False,
    #                  help='Print detailed information (default=%default).'
    #                  )
    parser.add_option('--Ef',dest='E_fermi',
                      default='None',
                      help='Fermi energy in eV (default=%default).'
                      )
    parser.add_option('--norm_tol',dest='norm_tol',
                      action='store',type='float',default=1e-6,
                      help='Fermi energy in eV (default=%default).'
                      )
    parser.add_option('-o','--orb',dest='orb_info',
                      action='store_true',default=False,
                      help='Print per orbital kinetic energies (default=%default).'
                      )
    parser.add_option('--ncore',dest='ncore',default=0,
                      help='number of core states to exclude (default=%default).'
                      )
    parser.add_option('--nval',dest='nval',default=-1,
                      help='number of valence states to include (default=%default, i.e. all).'
                      )


    opt, args = interpret_nk_options(parser)
    eshdf_filepath = first_eshdf_file(args)

    d = read_eshdf_nofk_data(eshdf_filepath,opt.E_fermi,ncore=opt.ncore,nval=opt.nval)

    nkpoints = d.nkpoints
    nspins   = d.nspins

    ktot  = 0.0
    nspin = np.zeros((nspins,),dtype=int)
    kspin = np.zeros((nspins,),dtype=float)
    normspin = np.zeros((nspins,),dtype=float)
    norb  = []
    korb  = []
    eorb  = []
    kpidx= []
    for s in range(nspins):
        nsorb = []
        ksorb = []
        esorb = []
        kpsidx = []
        for k in range(nkpoints):
            dks = d.data[k,s]
            nsorb.append(len(dks.kin))
            ksorb.extend(dks.kin)
            esorb.extend(dks.eig)
            kpsidx.extend([k]*len(dks.kin))
            normspin[s] += dks.ne
        #end for
        nsorb = np.array(nsorb,dtype=int)
        ksorb = np.array(ksorb,dtype=float)
        esorb = np.array(esorb,dtype=float)
        kpsidx = np.array(kpsidx,dtype=int)

        order = esorb.argsort()
        esorb = esorb[order]
        ksorb = ksorb[order]
        kpsidx = kpsidx[order]

        ns = nsorb.sum()
        ks = ksorb.sum()

        ktot    += ks
        nspin[s] = ns
        kspin[s] = ks
        norb.append(nsorb)
        korb.append(ksorb)
        eorb.append(esorb)
        kpidx.append(kpsidx)
    #end for

    def arr2str(a):
        return '{}'.format(a).strip().strip('[]')
    #end def arr2str

    log('\nNumber of spins              : {}'.format(nspins))
    log('Number kpoints               : {}'.format(nkpoints))
    log('Number of electrons per spin : {}'.format(arr2str(nspin)))
    log('Summed orbital norm per spin : {}'.format(arr2str(normspin)))
    log('Total kinetic energy         : {} Ha'.format(ktot))
    log('Kinetic energy per spin      : {} Ha'.format(arr2str(kspin)))

    norm_diff = np.abs(normspin-nspin).max()
    if norm_diff>opt.norm_tol:
        warn('Orbitals are not properly normalized!\nError in the summed norm is: {}\nThis exceeds a tolerance of: {}'.format(norm_diff,opt.norm_tol))
    #end if

    spin_labels = {0:'up',1:'down'}
    if opt.orb_info:
        log('\nPer orbital kinetic energies')
        for s in range(nspins):
            log('  Spin {} energies'.format(spin_labels[s]))
            log('    index kpoint_index  KS eig (eV)  kinetic (Ha)')
            for i,(kpi,eig,kin) in enumerate(zip(kpidx[s],eorb[s],korb[s])):
                log('    {:>3}      {:>3}        {: 10.6f}   {: 10.6f}'.format(i,kpi,eig,kin))
            #end for
        #end for
    #end if
#end def kinetic


def write_nk():
    # read command line inputs
    usage = '''usage: %prog write_nk [options] [eshdf_file]'''
    parser = OptionParser(usage=usage,add_help_option=False,version='%prog {}.{}.{}'.format(*nexus_version))
    parser.add_option('-h','--help',dest='help',
                      action='store_true',default=False,
                      help='Print help information and exit (default=%default).'
                      )
    parser.add_option('--Ef',dest='E_fermi',
                      default='None',
                      help='Fermi energy in eV (default=%default).'
                      )
    parser.add_option('--ncore',dest='ncore',default=0,
                      help='number of core states to exclude (default=%default).'
                      )
    parser.add_option('--nval',dest='nval',default=-1,
                      help='number of valence states to include (default=%default, i.e. all).'
                      )
    parser.add_option('--ispin',dest='ispin',default=0,
                      help='spin channel (default=%default).'
                      )
    parser.add_option('--outfile',dest='outfile',
                      default=None,
                      help='Output 3D momentum distribution'
                      )
    opt, args = interpret_nk_options(parser)
    eshdf_filepath = first_eshdf_file(args)

    d = read_eshdf_nofk_data(eshdf_filepath,opt.E_fermi,ncore=opt.ncore,nval=opt.nval)

    ispin0 = int(opt.ispin)
    log('Writing n(k) to HDF5 {}'.format(opt.outfile))
    import tables
    klist = []
    nklist = []
    for key, val in d.data.items():
        ik, ispin = key
        if ispin != ispin0: continue
        klist.append(val.k)
        nklist.append(val.nk)
    kvecs = np.concatenate(klist, axis=0)
    nk = np.concatenate(nklist, axis=0)
    data = np.c_[kvecs, nk]
    filters = tables.Filters(complevel=5, complib='zlib')
    fp = tables.open_file(opt.outfile, mode='w', filter=filters)
    atom = tables.Atom.from_dtype(data.dtype)
    ca = fp.create_carray(fp.root, 'data', atom, data.shape)
    ca[:] = data
    fp.close()
#end def write_nk



operations = obj(
    kinetic = kinetic,
    write_nk = write_nk,
    )



if __name__=='__main__':

    if len(sys.argv)<2:
        user_error('First argument must be operation to perform on ESHDF file.\ne.g. to examine kinetic energies, type "eshdf kinetic ..."\nValid operations are: {0}'.format(sorted(operations.keys())))
    #end if
    op_type = sys.argv[1]
    if op_type in operations:
        operations[op_type]()
    else:
        user_error('Unknown operation: {0}\nValid options are: {1}'.format(op_type,sorted(operations.keys())))
    #end if

#end if
