#!/usr/bin/env python

"""
# ==============================================================================
# Author:       Carlos A. Ruiz Perez
# Email:        cruizperez3@gatech.edu
# Intitution:   Georgia Institute of Technology
# Version:      1.0.0
# Date:         Nov 13, 2020

# Description: MicrobeAnnotator parses protein fasta files and annotates them
# using several databases in an iterative fashion and summarizes the findings
# using KEGG modules based on KO numbers associated with best database matches.
# ==============================================================================
"""

# ==============================================================================
# Import modules
# ==============================================================================
from microbeannotator.pipeline import identifier_conversion as convert
from microbeannotator.pipeline import protein_search as search
from microbeannotator.pipeline import sqlite3_search
from microbeannotator.pipeline import ko_mapper
from microbeannotator.pipeline import hmmsearch

from microbeannotator.utilities import core_utils
from microbeannotator.utilities import fasta_filter_list as filterfasta
from microbeannotator.utilities.logging import setup_logger
from microbeannotator import version

from functools import partial
from shutil import copyfile
from shutil import rmtree
from pathlib import Path
from shutil import which
from typing import List

import multiprocessing
import pickle
import attr
# ==============================================================================


# ==============================================================================
# Initalize logger
# ==============================================================================
logger = setup_logger(__name__)
# ==============================================================================


# ==============================================================================
# Define classes
# ==============================================================================
# Class to setup and store databases
@attr.s()
class RefdataInitializer:
    # Attributes initialized at instantiation
    database: Path = attr.ib(converter=Path)
    eukaryote: bool = attr.ib()
    method: str = attr.ib()
    light: bool = attr.ib()
    # Attributes initialized by methods
    microbeannotator_db: Path = attr.ib(init=False, converter=Path)
    conversion_db: Path = attr.ib(init=False, converter=Path)
    ko_list_db: Path = attr.ib(init=False, converter=Path)
    ko_profiles_db: Path = attr.ib(init=False, converter=Path)
    ko_hmm_models: List[Path] = attr.ib(init=False)
    swissprot_db: Path = attr.ib(init=False)
    trembl_db: Path = attr.ib(init=False)
    refseq_db: Path = attr.ib(init=False)

    def get_databases(self):
        # Get SQLite databases and kofam data locations
        self.microbeannotator_db = self.database / 'microbeannotator.db'
        self.conversion_db = self.database / 'conversion.db'
        self.ko_list_db = self.database / 'kofam_data/ko_list'
        self.ko_profiles_db = self.database / 'kofam_data/profiles'
        # Get protein databases
        if self.method == 'blast':
            self.swissprot_db = str(self.database / 'protein_db/uniprot_sprot')
            if not self.light:
                self.trembl_db = (
                    str(self.database / 'protein_db/uniprot_trembl'))
                self.refseq_db = (
                    str(self.database / 'protein_db/refseq_protein'))
        elif self.method == 'diamond':
            self.swissprot_db = (
                str(self.database / 'protein_db/uniprot_sprot.dmnd'))
            if not self.light:
                self.trembl_db = (
                    str(self.database / 'protein_db/uniprot_trembl.dmnd'))
                self.refseq_db = (
                    str(self.database / 'protein_db/refseq_protein.dmnd'))
        elif self.method == 'sword':
            self.swissprot_db = (
                str(self.database / 'protein_db/uniprot_sprot.fasta'))
            if not self.light:
                self.trembl_db = (
                    str(self.database / 'protein_db/uniprot_trembl.fasta'))
                self.refseq_db = (
                    str(self.database / 'protein_db/refseq_protein.fasta'))
        else:
            logger.error(
                f"Search method {self.method} not recognized."
                f"Plase select one of blast, diamond, or sword.")
            exit(1)
        # Get KOfam models based on subset selected
        self.ko_hmm_models = []
        with open(self.ko_profiles_db / 'common.list', 'r') as infile:
            for line in infile:
                self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
        with open(self.ko_profiles_db / 'independent.list', 'r') as infile:
            for line in infile:
                self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
        if self.eukaryote:
            with open(self.ko_profiles_db / 'eukaryote.list', 'r') as infile:
                for line in infile:
                    self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
        else:
            with open(self.ko_profiles_db / 'prokaryote.list', 'r') as infile:
                for line in infile:
                    self.ko_hmm_models.append(self.ko_profiles_db / line.strip())
# ==============================================================================
        
        
        
        
        
        
        
"""---1.0 Main Function---"""

def main():
    import argparse, sys, textwrap
    # Setup parser for arguments.
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
            description='''MicrobeAnnotator parses protein fasta files and annotates them\n'''
            '''using several databases in an iterative fashion and summarizes the findings\n'''
            '''using KEGG modules based on KO numbers associated with best database matches.\n'''
            '''Usage: ''' + sys.argv[0] + ''' -i [protein file] -o [output folder] -d [MicrobeAnnotator db folder]\n'''
            '''-m [search method]\n'''
            '''Global mandatory parameters: -i [protein file] -o [output folder] -d [MicrobeAnnotator db folder]\n'''
            '''-m [search method]\n'''
            '''Optional Database Parameters: See ''' + sys.argv[0] + ' -h')
    general_options = parser.add_argument_group('Mandatory i/o options.')
    general_options.add_argument('-i', '--input', dest='input_list', action='store', required=False, nargs='+',
                        help='Space-separated list of protein files to parse. Use -i OR -l.')
    general_options.add_argument('-l', '--list', dest='file_list', action='store', required=False,
                        help='File with list of inputs. Use -i OR -l.')
    general_options.add_argument('-o', '--outdir', dest='outdir', action='store', required=True,
                        help='Directory to store results.')
    general_options.add_argument('-d', '--database', dest='database', action='store', required=True,
                        help='Directory where MicrobeAnnotator databases are located.')
    search_options = parser.add_argument_group('Options for search process.')
    search_options.add_argument('-m', '--method', dest='method', action='store', required=True,
                        help='Method used to create databases and to perform seaches. One of "blast", "diamond" or "sword".')
    search_options.add_argument('--method_bin', dest='method_bin', action='store', required=False, default=None,
                        help='Directory where search method binaries (blast, diamond, sword) are located. By default assumes it is in PATH.')
    search_options.add_argument('--id_perc', dest='id_perc', action='store', required=False, default=40, type=int,
                        help='Minimum identity percentage to retain a hit. By default 40.')
    search_options.add_argument('--bitscore', dest='bitscore', action='store', required=False, default=80, type=int,
                        help='Minimum bitscore to retain a hit. By default 80.')
    search_options.add_argument('--evalue', dest='evalue', action='store', required=False, default=0.01, type=float,
                        help='Maximum evalue to retain a hit. By default 0.01.')
    search_options.add_argument('--aln_percent', dest='aln_percent', action='store', required=False, default=70, type=int,
                        help='Minimum percentage of query covered by hit alignment. By default 70.')
    plot_options = parser.add_argument_group('Summary and plotting options.')
    plot_options.add_argument('--cluster', dest='cluster', action='store', required=False,
                        help=textwrap.dedent('''
                        Cluster genomes and/or modules. Select "cols" for genomes, "rows" for modules, or "both".
                        By default, no clustering
                        '''))
    plot_options.add_argument('--filename', dest='plot_filename', action='store', required=False, default='metabolic_summary_',
                        help='Prefix for output summary tables and plots. By default "metabolic_summary"')
    misc_options = parser.add_argument_group('Miscellaneous options.')
    misc_options.add_argument('-t', '--threads', dest='threads', action='store', required=False, default=1, type=int,
                        help='Threads to use per processed file, i.e. (per protein file). By default 1.')
    misc_options.add_argument('-p', '--processes', dest='processes', action='store', required=False, default=1, type=int,
                        help=textwrap.dedent('''
                        Number of processes to launch, i.e. number of protein files to process simultaneously.
                        Note this is different from threads. For more information see the README. By default 1.
                        '''))
    misc_options.add_argument('--light', dest='light', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Use only KOfamscan and swissprot databases. By default also uses refseq and
                        trembl (only use if you built both using "microbeannotator_db_builder").
                        '''))
    misc_options.add_argument('--eukaryote', dest='eukaryote', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Use only KOfamscan eukaryote models").
                        '''))
    misc_options.add_argument('--full', dest='full', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Do not perform the iterative annotation but search all proteins against all databases
                        (Increases computation time).
                        '''))
    misc_options.add_argument('--continue_run', dest='continue_run', action='store_true', required=False,
                        help=textwrap.dedent('''
                        If something went wrong when runnin MicrobeAnnotator, try to resume from the last completed step.
                        '''))
    misc_options.add_argument('--refine', dest='refine', action='store_true', required=False,
                        help=textwrap.dedent('''
                        Complement the annotations by finding links to identifiers in other databases.
                        '''))
    misc_options.add_argument('--version', action='version',  version='MicrobeAnnotator v{}'.format(version),
                        help=textwrap.dedent('''Shows MicrobeAnnotator version'''))
    args = parser.parse_args()

    input_list = args.input_list
    file_list = args.file_list
    outdir = args.outdir
    outdir = Path(outdir)
    database = args.database
    method = args.method
    method = method.lower()
    method_bin = args.method_bin
    id_perc = args.id_perc
    bitscore = args.bitscore
    evalue = args.evalue
    aln_percent = args.aln_percent
    threads = args.threads
    processes = args.processes
    light = args.light
    cluster = args.cluster
    if cluster != None:
        cluster = cluster.lower()
    plot_filename = args.plot_filename
    full = args.full
    refine = args.refine
    continue_run = args.continue_run
    eukaryote = args.eukaryote

    # ==========================================================================
    # Welcome message
    # ==========================================================================
    logger.info(f" ---- This is MicrobeAnnotator v{version} ---- ")
    # ==========================================================================

    # ==========================================================================
    # Check user input
    # ==========================================================================
    logger.info("Validating user inputs")
    core_utils.input_validator(
        method, method_bin, input_list, file_list,
        cluster, processes, threads)
    logger.info("Passed")
    if input_list != None:
        logger.info(
            f"Processing {len(input_list)} files. I will run {processes} "
            f"files in parallel with {threads} threads per file.")
    elif file_list != None:
        number_files = 0
        with open(file_list, 'r') as infile:
            for line in infile:
                number_files += 1
        logger.info(
            f"Processing {number_files} files. I will run {processes} "
            f"files in parallel with {threads} threads per file.")
    if light:
        logger.info(
            f"Running in light mode. Searching only KOfam and Swissprot "
            f"using {method}.")
    # ==========================================================================

    # ==========================================================================
    # Initialize logfile, refdata and output folders
    # ==========================================================================
    # Create log folder to track process
    process_step = 0
    process_log_folder = outdir / "process_log"
    process_log_folder.mkdir(parents=True, exist_ok=True)
    kofam_outdir = outdir / 'kofam_results'
    kofam_outdir.mkdir(parents=True, exist_ok=True)
    swissprot_outdir = outdir / 'swissprot_results'
    swissprot_outdir.mkdir(parents=True, exist_ok=True)
    trembl_outdir = outdir / 'trembl_results'
    trembl_outdir.mkdir(parents=True, exist_ok=True)
    refseq_outdir = outdir / 'refseq_results'
    refseq_outdir.mkdir(parents=True, exist_ok=True)
    annotation_outdir = outdir / 'annotation_results'
    annotation_outdir.mkdir(parents=True, exist_ok=True)
    temp_prot_folder = outdir / "temporal_proteins"
    temp_prot_folder.mkdir(parents=True, exist_ok=True)
    # Initialize reference data class
    refdata = RefdataInitializer(database, eukaryote, method, light)
    refdata.get_databases()
    # Create dictionary with information per protein file
    protein_file_info = {}
    starting_proteins = {}
    unannotated_proteins = {}
    # ==========================================================================

    # ==========================================================================
    # Check if run is starting from previous step.
    # ==========================================================================
    if continue_run:
        # Import the last step completed and the dictionary with information
        logger.info("Restarting from last checkpoint")
        with open(process_log_folder / "log.txt", 'r') as logfile:
            process_step = int(logfile.readline().strip())
            logger.info(f"Step {process_step}")
        with open(process_log_folder / "structure.pickle", 'rb') \
            as structure_file:
            protein_file_info = pickle.load(structure_file)
    # ==========================================================================
    
    # ==========================================================================
    # Initialize protein_file_info, starting and unannotated proteins dicts
    # ==========================================================================
    if file_list != None:
        input_list = []
        with open(file_list) as file_list_handler:
            for line in file_list_handler:
                input_list.append(line.strip())

    if process_step == 0:
        for protein_file in input_list:
            starting_filename = str(Path(protein_file).name)
            final_annotation = annotation_outdir / (starting_filename+'.annot')
            # Write headers to final annotation file
            with open(final_annotation, 'w') as output:
                output.write(
                    f"query_id\tprotein_id\tproduct\tko_number\tko_product\t"
                    f"taxonomy\tfunction_go\tcompartment_go\tprocess_go\t"
                    f"interpro\tpfam\tec_number\tdatabase\n")
            # Add initial information to protein_file_info
            # Structure per file:
            # protein_file : [protein_file_name, final_annotation_file]
            protein_file_info[protein_file] = [
                starting_filename, final_annotation]
            # Set initial starting and unannotated proteins
            starting_proteins[starting_filename] = []
            unannotated_proteins[starting_filename] = []
            with open(Path(protein_file), 'r') as proteins:
                for line in proteins:
                    if line.startswith(">"):
                        line = line.strip().split()[0].replace(">", "")
                        starting_proteins[starting_filename].append(line)
                        unannotated_proteins[starting_filename].append(line)
        process_step += 1
    # ==========================================================================

    # ==========================================================================
    # Perform KOfam searches
    # ==========================================================================
    if process_step == 1:
        # Search the initial dataset with KOFamscan
        logger.info("Searching proteins against KOfam profiles")
        # Set hmm profile information to global and parse them
        for protein_file in input_list:
            starting_filename = str(Path(protein_file).name)
            # Map proteins to hmm models
            protein_to_model = hmmsearch.protein_to_model_mapper(
                [protein_file], refdata.ko_hmm_models)
            # Set output and append to structure dictionary
            outfile = kofam_outdir / (Path(protein_file).name + '.kofam')
            # Structure per file:
            # protein_file : [protein_file_name, final_annotation_file,
            # kofam_results]
            protein_file_info[protein_file].append(outfile)
            # Perform hmmsearch and filter results
            hmmsearch_results, annotated_proteins, \
                hypothetical_proteins = hmmsearch.hmmsearch_launcher(
                    protein_model_list=protein_to_model,
                    ko_list_db=refdata.ko_list_db,
                    output_path=outfile,
                    processes=processes, threads=threads)
            # Write filtered results to final annotation
            hmmsearch.write_hmmsearch_annotation(
                hmmsearch_results, protein_file_info[protein_file][1])
            # Set filename for next Fasta file and add to protein_file_info
            next_fasta = temp_prot_folder / f"{starting_filename}.step_2"
            # Remove annotated and hypothetical from unannotated list
            unannotated_proteins[starting_filename] = list(
                set(unannotated_proteins[starting_filename])
                - set(annotated_proteins))
            unannotated_proteins[starting_filename] = list(
                set(unannotated_proteins[starting_filename])
                - set(hypothetical_proteins))
            if full:
                copyfile(protein_file, next_fasta)
            else:
                # Remove annotated proteins from starting protein list
                starting_proteins[starting_filename] = list(
                    set(starting_proteins[starting_filename])
                    - set(annotated_proteins))
                # If there are proteins without KO in starting_proteins
                if len(starting_proteins[starting_filename]) > 0:
                    # Filter fasta file and store into next file
                    to_retain = starting_proteins[starting_filename].copy()
                    filterfasta.fastA_filter_list(
                        protein_file, next_fasta,
                        to_retain, reverse=False)
                    # Structure per file:
                # protein_file : [protein_file_name, final_annotation_file,
                # kofam_results, fasta_setp_2]
                protein_file_info[protein_file].append(next_fasta)
            
            
        # When the process is complete, write step completed (+1)
        # in the log folder. Also, export dictionary with information
        # to be imported in case of continue
        process_step += 1
        with open(process_log_folder / "log.txt", 'w') as logfile:
            logfile.write(f"{process_step}")
        with open(process_log_folder / "structure.pickle", 'wb') \
            as structure_file:
            pickle.dump(protein_file_info, structure_file)
    # ==========================================================================

    # ==========================================================================
    # Perform SwissProt searches
    # ==========================================================================
    if process_step == 2:
        logger.info("Searching proteins against Swissprot")
        # Get location of this step fasta files (.step_2)
        input_proteins = []
        for information in protein_file_info.values():
            if len(information) == 4:
                input_proteins.append(information[-1])
        if len(input_proteins) > 0:
            # Perform searches
            try:
                pool = multiprocessing.Pool(processes)
                arguments_to_pass = (
                    outdir, 'swissprot', refdata.swissprot_db, method,
                    threads, id_perc, bitscore, evalue, aln_percent, method_bin)
                search_results = pool.map(partial(search.similarity_search,
                multiple_arguments=arguments_to_pass), input_proteins)
                # Results as (protein_file, filtered_search_file)
            finally:
                pool.close() 
                pool.join()
            # Add name of filtered_search_results to protein_file_info
            # New structure:
            # protein_file : [protein_file_name, final_annotation_file,
            # kofam_results, fasta_step_2, Swissprot_Search_File]
            temp_dir = protein_file_info.copy()
            for result in search_results:
                for filename, info in temp_dir.items():
                    if result[0] == info[3]:
                        protein_file_info[filename].append(result[1])
            del temp_dir

        # Search annotations in SQLite DB and
        # append to the final annotation file
        logger.info("Extracting Swissprot annotation data")
        for original_file, information in protein_file_info.items():
            if len(information) != 5:
                continue
            starting_filename = information[0]
            filtered_results = information[4]
            final_annotation_file = information[1]
            significant_hits = []
            with open(filtered_results) as swissprot_results:
                for line in swissprot_results:
                    line = line.strip().split()
                    significant_hits.append((line[0], line[1]))
            if len(significant_hits) == 0:
                # If no hits were found
                if light:
                    # Write all the proteins without annotation
                    with open(final_annotation_file, 'a') as annotation_fh:
                        for original_protein in \
                            unannotated_proteins[starting_filename]:
                            annotation_fh.write(
                                f"{original_protein}\tNA\tNo match found\t"
                                f"NA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n")
                else:
                    # Copy the first iteration proteins to the second iteration file
                    next_fasta = temp_prot_folder / f"{starting_filename}.step_3"
                    copyfile(protein_file_info[original_file][3], next_fasta)
                    # protein_file : [protein_file_name, final_annotation_file,
                    # kofam_results, fasta_step_2, Swissprot_Search_File,
                    # fasta_step_3]
                    protein_file_info[original_file].append(next_fasta)
            else:
                # Extract the annotations
                annotation = sqlite3_search.search_ids_imported(
                    refdata.microbeannotator_db, "swissprot", significant_hits)
                # Each annotation will have:
                # query_id gene_id accession product ko_number organism taxonomy
                # function_GO compartment_GO process_GO interpro_id pfam_id
                # EC_number
                if light:
                    # Write the annotations found in the annotation file
                    with open(final_annotation_file, 'a') as annotation_fh:
                        for match in annotation:
                            annotation_fh.write(
                                f"{match[0]}\t{match[1]}\t{match[3]}\t"
                                f"{match[4]}\tNA\t{match[6]}\t{match[7]}\t"
                                f"{match[8]}\t{match[9]}\t{match[10]}\t"
                                f"{match[11]}\t{match[12]}\tswissprot\n")
                            # If match has KO number remove from starting
                            # and unannotated proteins
                            if match[4] != "" and match[4] != "NA":
                                try:
                                    starting_proteins[
                                        starting_filename].remove(match[0])
                                except ValueError:
                                    pass
                                try:
                                    unannotated_proteins[
                                        starting_filename].remove(match[0])
                                except ValueError:
                                    pass
                            # If it has annotation but not KO, remove
                            # from unannotated proteins only
                            elif match[3] != "":
                                try:
                                    unannotated_proteins[
                                        starting_filename].remove(match[0])
                                except ValueError:
                                    pass
                    # Check which proteins were not annotated and add information on those
                    # Extract annotated proteins
                    with open(final_annotation_file, 'a') as annotation_fh:
                        for original_protein in \
                            unannotated_proteins[starting_filename]:
                            annotation_fh.write(
                                f"{original_protein}\tNA\tNo match found\t"
                                f"NA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n")
                else:
                    # Write the annotations found in the annotation file
                    with open(final_annotation_file, 'a') as annotation_fh:
                        for match in annotation:
                            annotation_fh.write(
                                f"{match[0]}\t{match[1]}\t{match[3]}\t"
                                f"{match[4]}\tNA\t{match[6]}\t{match[7]}\t"
                                f"{match[8]}\t{match[9]}\t{match[10]}\t"
                                f"{match[11]}\t{match[12]}\tswissprot\n")
                            if full:
                            # Remove annotated from the final unannotated list
                                if match[3] != "":
                                    try:
                                        unannotated_proteins[
                                            starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                            else:
                                # Remove annotations with KO from starting and
                                # unannotated proteins
                                if match[4] != "" and match[4] != "NA":
                                    try:
                                        starting_proteins[
                                            starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                                    try:
                                        unannotated_proteins[
                                            starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                                # If it has annotation but not KO, remove
                                # from unannotated proteins only
                                elif match[3] != "":
                                    try:
                                        unannotated_proteins[
                                            starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                    # Check which proteins were not annotated and filter
                    # for the next iteration
                    next_fasta = temp_prot_folder / (
                        f"{starting_filename}.step_3")
                    if full:
                        copyfile(protein_file_info[original_file][3], next_fasta)
                        # protein_file : [protein_file_name, final_annotation_file,
                        # kofam_results, fasta_step_2, Swissprot_Search_File,
                        # fasta_step_3]
                        protein_file_info[original_file].append(next_fasta)
                    else:
                        if len(starting_proteins[starting_filename]) > 0:
                            to_retain = starting_proteins[
                                starting_filename].copy()
                            filterfasta.fastA_filter_list(
                                protein_file_info[original_file][3],
                                next_fasta, to_retain, reverse=False)
                            # protein_file : [protein_file_name, final_annotation_file,
                            # kofam_results, fasta_step_2, Swissprot_Search_File,
                            # fasta_step_3]
                            protein_file_info[original_file].append(next_fasta)
        # When the process is complete, write step completed (+1)
        # in the log folder. Also, export dictionary with information
        # to be imported in case of continue
        process_step += 1
        with open(process_log_folder / "log.txt", 'w') as logfile:
            logfile.write(f"{process_step}")
        with open(process_log_folder / "structure.pickle", 'wb') \
            as structure_file:
            pickle.dump(protein_file_info, structure_file)
    # ==========================================================================
    
    # ==========================================================================
    # Perform RefSeq searches
    # ==========================================================================
    # If running complete pipeline search against refseq
    if process_step == 3:
        if light:
            if refine:
                process_step = 5
            else:
                logger.info(
                    "MicrobeAnnotator annotation has finished succesfully!")
                logger.info("Summarizing results")
                if temp_prot_folder.is_dir():
                    rmtree(temp_prot_folder)
                process_step = 6
        else:
            logger.info("Searching proteins against RefSeq")
            # If we have a second iteration file then perform the searches
            # against refseq
            input_proteins = []
            for infor in protein_file_info.values():
                if len(infor) == 6:
                    input_proteins.append(infor[-1])
            if len(input_proteins) > 0:
                try:
                    pool = multiprocessing.Pool(processes)
                    arguments_to_pass = (
                        outdir, 'refseq', refdata.refseq_db, method, threads,
                        id_perc, bitscore, evalue, aln_percent, method_bin)
                    search_results = pool.map(partial(search.similarity_search,
                    multiple_arguments=arguments_to_pass), input_proteins)
                    # Results as (protein_file, filtered_search_file)
                finally:
                    pool.close()
                    pool.join()
            # Add name of filtered_search_results to protein_file_info
            # New structure:
            # protein_file : [protein_file_name, final_annotation_file,
            # kofam_results, fasta_step_2, Swissprot_Search_File,
            # fasta_step_3, RefSeq_Search_File]
            temp_dir = protein_file_info.copy()
            for result in search_results:
                for filename, info in temp_dir.items():
                    if len(info) == 6 and result[0] == info[5]:
                        protein_file_info[filename].append(result[1])
            del temp_dir
            # Search annotations in SQLite DB and
            # append to the final annotation file
            logger.info("Extracting RefSeq annotation data")
            for original_file, information in protein_file_info.items():
                if len(information) != 7:
                    continue
                starting_filename = information[0]
                filtered_results = information[6]
                final_annotation_file = information[1]
                significant_hits = []
                with open(filtered_results) as refseq_results:
                    for line in refseq_results:
                        line = line.strip().split()
                        significant_hits.append((line[0], line[1]))
                # If no hits were found
                if len(significant_hits) == 0:
                    next_fasta = temp_prot_folder / (
                        f"{starting_filename}.step_4")
                    copyfile(protein_file_info[original_file][4],next_fasta)
                    # protein_file : [protein_file_name, final_annotation_file,
                    # kofam_results, fasta_step_2, Swissprot_Search_File,
                    # fasta_step_3, RefSeq_Search_File, fasta_step_4]
                    protein_file_info[original_file].append(next_fasta)
                else:
                    # Extract the annotations
                    annotation = sqlite3_search.search_ids_imported(
                        refdata.microbeannotator_db, "refseq", significant_hits)
                    # Each annotation will have:
                    # query_id gene_id product taxonomy ko_number ko_product
                    # Write the annotations found in the annotation file
                    with open(final_annotation_file, 'a') as annotation_fh:
                        for match in annotation:
                            annotation_fh.write(
                                f"{match[0]}\t{match[1]}\t{match[2]}\t"
                                f"{match[5]}\t{match[6]}\t{match[3]}\t"
                                f"NA\tNA\tNA\tNA\tNA\t{match[4]}\trefseq\n")
                            if full:
                                if match[2] != "":
                                    try:
                                        unannotated_proteins[
                                        starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                            else:
                                # Remove annotations with KO from starting and
                                # unannotated proteins
                                if match[4] != "" and match[4] != "NA":
                                    try:
                                        starting_proteins[
                                        starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                                    try:
                                        unannotated_proteins[
                                            starting_filename].remove(
                                                match[0])
                                    except ValueError:
                                        continue
                                # If it has annotation but not KO, remove
                                # from unannotated proteins only
                                elif match[2] != "":
                                    try:
                                        unannotated_proteins[
                                        starting_filename].remove(match[0])
                                    except ValueError:
                                        continue
                    # Check which proteins were not annotated and filter
                    # for the next iteration
                    next_fasta = temp_prot_folder / (
                        f"{starting_filename}.step_4")
                    if full == True:
                        copyfile(
                            protein_file_info[original_file][5],
                            next_fasta)
                        # protein_file : [protein_file_name, final_annotation_file,
                        # kofam_results, fasta_step_2, Swissprot_Search_File,
                        # fasta_step_3, RefSeq_Search_File, fasta_step_4]
                        protein_file_info[original_file].append(next_fasta)  
                    else:
                        if len(starting_proteins[starting_filename]) > 0:
                            to_retain = starting_proteins[
                                starting_filename].copy()
                            filterfasta.fastA_filter_list(
                                protein_file_info[original_file][5], 
                                next_fasta, to_retain, reverse=False)
                            protein_file_info[original_file].append(next_fasta)
            # When the process is complete, write step completed (+1)
            # in the log folder. Also, export dictionary with information
            # to be imported in case of continue  
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') \
                as structure_file:
                pickle.dump(protein_file_info, structure_file)
    # ==========================================================================


    # ==========================================================================
    # Perform TrEMBL searches
    # ==========================================================================
        # Finally run remaining proteins against Trembl
    if process_step == 4:
        logger.info("Searching proteins against Trembl")
        input_proteins = []
        for infor in protein_file_info.values():
            if len(infor) == 8:
                input_proteins.append(infor[-1])
        if len(input_proteins) > 0:
            try:
                pool = multiprocessing.Pool(processes)
                arguments_to_pass = (
                    outdir, 'trembl', refdata.trembl_db, method, threads,\
                    id_perc, bitscore, evalue, aln_percent, method_bin)
                search_results = pool.map(partial(search.similarity_search,
                multiple_arguments=arguments_to_pass), input_proteins)
                # Results as (protein_file, filtered_search_file)
            finally:
                pool.close()
                pool.join()
        # Add name of filtered_search_results to protein_file_info
        # New structure:
        # protein_file : [protein_file_name, final_annotation_file,
        # kofam_results, fasta_step_2, Swissprot_Search_File,
        # fasta_step_3, RefSeq_Search_File, fasta_step_4, Trembl_Search_File]
        temp_dir = protein_file_info.copy()
        for result in search_results:
            for filename, info in temp_dir.items():
                if len(info) == 8 and result[0] == info[6]:
                    protein_file_info[filename].append(result[1])
        del temp_dir
        # Search annotations in SQLite DB and
        # append to the final annotation file
        logger.info("Extracting Trembl annotation data")
        for original_file, information in protein_file_info.items():
            if len(information) != 9:
                continue
            starting_filename = information[0]
            filtered_results = information[8]
            final_annotation_file = information[1]
            significant_hits = []
            with open(filtered_results) as trembl_results:
                for line in trembl_results:
                    line = line.strip().split()
                    significant_hits.append((line[0], line[1]))
            # If no hits were found
            if len(significant_hits) == 0:
                with open(final_annotation_file, 'a') as annotation_fh:
                    for original_protein in unannotated_proteins[
                        starting_filename]:
                        annotation_fh.write(
                            f"{original_protein}\tNA\tNo match found\tNA\t"
                            f"NA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n")
            else:
                annotation = sqlite3_search.search_ids_imported(
                    refdata.microbeannotator_db, "trembl", significant_hits)
                # Each annotation will have:
                # query_id gene_id accession product ko_number organism taxonomy
                # function_GO compartment_GO process_GO interpro_id pfam_id
                # EC_number
                with open(final_annotation_file, 'a') as annotation_fh:
                    for match in annotation:
                        annotation_fh.write(
                            f"{match[0]}\t{match[1]}\t{match[3]}\t{match[4]}\t"
                            f"NA\t{match[6]}\t{match[7]}\t{match[8]}\t"
                            f"{match[9]}\t{match[10]}\t{match[11]}\t"
                            f"{match[12]}\ttrembl\n")
                        if match[4] != "" and match[4] != "NA":
                            # If match has KO number remove from starting
                            # and unannotated proteins
                            try:
                                starting_proteins[
                                    starting_filename].remove(match[0])
                            except ValueError:
                                continue
                            try:
                                unannotated_proteins[
                                    starting_filename].remove(match[0])
                            except ValueError:
                                continue
                        # If it has annotation but not KO, remove
                        # from unannotated proteins only
                        elif match[3] != "":
                            try:
                                unannotated_proteins[
                                    starting_filename].remove(match[0])
                            except ValueError:
                                continue
                with open(final_annotation_file, 'a') as annotation_fh:
                    for original_protein in unannotated_proteins[
                        starting_filename]:
                        annotation_fh.write(
                            f"{original_protein}\tNA\tNo match found\tNA\t"
                            f"NA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\tNA\n")
        if temp_prot_folder.is_dir():
            rmtree(temp_prot_folder)
        # When the process is complete, write step completed (+1) in the log folder
        # Also, export dictionary with information to be imported in case of continue  
        process_step += 1
        with open(process_log_folder / "log.txt", 'w') as logfile:
            logfile.write("{}".format(process_step))
        with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
            pickle.dump(protein_file_info, structure_file)
    # ==========================================================================


    # ==========================================================================
    # Refine annotations with ids from other databases
    # ==========================================================================
    if process_step == 5:
        if refine == False:
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') \
                as structure_file:
                pickle.dump(protein_file_info, structure_file)
        else:
            import pandas as pd
            logger.info(
                f"Improving annotations by searching matches "
                f"in other databases")
            # Structure: 
            # protein_file : [protein_file_name, final_annotation_file,
            # kofam_results, fasta_step_2, Swissprot_Search_File,
            # fasta_step_3, RefSeq_Search_File, fasta_step_4, Trembl_Search_File]
            final_annotation_files = []
            for information in protein_file_info.values():
                final_annotation_files.append(str(information[1]))
            
            # Process each annotation file separately
            for annotation_file in final_annotation_files:
                positive_matches = {}
                annotation_table = pd.read_csv(annotation_file, sep="\t", header=0, index_col=0)
                # Parse Uniprot Records
                swissprot_records = annotation_table.loc[(annotation_table['database'] == "swissprot") & (annotation_table['ko_number'] == "NA"), ]
                # Extract records without ko numbers but with ec number
                swissprot_records_ko = swissprot_records.loc[(swissprot_records['ko_number'] == "NA") & (swissprot_records['ec_number'] != "NA"), ]
                swissprot_records_ko_ids = list(swissprot_records_ko['ec_number'])
                if len(swissprot_records_ko_ids) > 0:
                    positive_matches = convert.convert_ko_to_ec(swissprot_records_ko_ids, refdata.conversion_db, True)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            annotation_table.loc[(annotation_table['ec_number'] == identifier) & (annotation_table['ko_number'] == "NA"), "ko_number"] = " ".join(matches)
                # Delete used sub-tables
                del swissprot_records
                del swissprot_records_ko
                
                # Parse Trembl Records
                trembl_records = annotation_table.loc[(annotation_table['database'] == "trembl") & (annotation_table['ko_number'] == "NA"), ]
                # Extract records without ko numbers but with ec number
                trembl_records_ko = trembl_records.loc[(trembl_records['ko_number'] == "NA") & (trembl_records['ec_number'] != "NA"), ]
                trembl_records_ko_ids = list(trembl_records_ko['ec_number'])
                if len(trembl_records_ko_ids) > 0:
                    positive_matches = convert.convert_ko_to_ec(trembl_records_ko_ids, refdata.conversion_db, True)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            annotation_table.loc[(annotation_table['ec_number'] == identifier) & (annotation_table['ko_number'] == "NA"), "ko_number"] = " ".join(matches)
                # Delete used sub-tables
                del trembl_records
                del trembl_records_ko

                # Parse RefSeq Records
                refseq_records = annotation_table.loc[(annotation_table.database == "refseq") & (annotation_table['ko_number'] == "NA"), ]
                # Extract records without ko numbers but with ec number
                refseq_records_ko = refseq_records.loc[(refseq_records['ko_number'] == "NA") & (refseq_records['ec_number'] != "NA"), ]
                refseq_records_ko_ids = list(refseq_records_ko['ec_number'])
                if len(refseq_records_ko_ids) > 0:
                    positive_matches = convert.convert_ko_to_ec(refseq_records_ko_ids, refdata.conversion_db, True)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            annotation_table.loc[(annotation_table['ec_number'] == identifier), "ko_number"] = " ".join(matches)
                # Extract refseq ids for records with no ko
                refseq_records = annotation_table.loc[(annotation_table.database == "refseq") & (annotation_table['ko_number'] == "NA"), ]
                refseq_records_ids = list(refseq_records['protein_id'])
                if len(refseq_records_ids) > 0:
                    refseq_records_indices = refseq_records.index
                    query_refseq_pair = list(zip(refseq_records_indices, refseq_records_ids))
                    # Convert refseq to uniprot
                    refseq_uniprot = {}
                    positive_matches = convert.convert_refseq_to_uniprot(refseq_records_ids, refdata.conversion_db, False)
                    if len(positive_matches) > 0:
                        for identifier, matches in positive_matches.items():
                            refseq_uniprot[identifier] = matches[0]
                        query_uniprot_pair = []
                        for pair in query_refseq_pair:
                            if pair[1] in refseq_uniprot:
                                query_uniprot_pair.append((pair[refseq_records_indices], refseq_uniprot[pair[1]]))
                        # Now estract the uniprot annotations using the identifiers found in the previous step
                        uniprot_annotations = sqlite3_search.search_ids_imported(refdata.microbeannotator_db, "trembl", query_uniprot_pair)
                        if len(uniprot_annotations) > 0:
                            # Parse results and modify the original table
                            for annotation in uniprot_annotations:
                                original_index = annotation.pop(0)
                                annotation_table.loc[original_index] = annotation
                # Save the refined table in the original location
                pd.DataFrame.to_csv(annotation_table, sep="\t", index=0, header=True)
            # When the process is complete, write step completed (+1) in the log folder
            # Also, export dictionary with information to be imported in case of continue
            process_step += 1
            with open(process_log_folder / "log.txt", 'w') as logfile:
                logfile.write("{}".format(process_step))
            with open(process_log_folder / "structure.pickle", 'wb') as structure_file:
                pickle.dump(protein_file_info, structure_file)
    # ==========================================================================


    # ==========================================================================
    # Parse annotation files and summarize them
    # ==========================================================================
    if process_step == 6:
        logger.info("Extracting ko numbers and summarizing results")
        annotation_files = []
        for information in protein_file_info.values():
            annotation_folder = information[1].parent
            ko_numbers = str(annotation_folder / (information[0] + ".ko"))
            with open(information[1], 'r') as annotations, open(ko_numbers, 'w') as ko_present:
                for line in annotations:
                    line = line.strip().split("\t")
                    # If there are multiple KOs
                    if line[3] != "":
                        if len(line[3].split()) > 1:
                            for element in line[3].split():
                                ko_present.write("{}\n".format(element))
                        else:
                            ko_present.write("{}\n".format(line[3]))
            annotation_files.append(ko_numbers)
        prefix = str(Path(outdir) / plot_filename)
        regular_modules, bifurcation_modules, structural_modules,  \
        module_information, metabolism_matrix, module_group_matrix = ko_mapper.module_information_importer(annotation_files)
        metabolic_annotation = ko_mapper.global_mapper(regular_modules, bifurcation_modules, structural_modules, annotation_files)
        metabolism_matrix_dropped_relabel, module_colors = ko_mapper.create_output_files(metabolic_annotation, metabolism_matrix, module_information, cluster, prefix)
        ko_mapper.plot_function_barplots(module_colors, module_group_matrix, metabolism_matrix_dropped_relabel, prefix)
        logger.info("MicrobeAnnotator has finished succefully!")
    # ==========================================================================


# ==============================================================================
# Run main function
# ==============================================================================
if __name__ == "__main__":
    main()
# ==============================================================================
