#! /usr/bin/env python3
import argparse
import numpy as np

from modelmatcher.models import RateMatrix

def get_lower_triangle(R):
    r_elems = R[np.tril_indices(20,k=-1)]
    return r_elems

def get_upper_triangle(R):
    r_elems = R[np.triu_indices(20,k=1)]
    return r_elems


def main():
    ap = argparse.ArgumentParser(description='Compute the average of two seq evol models and print it out as a PAML-formatted matrix.')
    ap.add_argument('model1', choices=list(map(lambda m: m.get_name(), RateMatrix.get_all_models())),
                    help='Specifify first standard sequence model.')
    ap.add_argument('model2', choices=list(map(lambda m: m.get_name(), RateMatrix.get_all_models())),
                    help='Specifify second standard sequence model.')

    args = ap.parse_args()

    model1 = RateMatrix.instantiate(args.model1)
    model2 = RateMatrix.instantiate(args.model2)

    new_model = RateMatrix.combine_models(model1, model2)

    r_part = get_lower_triangle(new_model.get_r())

    left_index = 0
    for i in range(1,20):
        right_index = left_index + i
        print(' '.join(map(lambda f: f'{f:.6}', r_part[left_index:right_index])))
        left_index = right_index
    print()
    print(' '.join(map(lambda f: f'{f:.6}', new_model.freq)))
    print('\n# Created by combine_q in the modelmatcher package')
    print(f'# as the average of rate matrices for {args.model1} and {args.model2}.')

if __name__ == '__main__':
    main()
