#!/usr/bin/env python3

import argparse
import sys
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from random import randint

from modelmatcher.models import RateMatrix
import modelmatcher.models as ms
from modelmatcher.model_io import read_model

def pca_and_figure(ax, models, names, title, special_positioning=False):
    '''
    Take a list of models, represented as np arrays of the lower triangle R elements,
    run a PCA, and show a figure.
    '''
    scaled_models = StandardScaler().fit_transform(models)

    pca = PCA(n_components=2)
    pca.fit(scaled_models)

    projection = pca.transform(scaled_models)
    print('#models projected:', len(models))
    print('Transform size:', np.shape(projection))
    print('Explained variance:', pca.explained_variance_)

    ax.scatter(projection[:,0], projection[:,1])

    i = 0
    for name, x, y in zip(names, projection[:,0], projection[:,1]):
        print(f'{name:10} {x:.6} {y:.6}')
        if special_positioning and x < 0 and y < -1:
            label_coord = (-15 + 10*i, 30 - 11*i)
            i += 1
        else:
            label_coord = (5, 0)
        ax.annotate(name, xy=(x,y),
                     textcoords='offset pixels', xytext=label_coord)
    ax.set_title(title)

def plot_models(ax, m1, m2):
    name1 = m1.get_name()
    name2 = m2.get_name()
    R1 = m1.get_r()
    R2 = m2.get_r()

    ax.plot(R1[np.tril_indices(20, k=-1)], R2[np.tril_indices(20, k=-1)], '.')
    ax.set_xlabel(name1)
    ax.set_ylabel(name2)

def main(q=None):
    models=[]
    names=[]
    core_models=[]
    core_names=[]
    for m in RateMatrix.get_all_models():
        R = m.get_r()
        elems = R[np.tril_indices(20, k=-1)]
        models.append(elems)
        name = m.get_name()
        names.append(name)

        if name in ['WAG', 'JTT_DCMut', 'LG', 'VT', 'HIVb', 'HIVw', 'BLOSUM62', 'DCMUT', 'FLU']:
            core_models.append(elems)
            core_names.append(name)

    if q:
        R = q.get_r()
        models.append(R[np.tril_indices(20, k=-1)])
        names.append(q.get_name())

    fig, axs = plt.subplots(2,2, constrained_layout=True)
    plot_models(axs[1,0], ms.WAG(), ms.LG())
    plot_models(axs[0,0], ms.WAG(), ms.VT())
    plot_models(axs[0,1], ms.WAG(), ms.HIVw())
    if q:
        plot_models(axs[1,1], ms.WAG(), q)
    else:
        plot_models(axs[1,1], ms.HIVb(), ms.HIVw())
    plt.suptitle(f'Comparing R elements of some models')
    plt.show()

    fig, (left, right) = plt.subplots(1, 2, constrained_layout=True, figsize=(8,4))
    pca_and_figure(left, models, names, 'PCA for all models', True)
    pca_and_figure(right, core_models, core_names, 'PCA for selected models')
    plt.show()

if __name__ == '__main__':
    ap = argparse.ArgumentParser(description='Make plots comparing replacement models. The PCA is on the values from the R components (not Q matrices)')
    ap.add_argument('name', default='*', nargs='?',
                    help='Provide a tentative name for the input model. Nonsensical if no input file is given.')
    ap.add_argument('infile', nargs='?',
                    help='Optional infile in PAML format')
    args = ap.parse_args()

    if args.infile:
        with open(args.infile) as h:
            q = read_model(h, args.name)
            main(q)
    else:
        main()
