#!/usr/bin/env python2.7
import sys
import argparse
import bz2
import datetime as dt
from hapflk import InputOutput as IO
from hapflk import popgen
from hapflk import missing
import os
from fastphase import fastphaseCythonMT as myfph
import numpy as np

def formatTD(td):
    hours = td.seconds // 3600
    minutes = (td.seconds % 3600) // 60
    seconds = td.seconds % 60
    return '%02d:%02d:%02d' % (hours, minutes, seconds)

class Stepper():
    def __init__(self,prefix=''):
        self.start=dt.datetime.now()
        self.ncalls=0
        self.prefix=prefix
        self.write(msg="Start @ "+dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    def new(self,msg='Step'):
        self.ncalls+=1
        print self.prefix,self.ncalls,". [",formatTD(dt.datetime.now()-self.start),"] ",msg
        return dt.datetime.now()-self.start
    def write(self,msg):
        print (10+len(self.prefix))*' '+msg
    def end(self):
        return self.new(msg="The End @ "+dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

class Results():
    def __init__(self):
        self.namespace={}
    def __getitem__(self,item):
        return self.namespace[item]
    def __setitem__(self,item,value):
        self.namespace[item]=value
    def update(self,new_item):
        self.namespace.update(new_item)
    def show_contents(self):
        tw='Available results :\n'
        tw += '\n'.join(['\t -- '+str(x) for x in self.namespace.keys()])
        print ''.join(tw)
    def write_SNP_reynolds(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        popnames=self['pops']
        D=self['reynolds']
        for i in range(D.shape[0]):
            tw=[popnames[i]]
            for j in range(D.shape[1]):
                tw.append(str(D[i,j]))
            print >>fout,' '.join(tw)
        fout.close()
    def write_allele_frequencies(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        frq=self['freqs']
        popnames=self['pops']
        carte=self['map']
        data=self['dataset']
        print >>fout,'rs','chr','pos','all_ref','all_alt',' '.join(popnames)
        sorted_snps=carte.sort_loci(data.snp.keys())
        for s in sorted_snps:
            sidx=data.snpIdx[s]
            spos=carte.position(s)
            tw=[s,str(spos[0]),str(spos[2]),str(data.snp[s].alleles[1]),str(data.snp[s].alleles[0])]
            for ip,nom in enumerate(popnames):
                tw.append(str(frq[ip,sidx]))
            print >>fout,' '.join(tw)
    def write_cluster_frequencies(self,filename,outgroup):
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        Kprob=self['cluster.freqs']
        ##[E][K][npop x nsnp] dataset
        nfit,nclus,npop,nsnp=Kprob.shape
        pops=[x for x in self['pops'] if x!=outgroup]
        for ifit in range(nfit):
            fout=bz2.BZ2File(filename+'.fit_'+str(ifit)+'.bz2','w')
            print >>fout,'pop','locus','position','cluster','prob'
            for ipop in range(npop):
                for i,s in enumerate(sorted_snps):
                    spos=self['map'].position(s)
                    for ik in range(nclus):
                        print >>fout,pops[ipop],s,spos[2],ik,Kprob[ifit,ik,ipop,i]
            fout.close()
    def write_hapflk_results(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos','hapflk'
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for i,s in enumerate(sorted_snps):
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2]),self['hapflk'][i]]
            print >>fout,' '.join([str(x) for x in tw])
   
    def write_flk_results(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos','pzero','flk','pvalue'
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for s in sorted_snps:
            sidx=self['dataset'].snpIdx[s]
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2]),self['pzero'][sidx],self['FLK'][sidx],self['pval.FLK'][sidx]]
            print >>fout,' '.join([str(x) for x in tw])
    def write_eigen_flk(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos',' '.join(['PC'+str(i+1) for i in range(self['eigen.FLK'].shape[0])])
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for s in sorted_snps:
            sidx=self['dataset'].snpIdx[s]
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2])]
            for i in range(self['eigen.FLK'].shape[0]):
                tw.append(str(self['eigen.FLK'][i,sidx]))
            print >>fout,' '.join(tw)
        fout.close()
    def write_eigen_hapflk(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos',' '.join(['PC'+str(i+1) for i in range(self['eigen.hapflk'].shape[0])])
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for s in sorted_snps:
            sidx=self['dataset'].snpIdx[s]
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2])]
            for i in range(self['eigen.hapflk'].shape[0]):
                tw.append(str(self['eigen.hapflk'][i,sidx]))
            print >>fout,' '.join(tw)
        fout.close()
 
    def write_eigen_decomposition(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        popnames=[pop for pop in self['pops'] if pop != self['outgroup']]
        print >>fout,'PC','Lambda',' '.join(popnames)
        ## principal components
        for i in range(self['eigvec'].shape[1]):
            tw=[str(i+1),str(self['eigval'][i])]
            ## populations
            for j in range(self['eigvec'].shape[0]):
                tw.append(str(self['eigvec'][j,i]))
            print >>fout,' '.join(tw)
        fout.close()

def populate_parser(parser):
    parser.add_argument('-p','--prefix',dest='prefix',help='prefix for output files',default='hapflk')
    parser.add_argument('--ncpu',metavar='N',help='Use N processors when possible',default=1,type=int)
    parser.add_argument('--eigen',help='Perform eigen decomposition of tests',default=False,action='store_true')
    parser.add_argument('--reynolds',help='Force writing down Reynolds distances',default=False,action='store_true')
    parser.add_argument('--future',help=argparse.SUPPRESS,default=False,action="store_true") ## for testing future release
    parser.add_argument('--debug',help=argparse.SUPPRESS,default=False,action="store_true") ## for debug purpose
    flk_opts=parser.add_argument_group('Population kinship ','Set parameters for getting the population kinship matrix')
    flk_opts.add_argument('--kinship',help='Read population kinship from file (if None, kinship is estimated)',metavar='FILE',default=None)
    flk_opts.add_argument('--reynolds-snps',dest='reysnps',type=int,help='Number of SNPs to use to estimate Reynolds distances',default=100000,metavar='L')
    flk_opts.add_argument('--outgroup',default=None,help='Use population POP as outgroup for tree rooting (if None, use midpoint rooting)',metavar="POP")
    flk_opts.add_argument('--keep-outgroup',dest='keepOG',default=False,help='Keep outgroup in population set',action="store_true")
    LD_opts=parser.add_argument_group('hapFLK and LD model','Switch on hapFLK calculations and set parameters of the LD model ')
    LD_opts.add_argument('-K',help='Set the number of clusters to K. hapFLK calculations switched off if K<0',default=-1,type=int)
    LD_opts.add_argument('--nfit',help='Set the number of model fit to use',type=int,default=20)
    LD_opts.add_argument('--phased','--inbred',help='Haplotype data provided',dest='inbred',action="store_true",default=False)
    LD_opts.add_argument('--kfrq',dest='kfrq',help='Write Cluster frequencies (Big files)',action="store_true",default=False)
    LD_opts.add_argument('--write-params',dest='wparams',help=argparse.SUPPRESS,default=False,action='store_true')
    parser.add_argument('--annot',help='Shortcut for --eigen --reynolds --kfrq',default=False,action='store_true')

def _tohap(x):
    if x==1:
        return missing
    else:
        return x/2

tohap=np.vectorize(_tohap)

def main():
    ## read options and init
    myparser=argparse.ArgumentParser(parents=[IO.io_parser],formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    populate_parser(myparser)
    if len(sys.argv)<2:
        myparser.print_help()
        sys.exit(0)
    myopts=myparser.parse_args()
    if myopts.annot:
        myopts.reynolds=True
        myopts.kfrq=True
        myopts.eigen=True
    counter=Stepper()
    results=Results()

    ## read input
    counter.new('Reading Input Files')
    my_input=IO.parseInput(myopts)
    if my_input is None:
        myparser.print_help()
        sys.exit(0)
    data=my_input['dataset']
    carte=my_input['map']
    results.update(my_input)
    sorted_snps=carte.sort_loci(data.snp.keys())
    sorted_snps_idx=np.array([data.snpIdx[s] for s in sorted_snps])

    ## compute allele frequencies
    counter.new('Computing Allele Frequencies')
    res=data.compute_pop_frq()
    results.update(res)

    ## compute Reynolds distances
    counter.new('Computing Reynolds distances')
    if myopts.reysnps>data.nsnp:
        myopts.reysnps=data.nsnp
    pbinom=float(myopts.reysnps)/data.nsnp
    snp_subset=np.array(np.random.binomial(1,pbinom,data.nsnp),dtype=bool)
    reynolds_dist=popgen.reynolds(results["freqs"][:,snp_subset])
    heteroZ=popgen.heterozygosity(results["freqs"][:,snp_subset])
    results.update({"reynolds":reynolds_dist})
    results.update({"outgroup":myopts.outgroup})
    results.update({"hzy":heteroZ})
    ## Get population kinship
    if myopts.kinship:
        counter.new('Reading Kinship Matrix')
        if not myopts.keepOG:
            kinship=popgen.popKinship_fromFile(myopts.kinship,[pop for pop in results['pops'] if pop != myopts.outgroup])
        else:
            kinship=popgen.popKinship_fromFile(myopts.kinship,[pop for pop in results['pops']])
    else:
       ## estimate kinship
        counter.new("Computing Kinship Matrix")
        kinship=popgen.popKinship_new(results['reynolds'],results['pops'],myopts.outgroup,
                                      fprefix=myopts.prefix,keep_outgroup=myopts.keepOG,hzy=results['hzy'])
    results.update({'kinship':kinship})

    ## forget about the outgroup if we keep it
    if myopts.keepOG:
        myopts.outgroup='Iwantmyoutgroupback'
        results['outgroup']=myopts.outgroup
    
    ## Compute single SNP FLK tests
    counter.new("Computing FLK tests")
    filter_outgroup=np.array([x!=myopts.outgroup for x in results['pops']],dtype=bool)
    frq_test=results["freqs"][filter_outgroup,]
    myFLK=popgen.FLK_test(results["kinship"])
    results.update({'eigvec':myFLK.Q,'eigval':myFLK.D})
    myFLK_res=np.apply_along_axis(myFLK.eval_flk,0,frq_test)
    results.update({'pzero':myFLK_res[0,],'FLK':myFLK_res[1,],'pval.FLK':myFLK_res[2,],'eigen.FLK':np.power(myFLK_res[3:,],2)})
                   
    ### hapFLK calculations
    if myopts.K>0:
        ## Estimate fastPhase model
        counter.new("Fitting LD model (this might take a while)")
        fph_params=[]
        fastphase_model=myfph.fastphase(results['dataset'].nsnp)
        for name,i in results['dataset'].indivIdx.items():
            if results['dataset'].indiv[name].pop==myopts.outgroup:
                continue
            if myopts.inbred:
                haplo=np.array(tohap(np.array(results['dataset'].Data[i,sorted_snps_idx],dtype=int)),dtype=np.int)
                fastphase_model.addHaplotype(name,haplo)
            else:
                fastphase_model.addGenotype(name,np.array(results['dataset'].Data[i,sorted_snps_idx],dtype=int))
        for e in range(myopts.nfit):
            sys.stderr.write('\tEM %d / %d \r'%(e+1,myopts.nfit))
            sys.stderr.flush()
            par=fastphase_model.fit(nClus=myopts.K,nthread=myopts.ncpu,verbose=myopts.debug)
            fph_params.append(par)
            if myopts.wparams:
                fout=open(myopts.prefix+'fph_par_'+str(e),'w')
                par.write(fout)
                fout.close()
        print
        results.update({'fph.params':fph_params})
        ## Compute Cluster population frequencies
        counter.new("Computing Cluster Frequencies")
        ## [E][K][npop x nsnp] dataset
        pop_cluster_freq=np.zeros((myopts.nfit,myopts.K,np.sum(filter_outgroup),results['dataset'].nsnp),dtype=float)
        ngeno=0
        fastphase_model.flush()
        for ipop,popname in enumerate([x for x in results['pops'] if x!=myopts.outgroup]):
            sys.stdout.write("\t %16s\r"%popname)
            sys.stdout.flush
            pvec=results['dataset'].populations[popname]
            n_pop_indiv=sum(pvec)
            for name,iind in results['dataset'].indivIdx.items():
                if pvec[iind]:
                    if myopts.inbred:
                        haplo=np.array(tohap(np.array(results['dataset'].Data[iind,sorted_snps_idx],dtype=int)),dtype=np.int)
                        fastphase_model.addHaplotype(name,haplo)
                    else:
                        fastphase_model.addGenotype(name,np.array(results['dataset'].Data[iind,sorted_snps_idx],dtype=int))
                    ngeno+=1
                    if ngeno==myopts.ncpu:
                        for ifit in range(myopts.nfit):
                            imputations=fastphase_model.impute([results['fph.params'][ifit]],nthread=myopts.ncpu)
                            for nom,dat in imputations.items():
                                probZ=dat[1]
                                if myopts.inbred:
                                    pop_cluster_freq[ifit,:,ipop,:]+=np.transpose(probZ[0]/n_pop_indiv)
                                else:
                                    pop_cluster_freq[ifit,:,ipop,:]+=np.transpose((0.5/n_pop_indiv)*(np.sum(probZ[0],axis=1)+np.sum(probZ[0],axis=2)))
                        ngeno=0
                        fastphase_model.flush()
            ## anyone left ?
            if ngeno>0:
                for ifit in range(myopts.nfit):
                    imputations=fastphase_model.impute([results['fph.params'][ifit]],nthread=myopts.ncpu)
                    for nom,dat in imputations.items():
                        probZ=dat[1]
                        if myopts.inbred:
                            pop_cluster_freq[ifit,:,ipop,:]+=np.transpose(probZ[0]/n_pop_indiv)
                        else:
                            pop_cluster_freq[ifit,:,ipop,:]+=np.transpose((0.5/n_pop_indiv)*(np.sum(probZ[0],axis=1)+np.sum(probZ[0],axis=2)))
                ngeno=0
                fastphase_model.flush()
        print
        results.update({'cluster.freqs':pop_cluster_freq})

        ## Compute hapFLK
        counter.new('Computing hapFLK')
        myFLK=popgen.FLK_test(results["kinship"],diallelic=False)
        hapflk=np.zeros(results['dataset'].nsnp,dtype=float)
        hapflk_eigen=np.zeros((results['eigval'].shape[0],results['dataset'].nsnp),dtype=float)
        for e in range(myopts.nfit):
            for k in range(myopts.K):
                myFLK_res=np.apply_along_axis(myFLK.eval_flk,0,results['cluster.freqs'][e,k,])
                hapflk+=myFLK_res[1,]
                hapflk_eigen+=np.power(myFLK_res[3:,],2)
        hapflk/=myopts.nfit
        hapflk_eigen/=myopts.nfit
        results.update({'hapflk':hapflk,'eigen.hapflk':hapflk_eigen})
 
    ## Write out results and bye
    counter.new("Writing down results")
    ##results.show_contents()
    results.write_allele_frequencies(myopts.prefix+'.frq')
    results.write_flk_results(myopts.prefix+'.flk')
    if myopts.reynolds:
        results.write_SNP_reynolds(myopts.prefix+'.rey')
    if myopts.eigen:
        results.write_eigen_decomposition(myopts.prefix+'.eig')
        results.write_eigen_flk(myopts.prefix+'.flk.eig')
    if myopts.kfrq:
        try:
            results.write_cluster_frequencies(myopts.prefix+'.kfrq',outgroup=myopts.outgroup)
        except KeyError:
            pass
    if myopts.K>0:
        try:
            results.write_hapflk_results(myopts.prefix+'.hapflk')
            if myopts.eigen:
                results.write_eigen_hapflk(myopts.prefix+'.hapflk.eig')
        except KeyError:
            pass
    counter.end()
    
    
    
if __name__=='__main__':
    main()
