#!/usr/bin/env python

# Argument handling

import argparse
import importlib.metadata
import sys

parser = argparse.ArgumentParser(description=(
    "Fast protein structure searching using structure graph embeddings. "
    "See https://github.com/greener-group/progres for documentation and citation information. "
    f"This is version {importlib.metadata.version('progres')} of the software."
))
subparsers = parser.add_subparsers(dest="mode",
    help="the mode to run progres in, run \"progres {mode} -h\" to see help for each")

parser_search = subparsers.add_parser("search",
    description="Search one or more queries against a pre-embedded database.",
    help="search one or more queries against a pre-embedded database")
parser_search.add_argument("-q", "--querystructure",
    help="query structure file in PDB/mmCIF/MMTF/coordinate format")
parser_search.add_argument("-l", "--querylist",
    help="text file with one query file path per line")
parser_search.add_argument("-t", "--targetdb", required=True,
    help=("pre-embedded database to search against, either "
          "\"scope95\", \"scope40\", \"cath40\", \"ecod70\", \"af21org\", \"afted\" or a file path"))
parser_search.add_argument("-f", "--fileformat",
    choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
    help="file format of the query structure(s), by default guessed from the file extension")
parser_search.add_argument("-s", "--minsimilarity", type=float, default=0.8,
    help="similarity threshold above which to return hits, default 0.8")
parser_search.add_argument("-m", "--maxhits", type=int, default=100,
    help="maximum number of hits to return, default 100")
parser_search.add_argument("-d", "--device", default="cpu",
    help="device to run on, default is \"cpu\"")

parser_score = subparsers.add_parser("score",
    description=(
        "Calculate the Progres score between two protein domains. "
        "The order of the domains does not affect the score. "
        "A score of 0.8 or higher indicates the same fold."
    ),
    help="calculate the Progres score between two protein domains")
parser_score.add_argument("structure1",
    help="first structure file in PDB/mmCIF/MMTF/coordinate format")
parser_score.add_argument("structure2",
    help="second structure file in PDB/mmCIF/MMTF/coordinate format")
parser_score.add_argument("-f", "--fileformat1",
    choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
    help="file format of the first structure, by default guessed from the file extension")
parser_score.add_argument("-g", "--fileformat2",
    choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
    help="file format of the second structure, by default guessed from the file extension")
parser_score.add_argument("-d", "--device", default="cpu",
    help="device to run on, default is \"cpu\"")

parser_embed = subparsers.add_parser("embed",
    description="Embed a dataset of structures to allow it to be searched against.",
    help="embed a dataset of structures to allow it to be searched against")
parser_embed.add_argument("-l", "--structurelist", required=True,
    help="text file with file path, domain name and optional note per line")
parser_embed.add_argument("-o", "--outputfile", required=True,
    help="output file path for the PyTorch file containing the embeddings")
parser_embed.add_argument("-f", "--fileformat",
    choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess",
    help="file format of the structures, by default guessed from the file extension")
parser_embed.add_argument("-d", "--device", default="cpu",
    help="device to run on, default is \"cpu\"")

args = parser.parse_args()

def main():
    if args.mode == "search":
        from progres import progres_search_print
        if args.minsimilarity < 0 or args.minsimilarity > 1:
            raise argparse.ArgumentTypeError("minsimilarity must be between 0 and 1")
        if args.maxhits < 1:
            raise argparse.ArgumentTypeError("maxhits must be a positive integer")
        if args.querystructure:
            progres_search_print(querystructure=args.querystructure, targetdb=args.targetdb,
                                 fileformat=args.fileformat, minsimilarity=args.minsimilarity,
                                 maxhits=args.maxhits, device=args.device)
        elif args.querylist:
            progres_search_print(querylist=args.querylist, targetdb=args.targetdb,
                                 fileformat=args.fileformat, minsimilarity=args.minsimilarity,
                                 maxhits=args.maxhits, device=args.device)
        else:
            print("One of -q and -l must be given for structure searching", file=sys.stderr)
    elif args.mode == "score":
        from progres import progres_score_print
        progres_score_print(structure1=args.structure1, structure2=args.structure2,
                            fileformat1=args.fileformat1, fileformat2=args.fileformat2,
                            device=args.device)
    elif args.mode == "embed":
        from progres import progres_embed
        progres_embed(structurelist=args.structurelist, outputfile=args.outputfile,
                      fileformat=args.fileformat, device=args.device)
    else:
        print("No mode selected, run \"progres -h\" to see help", file=sys.stderr)

if __name__ == "__main__":
    main()
