#!/usr/bin/env python2.7
import sys
import argparse
import bz2
import datetime as dt
from hapflk import InputOutput as IO
from hapflk import popgen
from hapflk import missing
import os
##from fastphase import fastphaseCythonMT as myfph
import fastphase
import numpy as np

def formatTD(td):
    hours = td.seconds // 3600
    minutes = (td.seconds % 3600) // 60
    seconds = td.seconds % 60
    return '%02d:%02d:%02d' % (hours, minutes, seconds)

class Stepper():
    def __init__(self,prefix=''):
        self.start=dt.datetime.now()
        self.ncalls=0
        self.prefix=prefix
        self.write(msg="Start @ "+dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    def new(self,msg='Step'):
        self.ncalls+=1
        print self.prefix,self.ncalls,". [",formatTD(dt.datetime.now()-self.start),"] ",msg
        return dt.datetime.now()-self.start
    def write(self,msg):
        print (10+len(self.prefix))*' '+msg
    def end(self):
        return self.new(msg="The End @ "+dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

class Results():
    def __init__(self):
        self.namespace={}
    def __getitem__(self,item):
        return self.namespace[item]
    def __setitem__(self,item,value):
        self.namespace[item]=value
    def update(self,new_item):
        self.namespace.update(new_item)
    def show_contents(self):
        tw='Available results :\n'
        tw += '\n'.join(['\t -- '+str(x) for x in self.namespace.keys()])
        print ''.join(tw)
    def write_SNP_reynolds(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        popnames=self['pops']
        D=self['reynolds']
        for i in range(D.shape[0]):
            tw=[popnames[i]]
            for j in range(D.shape[1]):
                tw.append(str(D[i,j]))
            print >>fout,' '.join(tw)
        fout.close()
    def write_allele_frequencies(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        frq=self['freqs']
        popnames=self['pops']
        carte=self['map']
        data=self['dataset']
        print >>fout,'rs','chr','pos','all_ref','all_alt',' '.join(popnames)
        sorted_snps=carte.sort_loci(data.snp.keys())
        for s in sorted_snps:
            sidx=data.snpIdx[s]
            spos=carte.position(s)
            tw=[s,str(spos[0]),str(spos[2]),str(data.snp[s].alleles[1]),str(data.snp[s].alleles[0])]
            for ip,nom in enumerate(popnames):
                tw.append(str(frq[ip,sidx]))
            print >>fout,' '.join(tw)
    def write_cluster_frequencies(self,filename,outgroup):
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        Kprob=self['cluster.freqs']
        ##[E][K][npop x nsnp] dataset
        nfit,nclus,npop,nsnp=Kprob.shape
        pops=[x for x in self['pops'] if x!=outgroup]
        for ifit in range(nfit):
            fout=bz2.BZ2File(filename+'.fit_'+str(ifit)+'.bz2','w')
            print >>fout,'pop','locus','position','cluster','prob'
            for ipop in range(npop):
                for i,s in enumerate(sorted_snps):
                    spos=self['map'].position(s)
                    for ik in range(nclus):
                        print >>fout,pops[ipop],s,spos[2],ik,Kprob[ifit,ik,ipop,i]
            fout.close()
    def write_hapflk_results(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos','hapflk'
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for i,s in enumerate(sorted_snps):
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2]),self['hapflk'][i]]
            print >>fout,' '.join([str(x) for x in tw])
   
    def write_flk_results(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos','pzero','flk','pvalue'
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for s in sorted_snps:
            sidx=self['dataset'].snpIdx[s]
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2]),self['pzero'][sidx],self['FLK'][sidx],self['pval.FLK'][sidx]]
            print >>fout,' '.join([str(x) for x in tw])
    def write_eigen_flk(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos',' '.join(['PC'+str(i+1) for i in range(self['eigen.FLK'].shape[0])])
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for s in sorted_snps:
            sidx=self['dataset'].snpIdx[s]
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2])]
            for i in range(self['eigen.FLK'].shape[0]):
                tw.append(str(self['eigen.FLK'][i,sidx]))
            print >>fout,' '.join(tw)
        fout.close()
    def write_eigen_hapflk(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        print >>fout,'rs','chr','pos',' '.join(['PC'+str(i+1) for i in range(self['eigen.hapflk'].shape[0])])
        sorted_snps=self['map'].sort_loci(self['dataset'].snp.keys())
        for s in sorted_snps:
            sidx=self['dataset'].snpIdx[s]
            spos=self['map'].position(s)
            tw=[s,str(spos[0]),str(spos[2])]
            for i in range(self['eigen.hapflk'].shape[0]):
                tw.append(str(self['eigen.hapflk'][i,sidx]))
            print >>fout,' '.join(tw)
        fout.close()
 
    def write_eigen_decomposition(self,filename=None):
        if filename is None:
            fout=sys.stdout
        else:
            fout=open(filename,'w')
        popnames=[pop for pop in self['pops'] if pop != self['outgroup']]
        print >>fout,'PC','Lambda',' '.join(popnames)
        ## principal components
        for i in range(self['eigvec'].shape[1]):
            tw=[str(i+1),str(self['eigval'][i])]
            ## populations
            for j in range(self['eigvec'].shape[0]):
                tw.append(str(self['eigvec'][j,i]))
            print >>fout,' '.join(tw)
        fout.close()

def populate_parser(parser):
    parser.add_argument('-p','--prefix',dest='prefix',help='prefix for output files',default='hapflk')
    parser.add_argument('--ncpu',metavar='N',help='Use N processors when possible',default=1,type=int)
    parser.add_argument('--eigen',help='Perform eigen decomposition of tests',default=False,action='store_true')
    parser.add_argument('--reynolds',help='Force writing down Reynolds distances',default=False,action='store_true')
    parser.add_argument('--future',help=argparse.SUPPRESS,default=False,action="store_true") ## for testing future release
    parser.add_argument('--debug',help=argparse.SUPPRESS,default=False,action="store_true") ## for debug purpose
    flk_opts=parser.add_argument_group('Population kinship ','Set parameters for getting the population kinship matrix')
    flk_opts.add_argument('--kinship',help='Read population kinship from file (if None, kinship is estimated)',metavar='FILE',default=None)
    flk_opts.add_argument('--reynolds-snps',dest='reysnps',type=int,help='Number of SNPs to use to estimate Reynolds distances',default=100000,metavar='L')
    flk_opts.add_argument('--outgroup',default=None,help='Use population POP as outgroup for tree rooting (if None, use midpoint rooting)',metavar="POP")
    flk_opts.add_argument('--keep-outgroup',dest='keepOG',default=False,help='Keep outgroup in population set',action="store_true")
    LD_opts=parser.add_argument_group('hapFLK and LD model','Switch on hapFLK calculations and set parameters of the LD model ')
    LD_opts.add_argument('-K',help='Set the number of clusters to K. hapFLK calculations switched off if K<0',default=-1,type=int)
    LD_opts.add_argument('--nfit',help='Set the number of model fit to use',type=int,default=20)
    LD_opts.add_argument('--phased','--inbred',help='Haplotype data provided',dest='inbred',action="store_true",default=False)
    LD_opts.add_argument('--kfrq',dest='kfrq',help='Write Cluster frequencies (Big files)',action="store_true",default=False)
    LD_opts.add_argument('--write-params',dest='wparams',help=argparse.SUPPRESS,default=False,action='store_true')
    parser.add_argument('--annot',help='Shortcut for --eigen --reynolds --kfrq',default=False,action='store_true')

def _tohap(x):
    if x==1:
        return missing
    else:
        return x/2

tohap=np.vectorize(_tohap)

def main():
    ## read options and init
    myparser=argparse.ArgumentParser(parents=[IO.io_parser],formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    populate_parser(myparser)
    if len(sys.argv)<2:
        myparser.print_help()
        sys.exit(0)
    myopts=myparser.parse_args()
    if myopts.annot:
        myopts.reynolds=True
        myopts.kfrq=True
        myopts.eigen=True
    counter=Stepper()
    results=Results()

    ## read input
    counter.new('Reading Input Files')
    my_input=IO.parseInput(myopts)
    if my_input is None:
        myparser.print_help()
        sys.exit(0)
    data=my_input['dataset']
    carte=my_input['map']
    results.update(my_input)
    sorted_snps=carte.sort_loci(data.snp.keys())
    sorted_snps_idx=np.array([data.snpIdx[s] for s in sorted_snps])

    ## compute allele frequencies
    counter.new('Computing Allele Frequencies')
    res=data.compute_pop_frq()
    results.update(res)

    ## compute Reynolds distances
    counter.new('Computing Reynolds distances')
    if myopts.reysnps>data.nsnp:
        myopts.reysnps=data.nsnp
    pbinom=float(myopts.reysnps)/data.nsnp
    snp_subset=np.array(np.random.binomial(1,pbinom,data.nsnp),dtype=bool)
    reynolds_dist=popgen.reynolds(results["freqs"][:,snp_subset])
    heteroZ=popgen.heterozygosity(results["freqs"][:,snp_subset])
    results.update({"reynolds":reynolds_dist})
    results.update({"outgroup":myopts.outgroup})
    results.update({"hzy":heteroZ})
    ## Get population kinship
    if myopts.kinship:
        counter.new('Reading Kinship Matrix')
        if not myopts.keepOG:
            kinship=popgen.popKinship_fromFile(myopts.kinship,[pop for pop in results['pops'] if pop != myopts.outgroup])
        else:
            kinship=popgen.popKinship_fromFile(myopts.kinship,[pop for pop in results['pops']])
    else:
       ## estimate kinship
        counter.new("Computing Kinship Matrix")
        kinship=popgen.popKinship_new(results['reynolds'],results['pops'],myopts.outgroup,
                                      fprefix=myopts.prefix,keep_outgroup=myopts.keepOG,hzy=results['hzy'])
    results.update({'kinship':kinship})

    ## forget about the outgroup if we keep it
    if myopts.keepOG:
        myopts.outgroup='Iwantmyoutgroupback'
        results['outgroup']=myopts.outgroup
    
    ## Compute single SNP FLK tests
    counter.new("Computing FLK tests")
    filter_outgroup=np.array([x!=myopts.outgroup for x in results['pops']],dtype=bool)
    frq_test=results["freqs"][filter_outgroup,]
    myFLK=popgen.FLK_test(results["kinship"])
    results.update({'eigvec':myFLK.Q,'eigval':myFLK.D})
    myFLK_res=np.apply_along_axis(myFLK.eval_flk,0,frq_test)
    results.update({'pzero':myFLK_res[0,],'FLK':myFLK_res[1,],'pval.FLK':myFLK_res[2,],'eigen.FLK':np.power(myFLK_res[3:,],2)})
                   
    ### hapFLK calculations
    if myopts.K>0:
        myfph=fastphase.fastphaseCythonMT
        ## Estimate fastPhase model
        counter.new("Fitting LD model (this might take a while)")
        fph_params=[]
        fastphase_model=myfph.fastphase(results['dataset'].nsnp)
        for name,i in results['dataset'].indivIdx.items():
            if results['dataset'].indiv[name].pop==myopts.outgroup:
                continue
            if myopts.inbred:
                haplo=np.array(tohap(np.array(results['dataset'].Data[i,sorted_snps_idx],dtype=int)),dtype=np.int)
                fastphase_model.addHaplotype(name,haplo)
            else:
                fastphase_model.addGenotype(name,np.array(results['dataset'].Data[i,sorted_snps_idx],dtype=int))
        for e in range(myopts.nfit):
            sys.stderr.write('\tEM %d / %d \r'%(e+1,myopts.nfit))
            sys.stderr.flush()
            par=fastphase_model.fit(nClus=myopts.K,nthread=myopts.ncpu,verbose=myopts.debug)
            fph_params.append(par)
            if myopts.wparams:
                fout=open(myopts.prefix+'fph_par_'+str(e),'w')
                par.write(fout)
                fout.close()
        print
        results.update({'fph.params':fph_params})
        ## Compute Cluster population frequencies
        counter.new("Computing Cluster Frequencies")
        ## [E][K][npop x nsnp] dataset
        pop_cluster_freq=np.zeros((myopts.nfit,myopts.K,np.sum(filter_outgroup),results['dataset'].nsnp),dtype=float)
        ngeno=0
        fastphase_model.flush()
        for ipop,popname in enumerate([x for x in results['pops'] if x!=myopts.outgroup]):
            sys.stdout.write("\t %16s\r"%popname)
            sys.stdout.flush
            pvec=results['dataset'].populations[popname]
            n_pop_indiv=sum(pvec)
            for name,iind in results['dataset'].indivIdx.items():
                if pvec[iind]:
                    if myopts.inbred:
                        haplo=np.array(tohap(np.array(results['dataset'].Data[iind,sorted_snps_idx],dtype=int)),dtype=np.int)
                        fastphase_model.addHaplotype(name,haplo)
                    else:
                        fastphase_model.addGenotype(name,np.array(results['dataset'].Data[iind,sorted_snps_idx],dtype=int))
                    ngeno+=1
                    if ngeno==myopts.ncpu:
                        for ifit in range(myopts.nfit):
                            imputations=fastphase_model.impute([results['fph.params'][ifit]],nthread=myopts.ncpu)
                            for nom,dat in imputations.items():
                                probZ=dat[1]
                                if myopts.inbred:
                                    pop_cluster_freq[ifit,:,ipop,:]+=np.transpose(probZ[0]/n_pop_indiv)
                                else:
                                    pop_cluster_freq[ifit,:,ipop,:]+=np.transpose((0.5/n_pop_indiv)*(np.sum(probZ[0],axis=1)+np.sum(probZ[0],axis=2)))
                        ngeno=0
                        fastphase_model.flush()
            ## anyone left ?
            if ngeno>0:
                for ifit in range(myopts.nfit):
                    imputations=fastphase_model.impute([results['fph.params'][ifit]],nthread=myopts.ncpu)
                    for nom,dat in imputations.items():
                        probZ=dat[1]
                        if myopts.inbred:
                            pop_cluster_freq[ifit,:,ipop,:]+=np.transpose(probZ[0]/n_pop_indiv)
                        else:
                            pop_cluster_freq[ifit,:,ipop,:]+=np.transpose((0.5/n_pop_indiv)*(np.sum(probZ[0],axis=1)+np.sum(probZ[0],axis=2)))
                ngeno=0
                fastphase_model.flush()
        print
        results.update({'cluster.freqs':pop_cluster_freq})

        ## Compute hapFLK
        counter.new('Computing hapFLK')
        myFLK=popgen.FLK_test(results["kinship"],diallelic=False)
        hapflk=np.zeros(results['dataset'].nsnp,dtype=float)
        hapflk_eigen=np.zeros((results['eigval'].shape[0],results['dataset'].nsnp),dtype=float)
        for e in range(myopts.nfit):
            for k in range(myopts.K):
                myFLK_res=np.apply_along_axis(myFLK.eval_flk,0,results['cluster.freqs'][e,k,])
                hapflk+=myFLK_res[1,]
                hapflk_eigen+=np.power(myFLK_res[3:,],2)
        hapflk/=myopts.nfit
        hapflk_eigen/=myopts.nfit
        results.update({'hapflk':hapflk,'eigen.hapflk':hapflk_eigen})
 
    ## Write out results and bye
    counter.new("Writing down results")
    ##results.show_contents()
    results.write_allele_frequencies(myopts.prefix+'.frq')
    results.write_flk_results(myopts.prefix+'.flk')
    if myopts.reynolds:
        results.write_SNP_reynolds(myopts.prefix+'.rey')
    if myopts.eigen:
        results.write_eigen_decomposition(myopts.prefix+'.eig')
        results.write_eigen_flk(myopts.prefix+'.flk.eig')
    if myopts.kfrq:
        try:
            results.write_cluster_frequencies(myopts.prefix+'.kfrq',outgroup=myopts.outgroup)
        except KeyError:
            pass
    if myopts.K>0:
        try:
            results.write_hapflk_results(myopts.prefix+'.hapflk')
            if myopts.eigen:
                results.write_eigen_hapflk(myopts.prefix+'.hapflk.eig')
        except KeyError:
            pass
    counter.end()
    
    
    
if __name__=='__main__':
    main()
