#!/usr/bin/env python

#======================================
#
# Compute the superconductivity
# transition temperature Tc
# using the Allen-Dynes modified
# McMillan Equation
#
# Shunhong Zhang
# shzang2@ustc.edu.cn
# Date: Mar 15, 2020
#
#======================================

import os
import numpy as np
from elphtk.arguments import *
import pickle
import glob
import re


def write_lambda_in(fil='lambda.in',mustar=0.1,filph='ph.out',elph_method='interpolated'):
    if elph_method=='interpolated':
        elph_dir='elph_dir'
        elph_prefix='elph.inp_lambda'
        elph_fils=glob.glob('{0}/{1}'.format(elph_dir,elph_prefix))
    qpts,qweights=qe_phonon.get_qpt_from_phout(filph=filph)
    nq=qpts.shape[0]
    with open(fil,'w') as fw:
        fw.write('10  0.12  1    ! emax (freq cutoff in THz), degauss, smearing method\n')
        fw.write('{0}\n'.format(nq))
        for iq,q1 in enumerate(qpts):
            fw.write(('{:12.8f}'*3).format(*tuple(q1)))
            fw.write('{:10.4f}\n'.format(qweights[iq]))
        for iqpt in range(nq): 
            if elph_method=='interpolated': 
                fw.write('{0}/{1}.{2}\n'.format(elph_dir,elph_prefix,iqpt+1))
            elif elph_method=='lambda_tetra' or elph=='gamma_tetra':
                elph_fils=glob.glob('*dyn*.elph*')
                for fil in elph_fils:
                    if int(fil.split('.')[-1])==iqpt+1: 
                        fw.write(fil+'\n')
        fw.write('{:5.2f}  ! \mu: screened Coloumb coefficient\n'.format(mustar))


def parse_input(fil='lambda.in'):
    if not os.path.isfile(fil):
        exit('cannot find {0}'.format(fil))
    with open(fil) as f:
        emax,degaussq,ngaussq = np.array(f.readline().split()[:3],float)
        ngaussq=int(ngaussq)
        nqpt=int(f.readline())
        qweight=np.array([f.readline().split()[-1] for iq in range(nqpt)],float)
        qweight=qweight/np.sum(qweight)
        for i in range(nqpt): f.readline()
        mustar=float(f.readline().split()[0])
    return emax,degaussq,ngaussq,qweight,mustar


def get_elph_single_q(iqpt=1,elph_dir='elph_dir',elph_method='interpolated'):
    if elph_method=='interpolated':
        elph_prefix='elph.inp_lambda'
        fil=glob.glob('{0}/{1}.{2}'.format(elph_dir,elph_prefix,iqpt))
    elif elph_method=='lambda_tetra' or elph_method=='gamma_tetra':
        fil=glob.glob('*dyn{0}.elph{0}'.format(iqpt))
    if len(fil)!=1: exit('elph file not found for q={0}'.format(iqpt))
    with open(fil[0]) as f:
        line=f.readline().split()
        nsig,nmode=np.array(line[-2:],int)
        degauss=np.zeros(nsig,float)
        ef=np.zeros(nsig,float)
        dosef=np.zeros(nsig,float)
        lambdaq=np.zeros((nsig,nmode),float)
        gammaq=np.zeros((nsig,nmode),float)
        w2=np.fromfile(f,count=nmode,sep=' ',dtype=float)
        freq=np.sign(w2)*np.sqrt(abs(w2))*3289.828  # Ry to THz
        for isig in range(nsig):
            line=f.readline().split()
            degauss[isig]=float(line[2])
            line=f.readline().split('=')
            for imode in range(nmode):
                line=f.readline().split('=')
                lambdaq[isig,imode]=float(line[1].split()[0])
                gammaq[isig,imode]=float(line[2].split()[0])
    return freq,dosef,ef,lambdaq,gammaq,degauss


def refine_degauss(degauss):
    nsig=degauss.shape[0]
    degauss_tmp=np.linspace(0,np.max(degauss),nsig+1)[1:]
    if not np.allclose(degauss,degauss_tmp):
        print ('Warning: the el_ph_sigma output by QE might be inaccurate! (too small to display)')
        degauss=degauss_tmp
    return degauss



def get_elph(elph_dir='elph_dir',elph_method='interpolated'):
    if elph_method=='interpolated': 
        elph_prefix='elph.inp_lambda'
        fils=glob.glob('{0}/{1}.*'.format(elph_dir,elph_prefix))
    elif elph_method=='lambda_tetra' or elph_method=='gamma_tetra':
        fils=glob.glob('*dyn*.elph*')
    nqpt=len(fils)
    line=open(fils[0]).readline()
    nsig,nmode=np.array(line.split()[-2:],int)
    freq=np.zeros((nqpt,nmode),float)
    dosef=np.zeros(nsig,float)
    ef=np.zeros(nsig,float)
    lambdaq=np.zeros((nqpt,nsig,nmode),float)
    gammaq=np.zeros((nqpt,nsig,nmode),float)
    degauss=np.zeros(nsig,float)
    for iqpt in range(nqpt):
        dataq=get_elph_single_q(iqpt=iqpt+1,elph_dir=elph_dir,
        elph_method=elph_method)
        freq[iqpt],dosef,ef,lambdaq[iqpt],gammaq[iqpt],degauss=dataq
        if iqpt!=0:
            if not np.allclose(dataq[1],dosef):
                print ('Error: N(E) inconsisitent at qpt 1 and {0}'.format(iqpt+1))
            if not np.allclose(dataq[2],ef):
                print ('Error: Ef inconsisitent at qpt 1 and {0}'.format(iqpt+1))
            if not np.allclose(dataq[5],degauss):
                print ('Error: degauss inconsisitent at qpt 1 and {0}'.format(iqpt+1))
    degauss=refine_degauss(degauss)
    return freq,dosef,ef,lambdaq,gammaq,degauss


def get_elph_from_phout(outdir='./',filph='ph.out'):
    lines=np.array(open('{0}/{1}'.format(outdir,filph)).readlines())
    idx=np.where([re.search("Calculation of q =",line) for line in lines])[0]
    nqpt=len(idx)
    idx=np.where([re.search("number of atoms/cell",line) for line in lines])[0][0]
    nat=int(lines[idx].split()[-1])
    nmode=nat*3
    idx=np.where([re.search("Broadening",line) for line in lines])[0]
    nsig=len(idx)/nqpt
    if nqpt*nsig!=len(idx):
        exit('Fail to fetch data from {0}/{1}'.format(outdir,filph))
    degauss=np.array([item.split()[2] for item in lines[idx]],float).reshape(nqpt,nsig)[0]
    degauss=refine_degauss(degauss)
    idx=np.where([re.search("gamma=",line) for line in lines])[0]
    get_gammaq=np.array([item.rstrip('\n').split('=') for item in lines[idx]])
    lambdaq=np.array([item[1].split()[0] for item in get_gammaq],float).reshape(nqpt,nsig,nmode)
    gammaq=np.array([item[2].split()[0] for item in get_gammaq],float).reshape(nqpt,nsig,nmode)
    idx=np.where([re.search('DOS',line) for line in lines])[0]
    get_dos=[item.rstrip('\n').split('=') for item in lines[idx]]
    dosef=np.array([item[1].split()[0] for item in get_dos],float).reshape(nqpt,nsig)[0]
    ef=np.array([item[2].split()[0] for item in get_dos],float).reshape(nqpt,nsig)[0]
    return dosef,ef,lambdaq,gammaq,degauss


def reverse_imag_freq(freq):
    refined_freq=freq
    idx=np.where(freq<0)
    if len(idx[0])>0:
        print ('Some small negaitve frequencies found at qpt # {0}'.format(sorted(set(idx[0]+1))))
        print ('The most negative one is {:10.5f} THz'.format(np.min(freq)))
        if np.max(abs(freq[idx]))<0.1:
            refined_freq[idx]*=-1.
            print ('This code reverses their signs to proceed with the calculation')
        else:
            print ('This might lead to severe divergence problem in Tc')
    return refined_freq


def calc_lambda(freq,lambdaq,degauss,degaussq,ngaussq,qweight,emax,nex=200):
    from elphtk import wgauss
    freq=reverse_imag_freq(freq)
    nqpt,nsig,nmode=lambdaq.shape
    lambda_mode=np.zeros((nsig,nmode),float)
    alpha2F=np.zeros((nex,nsig),float)
    ens=np.linspace(0,emax,nex)
    lambda2=np.zeros(nsig,float)
    logavg=np.zeros(nsig,float)
    Tc=np.zeros(nsig,float)
    e_step=emax/(nex-1)
    for isig in range(nsig):
        for imode in range(nmode):
            lambda_mode[isig,imode] = np.sum(lambdaq[:,isig,imode]*qweight,axis=0)
            for ie,en in enumerate(ens):
                smear_freq=np.array([wgauss.w0gauss(ngaussq,(en-ifreq)/degaussq)/degaussq for ifreq in freq[:,imode]])
                alpha2F[ie,isig] += np.sum(qweight*lambdaq[:,isig,imode]*freq[:,imode]*0.50*smear_freq)
        lambda2[isig]=np.sum(alpha2F[1:,isig]/ens[1:])*e_step*2
        logavg[isig]=np.sum(alpha2F[1:,isig]*np.log(ens[1:])/ens[1:])*e_step*2
    lambda_all=np.sum(lambda_mode,axis=1)
    logavg=np.exp(logavg/lambda2)*47.9924
    return ens,lambda_mode,lambda_all,lambda2,alpha2F,logavg


def calc_Tc(logavg,lambda_all,mustar_list):
    Tc=[]
    for mustar in mustar_list:
        Tc0 = logavg/1.2*np.exp(-1.04*(1+lambda_all)/(lambda_all-mustar*(1+0.62*lambda_all)))
        Tc.append(Tc0)
    return np.array(Tc)


def print_lambda_out(lambda_all,lambda2,logavg,dosef,Tc,degauss,fil=None):
    fmt='     lambda = {0:10.6f} (  {1:9.6f} )  <log w>=   {2:8.3f} K  N(Ef)={3:10.6f} at degauss={4:6.3f}'
    nsig=lambda_all.shape[0]
    for isig in range(nsig):
        print (fmt.format(lambda_all[isig],lambda2[isig],logavg[isig],dosef[isig],degauss[isig]),file=fil)
    print (('{:12s}'*3).format("lambda", "omega_log",  "T_c"),file=fil)
    for item in zip(lambda_all,logavg,Tc):
        if item[2]<1e3:
            print ('{0:10.5f} {1:10.3f} {2:14.7f}'.format(*tuple(item)),file=fil)
        else:
            print ('{0:10.5f} {1:10.3f} {2:14.7e}'.format(*tuple(item)),file=fil)


def print_Tc(ens,lambdaq,lambda_mode,lambda_all,lambda2,alpha2F,logavg,dosef,Tc,degaussq,ngaussq,degauss,qweight):
    nqpt,nsig,nmode=lambdaq.shape
    fmt=' Broadening   {0:5.3f} lambda     {1:9.6f} dos(Ef) {2:7.4f} omega_ln [K]    {3:8.5f}\n'
    with open('lambda','w') as fw:
        fw.write('\nElectron-phonon coupling constant, lambda\n\n')
        for isig in range(nsig):
            fw.write(fmt.format(degauss[isig],lambda_all[isig],dosef[isig],np.sum(lambdaq[:,isig,:])))
    with open('alpha2F.dat','w') as fw:
        fw.write('# E(THz)'+('{:9.5f}'*nsig).format(*tuple(degauss))+'\n')
        for ie,en in enumerate(ens):
            fw.write('{:8.4f}'.format(en)+' '.join(['{:9.5f}'.format(ia) for ia in alpha2F[ie]])+'\n')
    with open('lambda.dat','w') as fw:
        fw.write('# degauss   lambda    int alpha2F  <log w>     N(Ef)\n')
        fmt='{:12.5f}'*5+'\n'
        for item in zip(degauss,lambda_all,lambda2,logavg,dosef):
            fw.write(fmt.format(*tuple(item)))
    with open('lambdaq.dat','w') as fw:
        fw.write('iqpt  '+' '.join(['{0:>7s} {1:>12s}'.format('degauss','lambdaq')]*5)+'\n')
        fmt='{0:7.4f} {1:12.6f}'
        for iqpt in range(nqpt):
            fw.write('{:4d}  '.format(iqpt+1))
            fw.write(' '.join([fmt.format(degauss[isig],np.sum(lambdaq[iqpt,isig])) for isig in range(5)])+'\n')
        fw.write('total ')
        fw.write(' '.join([fmt.format(degauss[isig],sum(np.sum(lambdaq[:,isig],axis=1)*qweight)) for isig in range(5)]))
    #print_lambda_out(lambda_all,lambda2,logavg,dosef,Tc,degauss)


def plot_sigma_convergence(degauss,lambda_all,mustar_list,Tc):
    import matplotlib.pyplot as plt
    fig,(ax1,ax2)=plt.subplots(1,2,figsize=(8,5))
    ax1.scatter(degauss,lambda_all,facecolor='none',edgecolor='r',s=40,label='$\lambda$')
    colors=['r','g','b','m','c','orange']
    for ii,(mustar,Tc0) in enumerate(zip(mustar_list,Tc)):
        idx=np.where(Tc0<400)
        slabel='$\mu^*={0}$'.format(mustar)
        ax2.scatter(degauss[idx],Tc0[idx],facecolor='none',edgecolor=colors[ii],s=20,label=slabel)
    for ax in (ax1,ax2):
        ax.set_xlabel('$\sigma\ (Ry)$')
        ax.set_xlim(0,np.max(degauss))
        ax.axhline(0,ls='--',color='gray',alpha=0.5)
        ax.set_xticks(np.linspace(0,np.max(degauss),5))
    ax2.legend(loc='lower right',scatterpoints=1)
    ax1.set_ylabel('$\lambda$')
    ax2.set_ylabel('$T_c\ (K)$')
    fig.tight_layout()
    fig.savefig('sigma_convergence',dpi=500)
    return fig,ax


def get_args(prog='qe_lambda'):
    import argparse
    from elphtk import arguments
    desc_str='processing QE lambda (EPC)'
    parser = argparse.ArgumentParser(prog=prog, description = desc_str)
    arguments.add_control_arguments(parser)
    arguments.add_io_arguments(parser)
    arguments.add_fig_arguments(parser)
    arguments.add_plot_arguments(parser)
    arguments.add_phonon_arguments(parser)
    parser.add_argument('--prefix',type=str,default=None)
    parser.add_argument('--phlw',type=arguments.str2bool,nargs='?',const=False,default=False,help='phonon linewidth, need linewidth file')
    parser.add_argument('--sc',type=eval,default=(1,1,1),help='size of supercell')
    parser.add_argument('--plot_gamma',type=arguments.str2bool,nargs='?',const=False,default=False,help='plot gamma on phonon spectra')
    parser.add_argument('--isigma',type=int,default=1,help='index of gamma files, represent the smearing')
    args = parser.parse_args()
    args.freq_unit='cm^{-1}'
    return parser, args


def main(args):
    mustar_list=np.arange(0.1,0.17,0.02)
    write_lambda_in(filph=args.filph,elph_method=args.elph_method)
    emax,degaussq,ngaussq,qweight,mustar=parse_input()

    # option 1
    freq,dosef,ef,lambdaq,gammaq,degauss=get_elph(elph_method=args.elph_method)

    # option 2
    '''
    try:
        freq=qe_phonon.get_freq_from_phout(outdir='./',filph=args.filph)
    except:
        nq=len(glob.glob(args.prefix+'dyn*'))
        freq=np.array([qe_phonon.get_qe_dynmat(args.prefix,iq+1)[3]['THz'] for iq in range(nq)])
    dosef,ef,lambdaq,gammaq,degauss=get_elph_from_phout(filph=args.filph)
    '''

    ens,lambda_mode,lambda_all,lambda2,alpha2F,logavg=calc_lambda(freq,lambdaq,degauss,degaussq,ngaussq,qweight,emax)
    Tc=calc_Tc(logavg,lambda_all,mustar_list)
    print_Tc(ens,lambdaq,lambda_mode,lambda_all,lambda2,alpha2F,logavg,dosef,Tc,degaussq,ngaussq,degauss,qweight)
    for ii in range(len(mustar_list)):
        filout=open('lambda_{0}.out'.format(mustar_list[ii]),'w')
        print_lambda_out(lambda_all,lambda2,logavg,dosef,Tc[ii],degauss,fil=filout)
    plot_sigma_convergence(degauss,lambda_all,mustar_list,Tc)


parser, args = get_args()
 
if __name__=='__main__':
    from elphtk import __version__
    from elphtk.pkg_info import verbose_pkg_info
    verbose_pkg_info(__version__)
    main(args)
