#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
from tm_vec.embed_structure_model import (trans_basic_block,
                                          trans_basic_block_Config)
from tm_vec.tm_vec_utils import (featurize_prottrans, embed_tm_vec,
                                 format_ids, encode, load_database, query)
from tm_vec.fasta import Indexer
from deepblast.dataset.utils import states2alignment
from transformers import T5EncoderModel, T5Tokenizer
import re
import gc
from Bio import SeqIO
import faiss
from pathlib import Path
import argparse
import pickle


parser = argparse.ArgumentParser(
    description='Process TM-Vec arguments', add_help=True)

parser.add_argument("--query",
        type=Path,
        required=True,
        help="Input fasta-formatted data to query against database."
)

parser.add_argument("--database",
        type=Path,
        required=True,
        help="Database to query"
)

parser.add_argument("--database-fasta",
        type=Path,
        required=True,
        help=("Database that contains the corresponding "
              "protein sequences in fasta format.")
)

parser.add_argument("--database-faidx",
        type=Path,
        required=False,
        help="Fasta index for database fasta format."
)

parser.add_argument("--metadata",
        type=Path,
        help="Metadata for queried database"
)

parser.add_argument("--tm-vec-model",
        type=Path,
        required=True,
        help="Model path for embedding"
)

parser.add_argument("--tm-vec-config",
        type=Path,
        required=True,
        help="Config path for embedding"
)

parser.add_argument("--deepblast-model",
        type=Path,
        help="DeepBLAST model path"
)

parser.add_argument("--protrans-model",
                    type=Path,
                    default=None,
                    required=False,
                    help=("Model path for the ProTrans embedding model. "
                          "If this is not specified, then the model will "
                          "automatically be downloaded.")
)
parser.add_argument("--device",
                    type=str,
                    default=None,
                    required=False,
                    help=(
                        "The device id to load the model onto. "
                        "This will specify whether or not a GPU "
                        "will be utilized.  If `gpu` is specified ",
                        "then the first gpu device will be used."
                    )
)

parser.add_argument("--alignment-mode",
                    type=str,
                    default='needleman-wunsch',
                    required=False,
                    help=(
                        "`smith-waterman` or `needleman-wunch`."
                    )
)

parser.add_argument("--k-nearest-neighbors",
        type=int,
        default=5,
        help="Number of nearest neighbhors to return for each query."
)

parser.add_argument("--output-format",
                    type=str,
                    default='tabular',
                    required=False,
                    help=(
                        "Options include `tabular`, `alignment` "
                        "or `embedding`."
                    )
)

parser.add_argument("--output",
        type=Path,
        required=True,
        help="Output search results."
)

parser.add_argument("--output-embeddings",
        type=Path,
        default=None,
        required=False,
        help="Output encodings of query proteins (npy format)."
)

#Load arguments
args = parser.parse_args()

#Read in query sequences
with open(args.query) as handle:
    headers = []
    seqs = []
    for record in SeqIO.parse(handle, "fasta"):
        headers.append(record.id)
        seqs.append(str(record.seq))

flat_seqs = [seqs[i] for i in range(len(seqs))]
print("Sequences inputed")

#Load metadata if it exists
metadata_database = np.load(args.metadata)

#Set device
if torch.cuda.is_available() and args.device is not None:
    if args.device == 'gpu':
        device = torch.device(f'cuda:0')
    else:
        device = torch.device(f'cuda:{int(args.device)}')
else:
    print('Models will be loaded on CPU.')
    device = torch.device('cpu')

if args.protrans_model is None:
    #Load the ProtTrans model and ProtTrans tokenizer
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50",
                                            do_lower_case=False )
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
else:
    tokenizer = T5Tokenizer.from_pretrained(args.protrans_model,
                                            do_lower_case=False )
    model = T5EncoderModel.from_pretrained(args.protrans_model)

gc.collect()
model = model.to(device)
model = model.eval()
print("ProtTrans model downloaded")


#Load the Tm_Vec_Align TM model
tm_vec_model_config = trans_basic_block_Config.from_json(args.tm_vec_config)
model_deep = trans_basic_block.load_from_checkpoint(
    args.tm_vec_model, config=tm_vec_model_config)
model_deep = model_deep.to(device)
model_deep = model_deep.eval()
print("TM-Vec model loaded")



#Read in query database
query_database = np.load(args.database)

#Embed all query sequences
#Build query array and save the embeddings (will write them to output)
queries = encode(flat_seqs, model_deep, model, tokenizer, device)

#Build an indexed database
index = load_database(args.database)

#Return the nearest neighbors
k = args.k_nearest_neighbors
D, I = query(index, queries, k)

#Return the metadata for the nearest neighbor results
near_ids = []

print("meta data loaded")

for i in range(I.shape[0]):
    meta = metadata_database[I[i]]
    near_ids.append(list(meta))

near_ids = np.array(near_ids)

del model_deep  # need to clear space for DeepBLAST aligner

# Alignment section
# If we are also aligning proteins, load tm_vec align and
# align nearest neighbors if true
if args.deepblast_model is not None and args.output_format == 'alignment':
    # moving imports here because it may take a very long time
    from deepblast.trainer import DeepBLAST
    from deepblast.dataset.utils import pack_sequences
    from deepblast.dataset.utils import states2alignment
    from deepblast.utils import load_model

    align_model = load_model(
        args.deepblast_model, lm=model, tokenizer=tokenizer,
        alignment_mode=args.alignment_mode,
        device=device)

    align_model = align_model.to(device)
    seq_db = Indexer(args.database_fasta, args.database_faidx)
    seq_db.load()

    alignments = []
    with open(args.output, 'w') as fh:
        for i in range(I.shape[0]):
            for j in range(I.shape[1]):
                x = flat_seqs[i]
                seq_i = metadata_database[I[i, j]]
                y = seq_db[seq_i]
                try :
                    pred_alignment = align_model.align(x, y)
                    # Note : there is an edge case that smith-waterman will throw errors,
                    # but needleman-wunsch won't.
                    x_aligned, y_aligned = states2alignment(pred_alignment, x, y)
                    alignments_i = [x_aligned, pred_alignment, y_aligned]
                    #Write out the alignments
                    x, s, y = alignments_i
                    # TODO : not clear how to get the sequence IDS
                    ix, iy = format_ids(headers[i], seq_i)
                    fh.write(ix + ' ' + x + '\n')
                    fh.write(iy + ' ' + y + '\n\n')
                except:
                    print(f'No valid alignments found for {headers[i]} {seq_i}')


#Outputting results
#Write out the nearest neighbors
# (if no database meta data is providedwrite out their database indices)
if args.output_format == 'tabular':
    nids = pd.DataFrame(near_ids, index=headers)
    nids.index.name = 'query_id'
    nids = pd.melt(nids.reset_index(), id_vars='query_id', var_name='rank', value_name='database_id')
    tms = pd.DataFrame(D, index=headers)
    tms = pd.melt(tms, var_name='query_id', value_name='tm-score')
    nids = pd.concat((nids, tms[['tm-score']]), axis=1)
    nids = nids.sort_values(['query_id', 'rank'])
    nids.to_csv(args.output, sep='\t', index=None)


#Write out the embeddings
if args.output_embeddings is not None:
    np.save(args.output_embeddings, queries)
