#!/usr/bin/env python3
# coding: utf-8

import re
import sys
import torch
import sqlite3
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np


'''
    Je dois aller chercher à la façon rerun les ids
    des runs que je veux voir apparaitre sur le plots
'''

def usage(name=sys.argv[0]):
	return f'''{name} [database_path] [output_name]
 	[-h]
	[-id ID, RANGE, LIST | str]
	[-request SQL_REQUEST | str]
	[-param_name_and_value NAME VALUE | str str]
	[-param_name_and_value_low_threshold NAME VALUE | str str]
	[-param_name_and_value_high_threshold NAME VALUE | str str]

    [-plot_size | int]

	'''

def flag_parse(args, database_path):
    if args['id'] is not None:
        return args_parse_id(args['id'], database_path)

    elif args['param_name_value'] is not None:
        return args_parse_param_name_value(args['param_name_value'], database_path)

    elif args['param_name_value_low_threshold'] is not None:
        return args_parse_param_low_threshold(args['param_name_value_low_threshold'], database_path)

    elif args['param_name_value_high_threshold'] is not None:
        return args_parse_param_high_threshold(args['param_name_value_high_threshold'], database_path)

    elif args['request'] is not None:
        return args_parse_request(args['request'])

    else:
        return {}

def args_parse_id(value, database_path):
    con = sqlite3.connect(database_path)
    cur = con.cursor()
    if value == 'all':
        data = cur.execute('''SELECT runs.run_id, runs.pipeline, run_results.nb_params
                              FROM runs
                                INNER JOIN run_results ON runs.run_id = run_results.run_id''').fetchall()
    else:
        ids_sets = []
        ids_set = set()
        list_pattern = re.compile(r"^(\d|\d-\d)+(,(\d|\d-\d)+)*$")
        range_pattern = re.compile(r"\d+-\d+")
        result = list_pattern.match(value)
        if result:
            ids_list = result.group(0).split(',')
            for elem in ids_list:
                result = range_pattern.match(elem)
                if result:
                    mini, maxi = map(int, result.group(0).split('-'))
                    ids_sets.append(set(range(mini, maxi + 1)))
                else:
                    ids_sets.append({elem})
            for s in ids_sets:
                ids_set |= s
        data = cur.execute(f'''SELECT runs.run_id, runs.pipeline, run_results.nb_params 
                               FROM runs
                                    INNER JOIN run_results ON runs.run_id = run_results.run_id
                               WHERE runs.run_id {'IN' + str(tuple(ids_set)) if len(ids_set) > 1 else '=' + str(*ids_set)}''').fetchall()
    return data

def args_parse_param_name_value(arg, database_path):
    name, value = arg
    try:
        db_con = sqlite3.connect(database_path)
        db_cur = db_con.cursor()
        data = db_cur.execute('''SELECT params.run_id, runs.pipeline, run_results.nb_params
                                 FROM params
                                 INNER JOIN runs ON runs.run_id = params.run_id
                                 INNER JOIN run_results ON run_results.run_id = params.run_id
                                 WHERE params.param_name=? AND params.param_value=?''', (name, value)).fetchall()
        return data
    except Exception as e:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

def args_parse_param_low_threshold(arg, database_path):
    name, value = arg
    try:
        db_con = sqlite3.connect(database_path)
        db_cur = db_con.cursor()
        data = db_cur.execute('''SELECT params.run_id, runs.pipeline, run_results.nb_params
                                FROM params
                                    INNER JOIN runs ON runs.run_id = params.run_id
                                    INNER JOIn run_results ON run_results.run_id = params.run_id
                                WHERE param_name=? AND param_value<=?''', (name, value)).fetchall()
        return data
    except Exception as e:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

def args_parse_param_high_threshold(arg, database_path):
    name, value = arg
    try:
        db_con = sqlite3.connect(database_path)
        db_cur = db_con.cursor()
        data = db_cur.execute('''SELECT params.run_id, runs.pipeline, run_results.nb_params
                                FROM params
                                    INNER JOIN runs ON runs.run_id = params.run_id
                                    INNER JOIn run_results ON run_results.run_id = params.run_id
                                WHERE param_name=? AND param_value>=?''', (name, value)).fetchall()
        return data
    except Exception as e:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

def args_parse_request(request, database_path):
    try:
        db_con = sqlite3.connect(database_path)
        db_cur = db_con.cursor()
        try:
            data = db_cur.execute(request).fetchall()
        except Exception as e:
            print(e)
            print(f'Unable to request "{request}" to database.')
            exit(1)
    except:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)


'''
    requêtage dans la base de donnée des ids
    Il faut que je récupère en même temps la valeur de la pipeline dans la table runs
'''

if __name__ == '__main__':
    parser = ArgumentParser(usage=usage())
    parser.add_argument('database_path', type=str)
    parser.add_argument('output_name', type=str)
    parser.add_argument('runs_files_directory', type=str)

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('-id', type=str, help='test', metavar='ID, RANGE, LIST')
    group.add_argument('-request', type=str, help='', metavar='\t')
    group.add_argument('-param_name_value', type=str, nargs=2, help='', metavar='\t')
    group.add_argument('-param_name_value_low_threshold', type=str, nargs=2, help='', metavar='\t')
    group.add_argument('-param_name_value_high_threshold', type=str, nargs=2, help='', metavar='\t')

    parser.add_argument('-plot_size', type=int, default=2)

    args = vars(parser.parse_args())

    data = flag_parse(args, args['database_path'])
    if not data:
        print('No ids selected.')
        exit(0)

    pdf = PdfPages(args['output_name'])
    plt_size = args['plot_size']
    run_dir = args['runs_files_directory']
    if not run_dir.endswith('/'):
        run_dir += '/'

    legendFig = plt.figure()

    for i, (id, pipeline, nb_params) in enumerate(data):
        if i % (plt_size ** 2) == 0:
            fig, ax = plt.subplots(plt_size, plt_size)
        try:

            history = torch.load(f'{run_dir}history_{id}.pt')
            
            hist_test = np.array(history['test_loss'])
            hist_train = np.array(history['train_loss'])
            nb_epochs = len(hist_train)
            
            ten_p = nb_epochs // 10
            last_train = hist_train[nb_epochs - ten_p:]
            last_train_dec = hist_train[nb_epochs - ten_p - 1:-1]

            max_test = np.max(hist_test)
            max_train = np.max(hist_train)
            min_test = np.min(hist_test)
            min_train = np.min(hist_train)
            
            diff_hist = np.abs(hist_test - hist_train)
            max_data = np.max(np.array([max_test, max_train]))
            min_data = np.min(np.array([min_test, min_train]))
            range_data = max_data - min_data

            overfit = diff_hist[-1] / (range_data)
            trainability = np.sum(hist_test) / nb_epochs
            slope = np.mean(last_train - last_train_dec)

            int_counter = i % (plt_size ** 2)
            x = int_counter // plt_size
            y = int_counter % plt_size
            l1, = ax[x, y].plot(range(len(history['train_loss'])), history['train_loss'], label='train')
            l2, = ax[x, y].plot(range(len(history['test_loss'])), history['test_loss'], c='r', linestyle='--', label='test')
            ax[x, y].set_title(str(id) + ':' + pipeline.replace(' | ', '\n') + ':' +  str(nb_params), fontsize=8)    
            ax[x, y].text(0.4, 0.7, f"overfit {overfit:.2f}\ntrainability {trainability:.2f}\nslope {slope:.3E}", size=16 // plt_size, transform=ax[x, y].transAxes, alpha=0.5)

            if i % plt_size ** 2 == plt_size ** 2 - 1 or i == len(data) - 1:
                it = i
                while x != plt_size - 1 or y != plt_size - 1:
                    it += 1
                    int_counter = it % (plt_size ** 2)
                    x = int_counter // plt_size
                    y = int_counter % plt_size
                    fig.delaxes(ax[x, y])

                plt.tight_layout()
                pdf.savefig(fig)
                if i == len(data) - 1:
                    legendFig.legend([l1, l2], ['train', 'test'], loc='center', fontsize=20)
                    pdf.savefig(legendFig)
                plt.close()
        except Exception as exp:
            import traceback
            trace = traceback.format_exc()
            print(trace)
            print(exp)
            print(f'File {run_dir}history_{id}.pt not found.')


pdf.close()
print(f'Plots saved in {args["output_name"]}')
