#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
        
               _ (.".) _    
              '-'/. .\'-'   
    pcalf      ( o o )     
                 `"-"`  jgs    


        pcalf stand for python CALcyanin Finder
        The calcyanin protein will be search and annotated within your input file(s) following three steps:
            1) the glycine triplication specific of calcyanin will be searched using a specific HMM profile.
            2) glycine zipper will be annotated individually using specifics HMM models for sequences with a glycine triplication.
            3) the N-ter extremity of sequences with a glyX3 will be annotated using blastp and known n-ter as subject database.
"""


import argparse
import gzip
import os
import sys
import shutil
import logging
import multiprocessing
from importlib import resources
import pandas as pd


import pcalf.datas
from pcalf.core import search,bioseq,biohmm,log


logging.basicConfig(
    format="%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s",
    level=logging.INFO,
    handlers=[]
)

DATASDIR = os.path.join(resources.files(pcalf.datas))

def get_args():
    parser = argparse.ArgumentParser(
        description="""pcalf
                        
    pcalf stand for Python CALcyanin Finder

    The input fasta file(s) will be scanned following three steps:
    - First a specific HMM profile of the glycine triplication is searched using HMMSEARCH and only sequence with a hit above coverage and E-value threshold are retained.
    - Secondly, those sequences are more precisely annotated using specifics HMM profiles for each glycine zipper in order to define their modular organization for their C-ter extremity.
    - Finally, a set of known N-ter is blasted against them and the nearest neighbor is retained as (potential) N-ter type.
    
    """,
        formatter_class= argparse.RawTextHelpFormatter
    )
    parser.add_argument('-i','--input', dest='input', 
                        type = str, default=None,
                        help="Either a fasta file [gzip supported] or a tabular file with one file \"designation\tfile path\" per line.")
                                  
    parser.add_argument('-o', dest='res_dir', type=str, required=True,
                        help='Output directory. pcalf will ouput several files including updated HMMs and MSAs, a feature table and a summary.')   

    parser.add_argument('--max-iteration', type=int, default = 1,
                        help="Number of iteration.")

    parser.add_argument('--iterative-update', action="store_true",
                        help="If set, HMM profiles will be updated by aligning one sequence by one sequence. ")

    parser.add_argument('--glyx3-msa', dest = 'glyx3_msa', 
                        default = DATASDIR + "/GlyX3.msa.fa", 
                        help='Path to GlyX3 msa (default: %(default)s). A weighted HMM will be built from it.')                                     

    parser.add_argument('--domz', dest='domz', type=int, default=10000,
                        help='Sequence space size for hmmsearch (default: %(default)s).')     

    parser.add_argument('--glyx3-coverage', dest='glyx3_coverage_threshold', 
                        type=float,default=0.6,
                        help="Minimal coverage (default: %(default)s)." )
    
    parser.add_argument('--glyx3-evalue', dest='glyx3_evalue_threshold', 
                        type=float,default=1e-21,
                        help="Glyx3 E-value when considering Glyx3 Hits (default: %(default)s).")
                        
    parser.add_argument('--gly1-msa', dest = 'gly1_msa', 
                        default = DATASDIR + "/Gly1.msa.fa", 
                        help='Path to GlyZip1 msa (default: %(default)s). A weighted HMM will be built from it.' )

    parser.add_argument('--gly2-msa', dest = 'gly2_msa', 
                        default = DATASDIR + "/Gly2.msa.fa", 
                        help='path to GlyZip2 msa (default: %(default)s). A weighted HMM will be built from it.')

    parser.add_argument('--gly3-msa', dest = 'gly3_msa', 
                        default = DATASDIR + "/Gly3.msa.fa", 
                        help='path to GlyZip3 msa (default: %(default)s). A weighted HMM will be built from it.')                        

    parser.add_argument('--glyzip-E-value', dest='glyzip_e_evalue', 
                        type=float,default=3.6e-4,
                        help = "Glyzip E-evalue when considering Gly(1|2|3) Hits (default: %(default)s)." )
    parser.add_argument('--glyzip-coverage', dest='glyzip_coverage', 
                        type=float,default=0.7,
                        help = "Glyzip coverage threshold when considering Gly(1|2|3) Hits (default: %(default)s)." )
    
    parser.add_argument('--nterdb', dest='nterdb', 
                        default = DATASDIR + "/nterdb.ref.tsv", 
                        help='Path to nterdb tabular file (default: %(default)s). The file must have three column: N-ter types, N-ter accessions and N-ter amino acid sequences.')

    parser.add_argument('--nter-coverage', dest='nter_coverage', 
                        type=float,default=0.80,
                        help="Nter minimal coverage (default: %(default)s).")

    parser.add_argument('--nter-evalue', dest='nter_evalue', 
                        type=float,default=1e-07,
                        help="Nter maximal E-value threshold (default: %(default)s).")
    
    parser.add_argument('--log', default = None , type=str)

    parser.add_argument('-q', '--quiet', action="store_true" , help="Silent stdout logging")

    parser.add_argument('--debug', action="store_true")

    parser.add_argument('--force', action = "store_true" ,help = "overwrite output directory")

    parser.add_argument('--threads', type=int, default = multiprocessing.cpu_count(),
                        help="(default: %(default)s)")
        
    args = parser.parse_args()
    
    return args


def main():
    args = get_args()

    # Setup Logger and Handlers
    level = logging.INFO
    if args.debug:
        level = logging.DEBUG
    if not args.quiet:
        console = logging.StreamHandler()
        console.setLevel(level)
        console.setFormatter(log.CustomFormatter())
        logging.getLogger('').addHandler(
            console
        )
    if args.log:
        os.makedirs(os.path.dirname(os.path.abspath(args.log)),exist_ok=True)
        fhandler = logging.FileHandler(args.log)
        fhandler.setLevel(level)
        fhandler.setFormatter(log.CustomFormatter())
        logging.getLogger('').addHandler(
            fhandler
        )
    
    if not shutil.which("blastp"):
        logging.error("blast not found, please, considere installing it using conda install -c bioconda blast.")
        exit(-1)
    else:
        logging.debug("blastp found : {}".format(shutil.which("blastp")))

    res_dir = os.path.abspath(args.res_dir)

    if os.path.exists(res_dir):
        if args.force:
            logging.debug("{} directory have been removed because of the --force flag.".format(res_dir))
            shutil.rmtree(res_dir)            
        else:
            logging.error("Output directory already exist , Bye !")
            exit(-1)
    

    # if (args.fasta and args.input) or (not args.fasta and not args.input) :
    #     logging.error("Please choose one between --fasta and --input-file...")
    #     exit(-1)

    fastas = []
    names = []
    if not args.input.endswith(".gz"): 
        with open(args.input,'r') as stream:   
            for line in stream.readlines():
                # Check if not a fasta file
                if not line.startswith(">"):
                    n,f = line.strip().split()
                    fastas.append(f)
                    names.append(n)
                else: # Fasta file
                    fastas.append(args.input)
                    names.append(os.path.basename(args.input))
                    break
    else: 
        with gzip.open(args.input ,'rt') as stream:
            for line in stream.readlines():                
                if not line.startswith(">"): # Not a fasta file
                    logging.error("{} is not a fasta file.".format(args.input))
                else: # fasta file 
                    fastas.append(args.input)
                    names.append(os.path.basename(args.input))
                    break


    logging.info("Init HMMs from MSAs.")    
    glyx3 = biohmm.Hmm("Glyx3",args.glyx3_msa)
    gly1 = biohmm.Hmm("Gly1",args.gly1_msa)
    gly2 = biohmm.Hmm("Gly2",args.gly2_msa)
    gly3 = biohmm.Hmm("Gly3",args.gly3_msa)
    
    logging.info("Increase Glycine weight.")
    glyx3.hmm = search.increase_glycine_weight( glyx3,0.2 )
    gly1.hmm  = search.increase_glycine_weight( gly1,0.2 )
    gly2.hmm  = search.increase_glycine_weight( gly2,0.2 )
    gly3.hmm  = search.increase_glycine_weight( gly3,0.2 )


    logging.info("Init N-Ter DB.")
    nterdb = search.parse_nterdb(args.nterdb)
    
    # Initital number of sequences.
    _glyx3_nseq = glyx3.hmm.nseq
    _gly1_nseq = gly1.hmm.nseq
    _gly2_nseq = gly2.hmm.nseq
    _gly3_nseq = gly3.hmm.nseq

    logging.info("Start search.")
    calseq , u_glyx3 , u_gly1, u_gly2 , u_gly3 , u_nter , seq_per_iteration = search.search(
        res_dir,
        fastas,
        names,
        glyx3,
        gly1,
        gly2,
        gly3,
        nterdb,        
        domz = args.domz,
        glyx3_evalue_threshold=args.glyx3_evalue_threshold,  
        glyx3_coverage_threshold=args.glyx3_coverage_threshold,  
        glyzip_evalue_threshold= args.glyzip_e_evalue,
        glyzip_coverage_threshold = args.glyzip_coverage,
        nter_coverage_threshold = args.nter_coverage, 
        nter_evalue_threshold  =  args.nter_evalue,
        max_iteration=args.max_iteration, 
        is_update_iterative=args.iterative_update,
    )
   
    logging.info("The search is over.")

    logging.info("# Number of calcyanin detected : {}".format(len(calseq.sequences)))
    logging.info("# N_seqs within Glyx3 HMM : {} [+{}]".format(u_glyx3.hmm.nseq, u_glyx3.hmm.nseq -  _glyx3_nseq))
    logging.info("# N_seqs within Gly1 HMM : {} [+{}]".format(u_gly1.hmm.nseq, u_gly1.hmm.nseq - _gly1_nseq))
    logging.info("# N_seqs within Gly2 HMM : {} [+{}]".format(u_gly2.hmm.nseq, u_gly2.hmm.nseq - _gly2_nseq))
    logging.info("# N_seqs within Gly3 HMM : {} [+{}]".format(u_gly3.hmm.nseq, u_gly3.hmm.nseq - _gly3_nseq))
    
    logging.info("Dumping HMMs.") 
    hmmdir = os.path.join(res_dir,"HMM")
    os.makedirs(hmmdir,exist_ok=True)
    u_glyx3.hmm.write(open(hmmdir + "/Glyx3.hmm","wb"))
    u_gly1.hmm.write(open(hmmdir + "/Gly1.hmm","wb"))
    u_gly2.hmm.write(open(hmmdir + "/Gly2.hmm","wb"))
    u_gly3.hmm.write(open(hmmdir + "/Gly3.hmm","wb"))

    logging.info("Dumping MSAs.")
    msadir = os.path.join(res_dir,"MSA")
    os.makedirs(msadir,exist_ok=True)
    u_glyx3.msa.write(open(msadir+"/Glyx3.msa.fa","wb"),format="afa")
    u_gly1.msa.write(open(msadir+"/Gly1.msa.fa","wb"),format="afa")
    u_gly2.msa.write(open(msadir+"/Gly2.msa.fa","wb"),format="afa")
    u_gly3.msa.write(open(msadir+"/Gly3.msa.fa","wb"),format="afa")
    
    logging.info("Dumping N-ter table.")
    with open(res_dir+"/N-ter-DB.tsv","w") as nterout:
        for nterid, n in u_nter.items():
            nterout.write("{}\t{}\t{}\n".format(n[0],nterid,n[1]))

    logging.info("Dumping feature table.")
    features = calseq.to_feature_table()
    features.to_csv(res_dir + "/pcalf.features.tsv",sep="\t",header=True,index=True)
    

    logging.info("Make summary.")
    summary_datas = []
    for sequence, seq_features_df in features.groupby("sequence_id"):
        cterom = ",".join(
                        list(
                            seq_features_df[seq_features_df.feature_id.isin(["Gly1","Gly2","Gly3"]) ].sort_values("feature_start").feature_id
                            )
                )

        nter_type = None
        nter_neighbor = None 
        if not seq_features_df[seq_features_df.feature_id == "N-ter" ].empty:
            nter_type,nter_neighbor = "".join(set(
                seq_features_df[seq_features_df.feature_id == "N-ter" ].feature_src
            )).split("||")

        summary_datas.append(
            {
                "sequence_accession":sequence,
                "sequence_src": "".join(set(seq_features_df.sequence_src.unique())),
                "flag":search.decision_tree(nter_type,cterom),
                "nter":nter_type,
                "nter_neighbor": nter_neighbor,
                "cter":cterom,
                "sequence":str(calseq.get_seq_by_id(sequence).seq),
                "iteration":seq_per_iteration[sequence]
            }
        )
    if summary_datas:
        df = pd.DataFrame(summary_datas).set_index("sequence_accession")
    else:
        df = pd.DataFrame(columns=["sequence_accession","sequence_src","flag","nter","nter_neighbor","cter","sequence","iteration"]).set_index("sequence_accession")
    df.to_csv(res_dir + "/pcalf.summary.tsv",sep="\t",index=True,header=True)
if __name__ == "__main__":
    sys.exit(main())
