#!/usr/bin/env python
#===========================================================================#
#                                                                           #
#  File:       post_epw.py                                                  #
#  Usage:      processing data from EPW for plot and analysis               #      
#  Author:     Shunhong Zhang <szhang2@ustc.edu.cn>                         #
#                                                                           #
#===========================================================================#


from elphtk.epw_tool import *
from elphtk.epw_binary import __note__,__doc__
from elphtk.epw_binary import *


def verbose_grid_params(grid_params):
    nkftot,nkf1,nkf2,nkf3,nkfs,nbndfs = grid_params
    fmt = '{:>8s} = {:5d}'
    print ('\nGriding-related Parameters')
    print (fmt.format('nkftot',nkftot))
    print (fmt.format('nkf1',nkf1))
    print (fmt.format('nkf2',nkf2))
    print (fmt.format('nkf3',nkf3))
    print (fmt.format('nkfs',nkfs))
    print (fmt.format('nbndfs',nbndfs))
    print ('\n')
 

def read_elph(outdir='.',prefix='epw',seedname='wannier'):
    qpts,wf=read_freq(outdir=outdir,seedname=seedname)
    nqtotf,nmodes=wf.shape
    write_freq(qpts,wf)

    nlines = get_ephmat_lines(outdir=outdir,prefix=prefix)
    ixkf = read_ikmap(prefix=prefix,seedname=seedname)

    grid_params,en_params,wkfs,xkfs,ekfs = read_egnv(seedname=seedname)
    nkftot,nkf1,nkf2,nkf3,nkfs,nbndfs = grid_params
    ef,ef0,dosef,degaussw,fsthick = en_params
    verbose_grid_params(grid_params)

    nks=mpi_read_nks(prefix=prefix,seedname=seedname)
    ixkqf,nqfs=calc_kqmap(nkfs,xkfs,nqtotf,nqf1,nqf2,nqf3,nkf1,nkf2,nkf3,ixkf)
    nnq=calc_nnq(ixkqf)

    #calc_nnq(ixkqf)
    #plot_nkfs(nkf1,nkf2,nkf3,ixkf,ekfs)

    read_ephmat(nks,ef0,fsthick,ekfs,ixkqf,nqfs,nbndfs,
    outdir=outdir,prefix=prefix,seedname=seedname)


def mpi_read_elph(outdir='.',prefix='epw',seedname='wannier',ipool=1):
    comm,size,rank,node = get_mpi_handles()
    nlines = get_ephmat_lines(outdir=outdir,prefix=prefix)
    qpts,wf = read_freq(outdir=outdir,seedname=seedname,prefix=prefix)
    nqtotf,nmodes=wf.shape

    ixkf=None
    if not rank: ixkf = read_ikmap(prefix=prefix,seedname=seedname)
    ixkf=comm.bcast(ixkf,root=0)

    grid_params,en_params,wkfs,xkfs,ekfs = mpi_read_egnv(seedname=seedname)
    nkftot,nkf1,nkf2,nkf3,nkfs,nbndfs = grid_params
    ef,ef0,dosef,degaussw,fsthick = en_params

    nks=mpi_read_nks(outdir=outdir,seedname=seedname,prefix=prefix)
    ixkqf,nqfs = calc_kqmap(nkfs,xkfs,nqtotf,nqf1,nqf2,nqf3,nkf1,nkf2,nkf3,ixkf)
    nnq = calc_nnq(ixkqf)
    read_ephmat_one_pool(ipool,nks,ef0,fsthick,ekfs,ixkqf,nqfs,nbndfs,nnq,nlines,
    outdir=outdir,prefix=prefix,seedname=seedname)
    if not rank: 
        calc_nnq(ixkqf)
        plot_nkfs(nkf1,nkf2,nkf3,ixkf,ekfs)
    mpi_read_ephmat(nks,ef0,fsthick,ekfs,ixkqf,nqfs,nbndfs,
    outdir=outdir,prefix=prefix,seedname=seedname)


def get_args(desc_str):
    import argparse
    from elphtk import arguments
    parser = argparse.ArgumentParser(prog='epw_tools', description = desc_str)
    arguments.add_control_arguments(parser)
    arguments.add_io_arguments(parser)
    arguments.add_wan_arguments(parser)
    arguments.add_fig_arguments(parser)
    arguments.add_plot_arguments(parser)
    parser.add_argument('--prefix',type=str,default=None,help='prefix for epw')
    parser.add_argument('--seedname',type=str,default='ZrN',help='seedname for wannier90 files')
    parser.add_argument('--iso',type=bool,default=False,help='iso a2f or not')
    parser.add_argument('--gtype',type=str,default='pade',help='type of spuerconducting gap')
    parser.add_argument('--ismear',type=int,default=0,help='index of smearing, 0 to 8')
    parser.add_argument('--mesh_style',type=str,default='interp',help='style to show 2D mesh')
    parser.add_argument('--sum_modes',type=arguments.str2bool,nargs='?',const=True,default=True,help='sum lambda for all bands')
    parser.add_argument('--freq_unit',type=str,default='meV',help='unit of freq')
    parser.add_argument('--max_freq',type=float,default=0,help='max freq plotted')
    parser.add_argument('--cmap',type=str,default='jet',help='color map for slice plot')
    parser.add_argument('--imode',type=int,default=0,help='mode index specified for plot')
    parser.add_argument('--iband',type=int,default=0,help='band index specified for plot')
    parser.add_argument('--kmode',type=str,default='mesh',help='mode of kpts, mesh or band')
    args = parser.parse_args()
    return args  

def print_task_list(task_list):
    print ('\n{}\nlist of tasks\n{}'.format('='*20,'-'*20))
    print ('\n'.join(['{}'.format(item) for item in task_list]))
    print ('{}\n'.format('='*20))


def dryrun():
    print('use --task to specify your task')


task_list=[
'a2f',
'wan_band',
'gap',
'qdos',
'qpr',
'phband',
'phdos',
'fermi_surface',
'phmesh',
'lambda_FS',
'elself',
'phself',
'lambda.phself',
'specfun',
'specfun_sup',
'kgmap',
'plot_nest_fn',
'elph',
'mpi_elph',
'freq',
'kqmap_demo',
None,
]


def main(kws):
    print ('{0}\n'.format(desc_str))
    print_task_list(task_list)
    print ('task=',args.task)
    if   args.task==None:               dryrun()
    elif args.task=='a2f':              plot_a2f(args)
    elif args.task=='wan_band':         plot_wan_band(args)
    elif args.task=='gap':              plot_gap(args)
    elif args.task=='qdos':             plot_qdos(args)
    elif args.task=='qpr':              plot_qpr(args)
    elif args.task=='phband':           plot_phband(args)
    elif args.task=='fermi_surface':    plot_fermi_surface(args)
    elif args.task=='phmesh':           plot_phmesh(args,ikz=0)
    elif args.task=='lambda_FS':        plot_lambda_FS(args)
    elif args.task=='elself':           plot_elself(args)
    elif args.task=='phself':           plot_phself(args)
    elif args.task=='lambda.phself':    plot_lambda.phself(args)
    elif args.task=='specfun':          plot_specfun(args)
    elif args.task=='specfun_sup':      plot_specfun_sup(args)
    elif args.task=='phdos':            plot_phdos(args)
    elif args.task=='kqmap_demo':       test_kpmq()
    elif args.task=='elph':             read_elph(**kws)
    elif args.task=='mpi_elph':         mpi_read_elph(**kws)
    elif args.task=='engv':             
        grid_params,en_params,wkfs,xkfs,ekfs = read_egnv(outdir=args.outdir,seedname=args.seedname)
        write_grid_params(grid_params)
        write_en_params(en_params)
    elif args.task=='freq':
        qpts,freqs = read_freq_scipy(**kws)
        qpts,freqs = read_freq(**kws)
        write_freq(qpts,freqs)
    else: exit('Unavaialble taks {} specified!'.format(args.task))



nqf1,nqf2,nqf3=(12,12,1)
desc_str = 'post-processing for EPW'
args=get_args(desc_str)

kws = dict(
outdir=args.outdir,
prefix=args.prefix,
seedname=args.seedname)
 

if __name__=='__main__':
    from elphtk import __version__
    from elphtk.pkg_info import verbose_pkg_info
    verbose_pkg_info(__version__)
    print (__note__)
    print ('nqf1,nqf2,nqf3 = {},{},{}'.format(nqf1,nqf2,nqf3))
    print ('Change them if necessary')
    main(kws)
