#!/usr/bin/env python

"""
# ==============================================================================
# Author:       Carlos A. Ruiz Perez
# Email:        cruizperez3@gatech.edu
# Intitution:   Georgia Institute of Technology
# Version:      1.0.0
# Date:         Nov 13, 2020

# Description: MicrobeAnnotator parses protein fasta files and annotates them
# using several databases in an iterative fashion and summarizes the findings
# using KEGG modules based on KO numbers associated with best database matches.
# ==============================================================================
"""

# ==============================================================================
# Import modules
# ==============================================================================
from microbeannotator.pipeline import identifier_conversion as convert
from microbeannotator.pipeline import protein_search as search
from microbeannotator.pipeline import sqlite3_search
from microbeannotator.pipeline import ko_mapper
from microbeannotator.pipeline import hmmsearch

from microbeannotator.utilities import core_utils
from microbeannotator.utilities import fasta_filter_list as filterfasta
from microbeannotator.utilities.logging import setup_logger
from microbeannotator import version

from functools import partial
from shutil import copyfile
from shutil import rmtree
from pathlib import Path
from shutil import which
from typing import List

import multiprocessing
import pickle
import attr
# ==============================================================================


# ==============================================================================
# Initalize logger
# ==============================================================================
logger = setup_logger(__name__)
# ==============================================================================


# ==============================================================================
# Define classes
# ==============================================================================
# Class to setup and store databases
@attr.s()
class RefdataInitializer:
    # Attributes initialized at instantiation
    database: Path = attr.ib(converter=Path)
    eukaryote: bool = attr.ib()
    method: str = attr.ib()
    light: bool = attr.ib()
    # Attributes initialized by methods
    microbeannotator_db: Path = attr.ib(init=False, converter=Path)
    conversion_db: Path = attr.ib(init=False, converter=Path)
    ko_list_db: Path = attr.ib(init=False, converter=Path)
    ko_profiles_db: Path = attr.ib(init=False, converter=Path)
    ko_hmm_models: List[Path] = attr.ib(init=False)
    swissprot_db: Path = attr.ib(init=False)
    trembl_db: Path = attr.ib(init=False)
    refseq_db: Path = attr.ib(init=False)

    def get_databases(self):
        # Get SQLite databases and kofam data locations
        self.microbeannotator_db = self.database / 'microbeannotator.db'
        self.conversion_db = self.database / 'conversion.db'
        self.ko_list_db = self.database / 'kofam_data/ko_list'
        self.ko_profiles_db = self.database / 'kofam_data/profiles'
        # Get protein databases
        if self.method == 'blast':
            self.swissprot_db = str(self.database / 'protein_db/uniprot_sprot')
            if not self.light:
                self.trembl_db = (
                    str(self.database / 'protein_db/uniprot_trembl'))
                self.refseq_db = (
                    str(self.database / 'protein_db/refseq_protein'))
        elif self.method == 'diamond':
            self.swissprot_db = (
                str(self.database / 'protein_db/uniprot_sprot.dmnd'))
            if not self.light:
                self.trembl_db = (
                    str(self.database / 'protein_db/uniprot_trembl.dmnd'))
                self.refseq_db = (
                    str(self.database / 'protein_db/refseq_protein.dmnd'))
        elif self.method == 'sword':
            self.swissprot_db = (
                str(self.database / 'protein_db/uniprot_sprot.fasta'))
            if not self.light:
                self.trembl_db = (
                    str(self.database / 'protein_db/uniprot_trembl.fasta'))
                self.refseq_db = (
                    str(self.database / 'protein_db/refseq_protein.fasta'))
        else:
            logger.error(
                f"Search method {self.method} not recognized."
                f"Plase select one of blast, diamond, or sword.")
            exit(1)
        # Get KOfam models based on subset selected
        self.ko_hmm_models = []
        with open(self.ko_profiles_db / 'common.list', 'r') as infile:
            for line in infile:
                self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
        with open(self.ko_profiles_db / 'independent.list', 'r') as infile:
            for line in infile:
                self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
        if self.eukaryote:
            with open(self.ko_profiles_db / 'eukaryote.list', 'r') as infile:
                for line in infile:
                    self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
        else:
            with open(self.ko_profiles_db / 'prokaryote.list', 'r') as infile:
                for line in infile:
                    self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
# ==============================================================================
        
        
        
        
        
        
        
"""---1.0 Main Function---"""

def main():
    import argparse, sys, textwrap
    # Setup parser for arguments.
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
            description='''MicrobeAnnotator parses protein fasta files and annotates them\n'''
            '''using several databases in an iterative fashion and summarizes the findings\n'''
            '''using KEGG modules based on KO numbers associated with best database matches.\n'''
            '''Usage: ''' + sys.argv[0] + ''' -i [protein file] -o [output folder] -d [MicrobeAnnotator db folder]\n'''
            '''-m [search method]\n'''
            '''Global mandatory parameters: -i [protein file] -o [output folder] -d [MicrobeAnnotator db folder]\n'''
            '''-m [search method]\n'''
            '''Optional Database Parameters: See ''' + sys.argv[0] + ' -h')
    general_options = parser.add_argument_group('Mandatory i/o options.')
    general_options.add_argument('-i', '--input', dest='input_list', action='store', required=False, nargs='+',
                        help='Space-separated list of protein files to parse. Use -i OR -l.')
    general_options.add_argument('-l', '--list', dest='file_list', action='store', required=False,
                        help='File with list of inputs. Use -i OR -l.')
    general_options.add_argument('-o', '--outdir', dest='outdir', action='store', required=True,
                        help='Directory to store results.')
    general_options.add_argument('-d', '--database', dest='database', action='store', required=True,
                        help='Directory where MicrobeAnnotator databases are located.')
    search_options = parser.add_argument_group('Options for search process.')
    search_options.add_argument('-m', '--method', dest='method', action='store', required=True,
                        help='Method used to create databases and to perform seaches. One of "blast", "diamond" or "sword".')
    search_options.add_argument('--kofam_bin', dest='kofam_bin', action='store', required=False, default=None,
                        help='Directory where KOFamscan binaries are located. By default assumes it is in PATH.')
    search_options.add_argument('--method_bin', dest='method_bin', action='store', required=False, default=None,
                        help='Directory where search method binaries (blast, diamond, sword) are located. By default assumes it is in PATH.')
    search_options.add_argument('--id_perc', dest='id_perc', action='store', required=False, default=40, type=int,
                        help='Minimum identity percentage to retain a hit. By default 40.')
    search_options.add_argument('--bitscore', dest='bitscore', action='store', required=False, default=80, type=int,
                        help='Minimum bitscore to retain a hit. By default 80.')
    search_options.add_argument('--evalue', dest='evalue', action='store', required=False, default=0.01, type=float,
                        help='Maximum evalue to retain a hit. By default 0.01.')
    search_options.add_argument('--aln_percent', dest='aln_percent', action='store', required=False, default=70, type=int,
                        help='Minimum percentage of query covered by hit alignment. By default 70.')
    plot_options = parser.add_argument_group('Summary and plotting options.')
    plot_options.add_argument('--cluster', dest='cluster', action='store', required=False,
                        help=textwrap.dedent('''
                        Cluster genomes and/or modules. Select "cols" for genomes, "rows" for modules, or "both".
                        By default, no clustering
                        '''))
    plot_options.add_argument('--filename', dest='plot_filename', action='store', required=False, default='metabolic_summary_',
                        help='Prefix for output summary tables and plots. By default "metabolic_summary"')
    misc_options = parser.add_argument_group('Miscellaneous options.')
    misc_options.add_argument('-t', '--threads', dest='threads', action='store', required=False, default=1, type=int,
                        help='Threads to use per processed file, i.e. (per protein file). By default 1.')
    misc_options.add_argument('-p', '--processes', dest='processes', action='store', required=False, default=1, type=int,
                        help=textwrap.dedent('''
                        Number of processes to launch, i.e. number of protein files to process simultaneously.
                        Note this is different from threads. For more information see the README. By default 1.
                        '''))
    misc_options.add_argument('--light', dest='light', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Use only KOfamscan and swissprot databases. By default also uses refseq and
                        trembl (only use if you built both using "microbeannotator_db_builder").
                        '''))
    misc_options.add_argument('--eukaryote', dest='eukaryote', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Use only KOfamscan eukaryote models").
                        '''))
    misc_options.add_argument('--full', dest='full', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Do not perform the iterative annotation but search all proteins against all databases
                        (Increases computation time).
                        '''))
    misc_options.add_argument('--continue_run', dest='continue_run', action='store_true', required=False,
                        help=textwrap.dedent('''
                        If something went wrong when runnin MicrobeAnnotator, try to resume from the last completed step.
                        '''))
    misc_options.add_argument('--refine', dest='refine', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Complement the annotations by finding links to identifiers in other databases.
                        '''))
    misc_options.add_argument('--version', action='version',  version='MicrobeAnnotator v{}'.format(version),
                        help=textwrap.dedent('''Shows MicrobeAnnotator version'''))
    args = parser.parse_args()

    input_list = args.input_list
    file_list = args.file_list
    outdir = args.outdir
    outdir = Path(outdir)
    database = args.database
    method = args.method
    method = method.lower()
    kofam_bin = args.kofam_bin
    method_bin = args.method_bin
    id_perc = args.id_perc
    bitscore = args.bitscore
    evalue = args.evalue
    aln_percent = args.aln_percent
    threads = args.threads
    processes = args.processes
    light = args.light
    cluster = args.cluster
    if cluster != None:
        cluster = cluster.lower()
    plot_filename = args.plot_filename
    full = args.full
    refine = args.refine
    continue_run = args.continue_run
    eukaryote = args.eukaryote

    # ==========================================================================
    # Welcome message
    # ==========================================================================
    logger.info(f" --- This is MicrobeAnnotator v{version} --- ")
    # ==========================================================================

    # ==========================================================================
    # Check user input
    # ==========================================================================
    logger.info("Validating user inputs")
    core_utils.input_validator(
        method, method_bin, input_list, file_list,
        cluster, processes, threads)
    logger.info("Passed")
    if input_list != None:
        logger.info(
            f"Processing {len(input_list)} files. I will run {processes} "
            f"files in parallel with {threads} threads per file.")
    elif file_list != None:
        number_files = 0
        with open(file_list, 'r') as infile:
            for line in infile:
                number_files += 1
        logger.info(
            f"Processing {number_files} files. I will run {processes} "
            f"files in parallel with {threads} threads per file.")
    if light:
        logger.info(
            f"Running in light mode. Searching only KOfam and Swissprot "
            f"using {method}.")
    # ==========================================================================

    # ==========================================================================
    # Initialize logfile, refdata and output folders
    # ==========================================================================
    # Create log folder to track process
    process_step = 0
    process_log_folder = outdir / "process_log"
    process_log_folder.mkdir(parents=True, exist_ok=True)
    kofam_outdir = outdir / 'kofam_results'
    kofam_outdir.mkdir(parents=True, exist_ok=True)
    swissprot_outdir = outdir / 'swissprot_results'
    swissprot_outdir.mkdir(parents=True, exist_ok=True)
    trembl_outdir = outdir / 'trembl_results'
    trembl_outdir.mkdir(parents=True, exist_ok=True)
    refseq_outdir = outdir / 'refseq_results'
    refseq_outdir.mkdir(parents=True, exist_ok=True)
    annotation_outdir = outdir / 'annotation_results'
    annotation_outdir.mkdir(parents=True, exist_ok=True)
    temp_prot_folder = outdir / "temporal_proteins"
    temp_prot_folder.mkdir(parents=True, exist_ok=True)
    # Initialize reference data class
    refdata = RefdataInitializer(database, eukaryote, method, light)
    refdata.get_databases()
    # Create dictionary with information per protein file
    protein_file_info = {}
    starting_proteins = {}
    unannotated_proteins = {}
    # ==========================================================================

    # ==========================================================================
    # Check if run is starting from previous step.
    # ==========================================================================
    if continue_run:
        # Import the last step completed and the dictionary with information
        logger.info("Restarting from last checkpoint")
        with open(process_log_folder / "log.txt", 'r') as logfile:
            process_step = int(logfile.readline().strip())
            logger.info(f"Step {process_step}")
        with open(process_log_folder / "structure.pickle", 'rb') \
            as structure_file:
            protein_file_info = pickle.load(structure_file)
    # ==========================================================================
    
    # ==========================================================================
    # Initialize protein_file_info, starting and unannotated proteins dicts
    # ==========================================================================
    if file_list != None:
        input_list = []
        with open(file_list) as file_list_handler:
            for line in file_list_handler:
                input_list.append(line.strip())

    if process_step == 0:
        for protein_file in input_list:
            starting_filename = str(Path(protein_file).name)
            final_annotation = annotation_outdir / (starting_filename+'.annot')
            # Write headers to final annotation file
            with open(final_annotation, 'w') as output:
                output.write(
                    f"query_id\tprotein_id\tproduct\tko_number\tko_product\t"
                    f"taxonomy\tfunction_go\tcompartment_go\tprocess_go\t"
                    f"interpro\tpfam\tec_number\tdatabase\n")
            # Add initial information to protein_file_info
            # Structure per file:
            # protein_file : [protein_file_name, final_annotation_file]
            protein_file_info[protein_file] = [
                starting_filename, final_annotation]
            # Set initial starting and unannotated proteins
            starting_proteins[starting_filename] = []
            unannotated_proteins[starting_filename] = []
            with open(Path(protein_file), 'r') as proteins:
                for line in proteins:
                    if line.startswith(">"):
                        line = line.strip().split()[0].replace(">", "")
                        starting_proteins[starting_filename].append(line)
                        unannotated_proteins[starting_filename].append(line)
        process_step += 1
    # ==========================================================================

    # ==========================================================================
    # Perform KOfam searches
    # ==========================================================================
    if process_step == 1:
        # Search the initial dataset with KOFamscan
        logger.info("Searching proteins against KOfam profiles")
        # Set hmm profile information to global and parse them
        for protein_file in input_list:
            starting_filename = str(Path(protein_file).name)
            # Map proteins to hmm models
            protein_to_model = hmmsearch.protein_to_model_mapper(
                [protein_file], refdata.ko_hmm_models)
            # Set output
            outfile = kofam_outdir / (Path(protein_file).name + '.kofam')
            # Perform hmmsearch and filter results
            hmmsearch_results, annotated_proteins, \
                hypothetical_proteins = hmmsearch.hmmsearch_launcher(
                    protein_model_list=protein_to_model,
                    ko_list_db=refdata.ko_list_db,
                    output_path=outfile,
                    processes=processes, threads=threads)
            # Write filtered results to final annotation
            hmmsearch.write_hmmsearch_annotation(
                hmmsearch_results, protein_file_info[protein_file][1])
            # Set filename for next Fasta file and add to protein_file_info
            next_fasta = temp_prot_folder / f"{starting_filename}.step_2"
            protein_file_info[protein_file].append(next_fasta)
            # Remove annotated and hypothetical from unannotated list
            unannotated_proteins[starting_filename] = list(
                set(unannotated_proteins[starting_filename])
                - set(annotated_proteins))
            unannotated_proteins[starting_filename] = list(
                set(unannotated_proteins[starting_filename])
                - set(hypothetical_proteins))
            if full:
                copyfile(protein_file, next_fasta)
            else:
                # Remove annotated proteins from starting protein list
                starting_proteins[starting_filename] = list(
                    set(starting_proteins[starting_filename])
                    - set(annotated_proteins))
                filterfasta.fastA_filter_list(
                    protein_file, next_fasta,
                    starting_proteins[starting_filename], reverse=True)
        # When the process is complete, write step completed (+1) in the log folder
        # Also, export dictionary with information to be imported in case of continue
        process_step += 1
        with open(process_log_folder / "log.txt", 'w') as logfile:
            logfile.write("{}".format(process_step))
        with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
            pickle.dump(protein_file_info, structure_file)
    # ==========================================================================

    # ==========================================================================
    # Perform SwissProt searches
    # ==========================================================================
    if process_step == 2:
        logger.info("Searching proteins against Swissprot")
        # Get location of this step fasta files (.step_2)
        input_proteins = []
        for infor in protein_file_info.values():
            input_proteins.append(infor[-1])
        # Perform searches
        try:
            pool = multiprocessing.Pool(processes)
            arguments_to_pass = (
                outdir, 'swissprot', refdata.swissprot_db, method,
                threads, id_perc, bitscore, evalue, aln_percent, method_bin)
            search_results = pool.map(partial(search.similarity_search,
            multiple_arguments=arguments_to_pass), input_proteins)
            # Results as (protein_file, filtered_search_file)
        finally:
            pool.close()
            pool.join()
        # Add name of filtered_search_results to protein_file_info
        # New structure: 
        # protein_file : [protein_file_name, final_annotation_file, filtered_fasta_1it, Swissprot_Search_File]
        temp_dir = protein_file_info.copy()
        for result in search_results:
            for filename, info in temp_dir.items():
                if result[0] == info[2]:
                    protein_file_info[filename].append(result[1])
        del temp_dir

        # Search annotations in SQLite DB and append to the final annotation file
        print("Extracting Swissprot annotation data...")
        for original_file, information in protein_file_info.items():
            filtered_sword_results = information[3]
            final_annotation_file = information[1]
            significant_hits = []
            with open(filtered_sword_results) as swissprot_results:
                for line in swissprot_results:
                    line = line.strip().split()
                    significant_hits.append((line[0], line[1]))
            if len(significant_hits) == 0:
                # If no hits were found
                if light == True:
                    # Write all the proteins without annotation
                    with open(final_annotation_file, 'a') as final_annotation_fh:
                        for original_protein in unannotated_proteins[str(Path(original_file).name)]:
                            final_annotation_fh.write("{}\tNA\tNo match found\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n".format(original_protein))
                else:
                    # Copy the first iteration proteins to the second iteration file
                    second_it_outfile = str(temp_prot_folder / (protein_file_info[original_file][0] + ".step_3"))
                    copyfile(protein_file_info[original_file][2],second_it_outfile)
                    protein_file_info[original_file].append(second_it_outfile)
            else:
                # Extract the annotations
                annotation = sqlite3_search.search_ids_imported(refdata.microbeannotator_db, "swissprot", significant_hits)
                # Each annotation will have:
                # query_id gene_id accession product ko_number organism taxonomy function_GO compartment_GO process_GO interpro_id pfam_id EC_number
                if light == True:
                    # Write the annotations found in the annotation file
                    with open(final_annotation_file, 'a') as final_annotation_fh:
                        for match in annotation:
                            final_annotation_fh.write("{}\t{}\t{}\t{}\tNA\t{}\t{}\t{}\t{}\t{}\t{}\t{}\tswissprot\n".format(match[0],
                            match[1], match[3], match[4], match[6], match[7], match[8], match[9], match[10], match[11], match[12]))
                            if match[4] != "" and match[4] != "NA":
                                starting_proteins[str(Path(original_file).name)].remove(match[0])
                                if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                    unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                            elif match[3] != "":
                                if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                    unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                    # Check which proteins were not annotated and add information on those
                    # Extract annotated proteins
                    with open(final_annotation_file, 'a') as final_annotation_fh:
                        for original_protein in unannotated_proteins[str(Path(original_file).name)]:
                            final_annotation_fh.write("{}\tNA\tNo match found\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n".format(original_protein))
                else:
                    # Write the annotations found in the annotation file
                    with open(final_annotation_file, 'a') as final_annotation_fh:
                        for match in annotation:
                            final_annotation_fh.write("{}\t{}\t{}\t{}\tNA\t{}\t{}\t{}\t{}\t{}\t{}\t{}\tswissprot\n".format(match[0],
                            match[1], match[3], match[4], match[6], match[7], match[8], match[9], match[10], match[11], match[12]))
                            if full == False:
                                if match[4] != "" and match[4] != "NA":
                                    # Remove record with KO from those passing to the next step and the final list of unannotated
                                    starting_proteins[str(Path(original_file).name)].remove(match[0])
                                    if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                        unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                                elif match[3] != "":
                                    # If match is annotated remove from final list
                                    if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                        unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                            else: # Remove annotated from the final unannotated list
                                if match[3] != "":
                                    if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                        unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                    # Check which proteins were not annotated and filter for the next iteration
                    second_it_outfile = str(temp_prot_folder / (protein_file_info[original_file][0] + ".step_3"))
                    if full == True:
                        copyfile(protein_file_info[original_file][2],second_it_outfile)
                        protein_file_info[original_file].append(second_it_outfile)
                    else:
                        protein_file_info[original_file].append(second_it_outfile)
                        to_retain = starting_proteins[str(Path(original_file).name)].copy()
                        filterfasta.fastA_filter_list(protein_file_info[original_file][2], 
                        second_it_outfile, to_retain, reverse=False)
                        # starting_proteins[str(Path(original_file).name)]
                    # New structure: 
                    # protein_file : [protein_file_name, final_annotation_file, filtered_fasta_1it, Swissprot_Search_File,
                    # filtered_fasta_2it]
        # When the process is complete, write step completed (+1) in the log folder
        # Also, export dictionary with information to be imported in case of continue
        process_step += 1
        with open(process_log_folder / "log.txt", 'w') as logfile:
            logfile.write("{}".format(process_step))
        with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
            pickle.dump(protein_file_info, structure_file)
        # -----------------------
    
    
    
    # If running complete pipeline search against refseq
    if process_step == 3:
        if light == True:
            print("The light run of MicrobeAnnotator had finished succesfully!")
            print("Summarizing results...")
            process_step = 6
        else:
            print("Searching proteins against RefSeq...")
            # If we have a second iteration file then perform the searches againts refseq
            input_proteins = []
            for infor in protein_file_info.values():
                if len(infor) == 5:
                    input_proteins.append(infor[-1])
            if len(input_proteins) > 0:
                try:
                    pool = multiprocessing.Pool(processes)
                    arguments_to_pass = (outdir, 'refseq', refdata.refseq_db, method,
                                        threads, id_perc, bitscore, evalue, aln_percent, method_bin)
                    search_results = pool.map(partial(search.similarity_search,
                    multiple_arguments=arguments_to_pass), input_proteins)
                    # Results as (protein_file, filtered_search_file)
                finally:
                    pool.close()
                    pool.join()
            # Add name of filtered_search_results to protein_file_info
            # New structure: 
            # protein_file : [protein_file_name, final_annotation_file, filtered_fasta_1it, Swissprot_Search_File,
            # filtered_fasta_2it, RefSeq_Search_File]
            temp_dir = protein_file_info.copy()
            for result in search_results:
                for filename, info in temp_dir.items():
                    if len(info) == 5 and result[0] == info[4]:
                        protein_file_info[filename].append(result[1])
            del temp_dir
            # --------------------------

            # Search annotations in SQLite DB and append to the final annotation file
            print("Extracting RefSeq annotation data...\n")
            for original_file, information in protein_file_info.items():
                if len(information) == 6:
                    filtered_sword_results = information[5]
                    final_annotation_file = information[1]
                    significant_hits = []
                    with open(filtered_sword_results) as swissprot_results:
                        for line in swissprot_results:
                            line = line.strip().split()
                            significant_hits.append((line[0], line[1]))
                    # If no hits were found
                    if len(significant_hits) == 0:
                        third_it_outfile = str(temp_prot_folder / (protein_file_info[original_file][0] + ".3it"))
                        copyfile(protein_file_info[original_file][4],third_it_outfile)
                        protein_file_info[original_file].append(third_it_outfile)
                    else:
                        annotation = sqlite3_search.search_ids_imported(refdata.microbeannotator_db, "refseq", significant_hits)
                        # Each annotation will have:
                        # query_id gene_id product taxonomy ko_number ko_product
                        # Write the annotations found in the annotation file

                        with open(final_annotation_file, 'a') as final_annotation_fh:
                            for match in annotation:
                                final_annotation_fh.write("{}\t{}\t{}\t{}\t{}\t{}\tNA\tNA\tNA\tNA\tNA\t{}\trefseq\n".format(match[0],
                                match[1], match[2], match[5], match[6], match[3], match[4]))
                                if full == False:
                                    if match[5] != "" and match[5] != "NA":
                                        # Remove record with KO from those passing to the next step and the final list of unannotated
                                        starting_proteins[str(Path(original_file).name)].remove(match[0])
                                        if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                            unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                                    elif match[2] != "":
                                        # If match is annotated remove from final list
                                        if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                            unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                                else: # Remove annotated from the final unannotated list
                                    if match[2] != "":
                                        if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                            unannotated_proteins[str(Path(original_file).name)].remove(match[0])

                        # Check which proteins were not annotated and filter for the next iteration
                        third_it_outfile = str(temp_prot_folder / (protein_file_info[original_file][0] + ".3it"))
                        if full == True:
                            copyfile(protein_file_info[original_file][4],third_it_outfile)
                            protein_file_info[original_file].append(third_it_outfile)
                        else:
                            # Check if all annotations have ko numbers, if not, get ids to next round of iteration
                            protein_file_info[original_file].append(third_it_outfile)
                            to_retain = starting_proteins[str(Path(original_file).name)].copy()
                            filterfasta.fastA_filter_list(protein_file_info[original_file][2], 
                            third_it_outfile, to_retain, reverse=False)
                            # New structure: 
                            # protein_file : [protein_file_name, final_annotation_file, filtered_fasta_1it, Swissprot_Search_File,
                            # filtered_fasta_2it, RefSeq_Search_File, filtered_fasta_3it]
            # When the process is complete, write step completed (+1) in the log folder
            # Also, export dictionary with information to be imported in case of continue  
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
                pickle.dump(protein_file_info, structure_file)
        # --------------------------


        # Finally run remaining proteins against Trembl
    if process_step == 4:
        if light == True:
            print("The light run of MicrobeAnnotator had finished succesfully!")
            print("Summarizing those results...")
            process_step = 5
        else:
            print("Searching proteins against Trembl...\n")
            input_proteins = []
            for infor in protein_file_info.values():
                if len(infor) == 7:
                    input_proteins.append(infor[-1])
            if len(input_proteins) > 0:
                try:
                    pool = multiprocessing.Pool(processes)
                    arguments_to_pass = (outdir, 'trembl', refdata.trembl_db, method,
                                        threads, id_perc, bitscore, evalue, aln_percent, method_bin)
                    search_results = pool.map(partial(search.similarity_search,
                    multiple_arguments=arguments_to_pass), input_proteins)
                    # Results as (protein_file, filtered_search_file)
                finally:
                    pool.close()
                    pool.join()
            # Add name of filtered_search_results to protein_file_info
            # New structure: 
            # protein_file : [protein_file_name, final_annotation_file, filtered_fasta_1it, Swissprot_Search_File,
            # filtered_fasta_2it, RefSeq_Search_File, filtered_fasta_3it, Trembl_Search_File]
            temp_dir = protein_file_info.copy()
            for result in search_results:
                for filename, info in temp_dir.items():
                    if len(info) == 7 and result[0] == info[6]:
                        protein_file_info[filename].append(result[1])
            del temp_dir
            # --------------------------
            # Search annotations in SQLite DB and append to the final annotation file
            print("Extracting Trembl annotation data...\n")
            for original_file, information in protein_file_info.items():
                if len(information) == 8:
                    filtered_sword_results = information[7]
                    final_annotation_file = information[1]
                    significant_hits = []
                    with open(filtered_sword_results) as swissprot_results:
                        for line in swissprot_results:
                            line = line.strip().split()
                            significant_hits.append((line[0], line[1]))
                    if len(significant_hits) == 0:
                        with open(final_annotation_file, 'a') as final_annotation_fh:
                            for original_protein in unannotated_proteins[str(Path(original_file).name)]:
                                final_annotation_fh.write("{}\tNA\tNo match found\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n".format(original_protein))
                    else:
                        annotation = sqlite3_search.search_ids_imported(refdata.microbeannotator_db, "trembl", significant_hits)
                        # Each annotation will have:
                        # query_id gene_id product taxonomy ko_number ko_product
                        with open(final_annotation_file, 'a') as final_annotation_fh:
                            for match in annotation:
                                final_annotation_fh.write("{}\t{}\t{}\t{}\tNA\t{}\t{}\t{}\t{}\t{}\t{}\t{}\ttrembl\n".format(match[0],
                                match[1], match[3], match[4], match[6], match[7], match[8], match[9], match[10], match[11], match[12]))
                                if match[4] != "" and match[4] != "NA":
                                    # Remove record with KO from those passing to the next step and the final list of unannotated
                                    starting_proteins[str(Path(original_file).name)].remove(match[0])
                                    if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                        unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                                elif match[3] != "":
                                    # If match is annotated remove from final list
                                    if match[0] in unannotated_proteins[str(Path(original_file).name)]:
                                        unannotated_proteins[str(Path(original_file).name)].remove(match[0])
                                
                        with open(final_annotation_file, 'a') as final_annotation_fh:
                            for original_protein in unannotated_proteins[str(Path(original_file).name)]:
                                final_annotation_fh.write("{}\tNA\tNo match found\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n".format(original_protein))
            if temp_prot_folder.is_dir():
                rmtree(temp_prot_folder)
            # When the process is complete, write step completed (+1) in the log folder
            # Also, export dictionary with information to be imported in case of continue  
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
                pickle.dump(protein_file_info, structure_file)
    # ------------------------

    # Refine annotations with ids from other databases
    if process_step == 5:
        if refine == False:
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
                pickle.dump(protein_file_info, structure_file)
        else:
            import pandas as pd
            print("Improving annotations by searching matches in other databases...")
            # Structure: 
            # protein_file : [protein_file_name, final_annotation_file, filtered_fasta_1it, Swissprot_Search_File,
            # filtered_fasta_2it, RefSeq_Search_File, filtered_fasta_3it, Trembl_Search_File]
            final_annotation_files = []
            for information in protein_file_info.values():
                final_annotation_files.append(str(information[1]))
            
            # Process each annotation file separately
            for annotation_file in final_annotation_files:
                positive_matches = {}
                annotation_table = pd.read_csv(annotation_file, sep="\t", header=0, index_col=0)
                # Parse Uniprot Records
                swissprot_records = annotation_table.loc[(annotation_table['database'] == "swissprot") & (annotation_table['ko_number'] == "NA"), ]
                # Extract records without ko numbers but with ec number
                swissprot_records_ko = swissprot_records.loc[(swissprot_records['ko_number'] == "NA") & (swissprot_records['ec_number'] != "NA"), ]
                swissprot_records_ko_ids = list(swissprot_records_ko['ec_number'])
                if len(swissprot_records_ko_ids) > 0:
                    positive_matches = convert.convert_ko_to_ec(swissprot_records_ko_ids, refdata.conversion_db, True)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            annotation_table.loc[(annotation_table['ec_number'] == identifier) & (annotation_table['ko_number'] == "NA"), "ko_number"] = " ".join(matches)
                # Delete used sub-tables
                del swissprot_records
                del swissprot_records_ko
                
                # Parse Trembl Records
                trembl_records = annotation_table.loc[(annotation_table['database'] == "trembl") & (annotation_table['ko_number'] == "NA"), ]
                # Extract records without ko numbers but with ec number
                trembl_records_ko = trembl_records.loc[(trembl_records['ko_number'] == "NA") & (trembl_records['ec_number'] != "NA"), ]
                trembl_records_ko_ids = list(trembl_records_ko['ec_number'])
                if len(trembl_records_ko_ids) > 0:
                    positive_matches = convert.convert_ko_to_ec(trembl_records_ko_ids, refdata.conversion_db, True)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            annotation_table.loc[(annotation_table['ec_number'] == identifier) & (annotation_table['ko_number'] == "NA"), "ko_number"] = " ".join(matches)
                # Delete used sub-tables
                del trembl_records
                del trembl_records_ko

                # Parse RefSeq Records
                refseq_records = annotation_table.loc[(annotation_table.database == "refseq") & (annotation_table['ko_number'] == "NA"), ]
                # Extract records without ko numbers but with ec number
                refseq_records_ko = refseq_records.loc[(refseq_records['ko_number'] == "NA") & (refseq_records['ec_number'] != "NA"), ]
                refseq_records_ko_ids = list(refseq_records_ko['ec_number'])
                if len(refseq_records_ko_ids) > 0:
                    positive_matches = convert.convert_ko_to_ec(refseq_records_ko_ids, refdata.conversion_db, True)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            annotation_table.loc[(annotation_table['ec_number'] == identifier), "ko_number"] = " ".join(matches)
                # Extract refseq ids for records with no ko
                refseq_records = annotation_table.loc[(annotation_table.database == "refseq") & (annotation_table['ko_number'] == "NA"), ]
                refseq_records_ids = list(refseq_records['protein_id'])
                if len(refseq_records_ids) > 0:
                    refseq_records_indices = refseq_records.index
                    query_refseq_pair = list(zip(refseq_records_indices, refseq_records_ids))
                    # Convert refseq to uniprot
                    refseq_uniprot = {}
                    positive_matches = convert.convert_refseq_to_uniprot(refseq_records_ids, refdata.conversion_db, False)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            refseq_uniprot[identifier] = matches[0]
                        query_uniprot_pair = []
                        for pair in query_refseq_pair:
                            if pair[1] in refseq_uniprot:
                                query_uniprot_pair.append((pair[refseq_records_indices], refseq_uniprot[pair[1]]))
                        # Now estract the uniprot annotations using the identifiers found in the previous step
                        uniprot_annotations = sqlite3_search.search_ids_imported(refdata.microbeannotator_db, "trembl", query_uniprot_pair)
                        if len(uniprot_annotations) > 0:
                            # Parse results and modify the original table
                            for annotation in uniprot_annotations:
                                original_index = annotation.pop(0)
                                annotation_table.loc[original_index] = annotation
                # Save the refined table in the original location
                pd.DataFrame.to_csv(annotation_table, sep="\t", index=0, header=True)
            # When the process is complete, write step completed (+1) in the log folder
            # Also, export dictionary with information to be imported in case of continue
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
                pickle.dump(protein_file_info, structure_file)
            print("Done!")
    # ------------------------

    # Parse annotation files and summarize them
    if process_step == 6:
        print("Extracting ko numbers and summarizing results...")
        annotation_files = []
        for information in protein_file_info.values():
            annotation_folder = information[1].parent
            ko_numbers = str(annotation_folder / (information[0] + ".ko"))
            with open(information[1], 'r') as annotations, open(ko_numbers, 'w') as ko_present:
                for line in annotations:
                    line = line.strip().split("\t")
                    # If there are multiple KOs
                    if line[3] != "":
                        if len(line[3].split()) > 1:
                            for element in line[3].split():
                                ko_present.write("{}\n".format(element))
                        else:
                            ko_present.write("{}\n".format(line[3]))
            annotation_files.append(ko_numbers)
        prefix = str(Path(outdir) / plot_filename)
        regular_modules, bifurcation_modules, structural_modules,  \
        module_information, metabolism_matrix, module_group_matrix = ko_mapper.module_information_importer(annotation_files)
        metabolic_annotation = ko_mapper.global_mapper(regular_modules, bifurcation_modules, structural_modules, annotation_files)
        metabolism_matrix_dropped_relabel, module_colors = ko_mapper.create_output_files(metabolic_annotation, metabolism_matrix, module_information, cluster, prefix)
        ko_mapper.plot_function_barplots(module_colors, module_group_matrix, metabolism_matrix_dropped_relabel, prefix)
        print("MicrobeAnnotator has finished succefully!")
    # ------------------------

if __name__ == "__main__":
    main()

