#!/usr/bin/env python3
import os
import time
import click
import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
from biom import load_table, Table
from biom.util import biom_open
from skbio import OrdinationResults
from skbio.stats.composition import clr, centralize, closure
from skbio.stats.composition import clr_inv as softmax
from scipy.stats import entropy, spearmanr
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import svds
import tensorflow as tf
from tensorflow.contrib.distributions import Multinomial, Normal
from mmvec.multimodal import MMvec
from mmvec.util import split_tables, format_params
import matplotlib.pyplot as plt

@click.group()
def mmvec():
    pass


@mmvec.command()
@click.option('--microbe-file',
              help='Input microbial abundances')
@click.option('--metabolite-file',
              help='Input metabolite abundances')
@click.option('--metadata-file', default=None,
              help='Input sample metadata file')
@click.option('--training-column',
              help=('Column in the sample metadata specifying which '
                    'samples are for training and testing.'),
              default=None)
@click.option('--num-testing-examples',
              help=('Number of samples to randomly select for testing'),
              default=10)
@click.option('--min-feature-count',
              help=('Minimum number of samples a microbe needs to be observed '
                    'in order to not filter out'),
              default=10)
@click.option('--epochs',
              help='Number of epochs to train', default=10)
@click.option('--batch-size',
              help='Size of mini-batch', default=32)
@click.option('--latent-dim',
              help=('Dimensionality of shared latent space. '
                    'This is analogous to the number of PC axes.'),
              default=3)
@click.option('--input-prior',
              help=('Width of normal prior for input embedding.  '
                    'Smaller values will regularize parameters towards zero. '
                    'Values must be greater than 0.'),
              default=1.)
@click.option('--output-prior',
              help=('Width of normal prior for input embedding.  '
                    'Smaller values will regularize parameters towards zero. '
                    'Values must be greater than 0.'),
              default=1.)
@click.option('--arm-the-gpu', is_flag=True,
              help=('Enables GPU support'),
              default=False)
@click.option('--learning-rate',
              help=('Gradient descent decay rate.'),
              default=1e-1)
@click.option('--beta1',
              help=('Gradient decay rate for first Adam momentum estimates'),
              default=0.9)
@click.option('--beta2',
              help=('Gradient decay rate for second Adam momentum estimates'),
              default=0.95)
@click.option('--clipnorm',
              help=('Gradient clipping size.'),
              default=10.)
@click.option('--checkpoint-interval',
              help=('Number of seconds before a storing a summary.'),
              default=1000)
@click.option('--summary-interval',
              help=('Number of seconds before a storing a summary.'),
              default=1000)
@click.option('--summary-dir', default='summarydir',
              help='Summary directory to save cross validation results.')
@click.option('--embeddings-file', default=None,
              help=('Path to save the embeddings learned from the model. '
                    'If this is not specified, then this will be saved under '
                    '`--summary-dir`.'))
@click.option('--ranks-file', default=None,
              help=('Path to save the ranks learned from the model. '
                    'If this is not specified, then this will be saved under '
                    '`--summary-dir`.'))
@click.option('--ordination-file', default=None,
              help=('Path to save the ordination learned from the model. '
                    'If this is not specified, then this will be saved under '
                    '`--summary-dir`.'))
def paired_omics(microbe_file, metabolite_file,
                 metadata_file, training_column,
                 num_testing_examples, min_feature_count,
                 epochs, batch_size, latent_dim,
                 input_prior, output_prior, arm_the_gpu,
                 learning_rate, beta1, beta2, clipnorm,
                 checkpoint_interval, summary_interval,
                 summary_dir, embeddings_file, ranks_file, ordination_file):

    microbes = load_table(microbe_file)
    metabolites = load_table(metabolite_file)

    if metadata_file is not None:
        metadata = pd.read_table(metadata_file, index_col=0)
    else:
        metadata = None

    # filter out low abundance microbes and split table
    res = split_tables(
        microbes, metabolites,
        metadata=metadata, training_column=training_column,
        num_test=num_testing_examples,
        min_samples=min_feature_count)

    (train_microbes_df, test_microbes_df,
     train_metabolites_df, test_metabolites_df) = res


    params = []

    sname = 'latent_dim_' + str(latent_dim) + \
           '_input_prior_%.2f' % input_prior + \
           '_output_prior_%.2f' % output_prior + \
           '_beta1_%.2f' % beta1 + \
           '_beta2_%.2f' % beta2

    sname = os.path.join(summary_dir, sname)
    if embeddings_file is None:
        embeddings_file = sname + "_embedding.txt"
    if ranks_file is None:
        ranks_file = sname + "_ranks.txt"
    if ordination_file is None:
        ordination_file = sname + "_ordination.txt"


    n, d1 = microbes.shape
    n, d2 = metabolites.shape

    train_microbes_coo = coo_matrix(train_microbes_df.values)
    test_microbes_coo = coo_matrix(test_microbes_df.values)

    if arm_the_gpu:
        # pick out the first GPU
        device_name='/device:GPU:0'
    else:
        device_name='/cpu:0'

    config = tf.ConfigProto()
    with tf.Graph().as_default(), tf.Session(config=config) as session:
        model = MMvec(
            latent_dim=latent_dim,
            u_scale=input_prior, v_scale=output_prior,
            learning_rate = learning_rate,
            beta_1=beta1, beta_2=beta2,
            device_name=device_name,
            clipnorm=clipnorm, save_path=sname)

        model(session,
              train_microbes_coo, train_metabolites_df.values,
              test_microbes_coo, test_metabolites_df.values)

        loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval,
                             checkpoint_interval=checkpoint_interval)

        pc_ids = list(range(latent_dim))
        vdim = model.V.shape[0]
        V = np.hstack((np.zeros((vdim, 1)), model.V))
        V = V.T
        Vbias = np.hstack((np.zeros(1), model.Vbias.ravel()))

        # Save to an embeddings file
        Uparam = format_params(model.U, pc_ids, list(train_microbes_df.columns), 'microbe')
        Vparam = format_params(V, pc_ids, list(train_metabolites_df.columns), 'metabolite')
        df = pd.concat(
            (
                Uparam, Vparam,
                format_params(model.Ubias, ['bias'], train_microbes_df.columns, 'microbe'),
                format_params(Vbias, ['bias'], train_metabolites_df.columns, 'metabolite')
            ), axis=0)

        df.to_csv(embeddings_file, sep='\t')

        # Save to a ranks file
        ranks = pd.DataFrame(model.U @ V.T, index=train_microbes_df.columns,
                             columns=train_metabolites_df.columns)
        ranks = ranks - ranks.mean(axis=1).values.reshape(-1, 1)
        ranks.index.name = 'featureid'
        ranks.to_csv(ranks_file, sep='\t')

        # Save to an ordination file
        ranks = ranks - ranks.mean(axis=0)
        u, s, v = svds(ranks, k=latent_dim)
        microbe_embed = u @ np.diag(s)
        metabolite_embed = v.T
        pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])]
        features = pd.DataFrame(
            microbe_embed, columns=pc_ids,
            index=train_microbes_df.columns)
        samples = pd.DataFrame(
            metabolite_embed, columns=pc_ids,
            index=train_metabolites_df.columns)
        short_method_name = 'mmvec biplot'
        long_method_name = 'Multiomics mmvec biplot'
        eigvals = pd.Series(s, index=pc_ids)
        proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids)
        biplot = OrdinationResults(
            short_method_name, long_method_name, eigvals,
            samples=samples, features=features,
            proportion_explained=proportion_explained)
        biplot.write(ordination_file)


if __name__ == '__main__':
    mmvec()
