#!/usr/bin/env python

# executing script for IDR predictor in command line.

# import stuff for making CLI
import os
import argparse
import protfasta


import metapredict as meta


if __name__ == "__main__":

    # Parse command line arguments.
    parser = argparse.ArgumentParser(description='Predict IDRs for all sequences in a FASTA file.')

    parser.add_argument('data_file', help='Path to fasta file containing sequences to be predicted.')

    parser.add_argument('-o', '--output-file', help='Filename for where to save the outputfile. Default = idrs_shephard.tsv', default='idrs_shephard.tsv')

    parser.add_argument('-l', '--legacy', action='store_true', help='Optional. Use this flag to use the original legacy version of metapredict.')

    parser.add_argument('--invalid-sequence-action', help="For parsing FASTA file, defines how to deal with non-standard amino acids. See https://protfasta.readthedocs.io/en/latest/read_fasta.html for details. Default='convert' ", default='convert')

    parser.add_argument('--mode', help='Defines the mode in which IDRs are reported. By default this generates a FASTA file with header format that matches  the input file with an additional set of fields that are "IDR_START=$START   IDR_END=$END" where $START and $END are the starting and ending IDRs. If mode is set to shephard-domains than a SHEPHAD-compliant domains file is generated. If shephard-domains-uniprot the uniprot ID is extracted from the header assuming standard uniprot formatting. Default = fasta', default='fasta')

    parser.add_argument('--threshold', help='Defines the threshold used to define a region as disordered or not. Default=0.42 which is recommended.', default=0.42, type=float)
    parser.add_argument('--verbose', help='If included then prints out status updates', action='store_true')

    args = parser.parse_args()

    if args.mode not in ['fasta', 'shephard-domains','shephard-domains-uniprot', ]:
        raise Exception("--mode must be set to one of 'fasta' or 'shephard-domains'")
    
    if args.legacy:
        use_legacy = True
        threshold_val = args.threshold_val


    else:
        use_legacy=False
        # if not using legacy and the default legacy value is still being used, adjust it to 0.5.
        if args.threshold == 0.42:
            threshold_val = 0.5

        # if the user sets their own threshold value that isn't 0.42, keep it.
        else:
            threshold_val = args.threshold
    
    if not os.path.isfile(args.data_file):
        print('Error: Could not find passed fasta file [%s]'%(args.data_file))

    # read in sequences
    sequences = protfasta.read_fasta(args.data_file, invalid_sequence_action=args.invalid_sequence_action)
    if args.verbose:
        print('Read in FASTA file')

    idrs = {}
    
    c = 0
    n_seqs = len(sequences)
    for s in sequences:
        c=c+1
        idrs[s] = meta.predict_disorder_domains(sequences[s], disorder_threshold=threshold_val, legacy=use_legacy, return_list=True)

        if n_seqs > 500:
            if args.verbose:
                if c % 500 == 0:
                    print('On %i of %i'%(c,n_seqs))
                
        
    return_dictionary = {}
    if args.mode == 'fasta':
        # for each sequence
        for s in idrs:
            # d is IDR start and end positions
            for d in idrs[s][2]:
                return_dictionary['%s IDR_START=%i IDR_END=%i' % (s, d[0], d[1])] =  d[2]
                
        
        protfasta.write_fasta(return_dictionary, args.output_file)

    elif args.mode == 'shephard-domains':
        fh = open(args.output_file, 'w')
        for s in idrs:
            # d is IDR start and end positions
            for d in idrs[s][2]:

                # note need +1 for shephard format
                start = d[0]+1
                end = d[1]
                fh.write(f'{s}\t{start}\t{end}\tIDR\n')

    elif args.mode == 'shephard-domains-uniprot':
        fh = open(args.output_file, 'w')
        for s in idrs:
            # d is IDR start and end positions
            for d in idrs[s][2]:

                uid = s.split('|')[1]
                start = d[0]+1
                end = d[1]
                fh.write(f'{uid}\t{start}\t{end}\tIDR\n')
                
        

