#! /usr/bin/env python3
import argparse
import logging
import numpy as np
import os
import sys

from modelmatcher.models import RateMatrix
from modelmatcher.model_io import read_model

def get_lower_triangle(R):
    r_elems = R[np.tril_indices(20,k=-1)]
    return r_elems

def get_upper_triangle(R):
    r_elems = R[np.triu_indices(20,k=1)]
    return r_elems

def perturb(R_elems, sd, min_element = 0.001):
    '''
    Perturb the R elements (a vector) with a factor drawn from N(1,sd), but ensure it is positive.
    The third parameter defines the smallest value accepted in an R matrix.
    Returns the new array (matrix)
    '''
    perturberation = np.random.normal(loc=1.0, scale=sd, size=len(R_elems))
    n_small_elements = sum(perturberation < min_element)
    if n_small_elements > 0:
        logging.info(f'#small elements:\t{n_small_elements}')

    # Ensure that there are no small elements
    perturberation[perturberation < min_element] = min_element

    R2 =  np.multiply(R_elems, perturberation)
    return R2

def print_q(r_elems):
    '''
    Output the R matrix in a seq-gen compatible format.
    '''
    print(' '.join(map(str,r_elems)))

def get_model(model_name, backup_dir='.'):
    '''
    Return instance of the named model. If not a built-in model, try loading from file in
    the given directory (current dir, by default).
    '''
    if model_name in list(map(lambda m: m.get_name(), RateMatrix.get_all_models())):
        # Standard model
        model = model_name
        m = RateMatrix.instantiate(model)
        return m
    else:
        # Assume model name is a file
        filename=backup_dir + '/' + model_name
        try:
            with open(filename) as h:
                basename = os.path.basename(model_name)
                if not basename:
                    basename = 'UserDefined'
                m = read_model(h, basename) # Use the base of the filename as modelname
                return m
        except Exception as e:
            print(e, file=sys.stderr)
            sys.exit(f'Could not load model "{args.model_name}". It is neither a known model nor a file containing a model.')



def main():
    ap = argparse.ArgumentParser(description='Given a choice of rate model (eg WAG, LG), generate a perturbed version of the model to be used with the -r option in seq-gen. The name is misleading: it is the R matrix, not Q, which is output and only the lower-triagular part.')
    ap.add_argument('model_name',
                    help='Specifify which standard sequence model should be used.')
    ap.add_argument('-d', '--model-directory',
                    help='If model name is not a built-in, try loading matrix from this directory.')
    ap.add_argument('sd', type=float, default=0.1,
                    help='Define the perturberation that should be used. Each entry in the R matrix will be multiplied with a factor taken from N(1, sd).')
    group=ap.add_mutually_exclusive_group(required=True)
    group.add_argument('--paml', action='store_true',
                       help='Output the lower triangular R matrix, as PAML needs.')
    group.add_argument('--seqgen', action='store_true',
                       help='Output the upper triangular R matrix, as seq-gen needs.')
    group.add_argument('--indelible', action='store_true',
                       help='Output the upper triangular R matrix, as indelible needs.')
    group.add_argument('--freqs', action='store_true',
                       help='Output the amino acid frequencies of the named model.')


    ap.add_argument('-l', '--log-file',
                    help='Give the name of a log file to use')

    args = ap.parse_args()

    if args.log_file:
        logging.basicConfig(filename=args.log_file, level=logging.INFO)

    m = get_model(args.model_name, args.model_directory)
    sd = args.sd

    if args.freqs:
        f = m.get_freq()
        print(' '.join(map(str, f)))
    elif args.indelible:
        r_part = get_lower_triangle(m.get_r())
        r_new = perturb(r_part, sd)
        i = 0
        for row in range(1,20):
            for col in range(row):
                print(f'{r_new[i]:8.5}', end=' ')
                i += 1
            print()
        print()
        freqs = m.get_freq()
        for f in freqs:
            print(f'{f:8.5}', end=' ')
        print()
    else:
        if args.paml:
            r_part = get_lower_triangle(m.get_r())
        elif args.seqgen:
            r_part = get_upper_triangle(m.get_r())
        else:
            exit('You need to choose between PAML and Seq-Gen output, sorry!')
        r_new = perturb(r_part, sd)
        print_q(r_new)

if __name__ == '__main__':
    main()
