#!/usr/bin/env python3
# coding: utf-8

import sqlite3
import seaborn as sns
from argparse import ArgumentParser
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

if __name__ == '__main__':
	parser = ArgumentParser()
	parser.add_argument('database_path', type=str)
	parser.add_argument('output_file', type=str)

	args = vars(parser.parse_args())

	try:
		con = sqlite3.connect(args['database_path'])
		cur = con.cursor()
	except:
		print(f'Unable to connect to {args["database_path"]}')
		exit(1)
	
	criterions = ['test_loss', 'train_loss', 'overfit', 'trainability', 'duration(s)', 'nb_params']
	
	get_params = ''
	for crit in criterions:
		get_params += f'p0."{crit}", ' 

	data = pd.DataFrame(cur.execute(f"""SELECT DISTINCT {get_params} p1.pipeline
		FROM run_results p0
			INNER JOIN runs p1 ON p0.run_id = p1.run_id
		""").fetchall(),
		columns=criterions + ['pipeline'])


	dfs = dict(tuple(data.groupby('pipeline')))
	print(dfs)

	pdf = PdfPages(args['output_file'])

	for pipeline in dfs:
		print(pipeline)
		for crit in criterions:
			fig, axes = plt.subplots(1, 2)
			sns.violinplot(ax=axes.flatten()[1], data=dfs[pipeline], y=crit, x='nb_params', cut=0)
			df_mean = dfs[pipeline].groupby(['nb_params']).mean().reset_index()
			axes.flatten()[0].plot(df_mean['nb_params'], df_mean[crit])
			axes.flatten()[0].set_xlabel('nb_params')
			axes.flatten()[0].set_ylabel(crit)
			axes.flatten()[1].set_xlabel('nb_params')
			axes.flatten()[1].set_ylabel(crit)
			fig.suptitle(pipeline.replace(' | ', '\n'))
			fig.tight_layout()
			pdf.savefig(fig)
	'''
	print(dfs)
	for data_name in dfs:
		for model_name in dfs[data_name]:
			for crit in criterions:
				print(model_name + crit)
				# FIXME to remove
				if crit != 'slope':
					fig, axes = plt.subplots(1, 2)
					sns.violinplot(ax=axes.flatten()[1], data=dfs[data_name][model_name], y=crit, x='nb_params', cut=0)
					df_mean = dfs[data_name][model_name].groupby(['nb_params']).mean().reset_index()
					axes.flatten()[0].plot(df_mean['nb_params'], df_mean[crit], label=model_name)
					axes.flatten()[0].set_xlabel('nb_params')
					axes.flatten()[0].set_ylabel(crit)
					axes.flatten()[0].legend()
					axes.flatten()[1].set_xlabel('nb_params')
					axes.flatten()[1].set_ylabel(crit)
					fig.suptitle(data_name + '\n' +  model_name)
					fig.tight_layout()
					pdf.savefig(fig)
			
	'''
	pdf.close()
	print(f'Plot saved in {args["output_file"]}')