#!/usr/bin/env python3
# coding: utf-8

import sqlite3
from argparse import ArgumentParser
import pandas as pd
import matplotlib.pyplot as plt

if __name__ == '__main__':
	parser = ArgumentParser()
	parser.add_argument('database_path', type=str)
	parser.add_argument('output_file', type=str)
	parser.add_argument('dataset_key_criterion', type=str)
	parser.add_argument('param_criterion', type=str)

	args = vars(parser.parse_args())

	try:
		con = sqlite3.connect(args['database_path'])
		cur = con.cursor()
	except:
		print(f'Unable to connect to {args["database_path"]}')
		exit(1)
	
	# idée : mettre pipeline_scheme devant le nom de la clé dans la base de donnée
	data = pd.DataFrame(cur.execute(f"""SELECT DISTINCT p0.{args['param_criterion']}, p0.nb_params, p1.param_value, p2.param_value, p3.pipeline
		FROM run_results p0
			INNER JOIN params p1 ON p0.run_id = p1.run_id
			INNER JOIN params p2 ON p0.run_id = p2.run_id
			INNER JOIN runs p3 ON p0.run_id = p3.run_id
		WHERE p0.run_id IN
		(SELECT run_id FROM params WHERE param_value = ?)
		AND (p1.param_name='dataset_key' AND p2.param_name LIKE 'fitter_key')""", (args['dataset_key_criterion'],)).fetchall(),
		columns=['criterion', 'nb_params', 'dataset_name', 'fitter_name', 'name'])


	dfs = dict(tuple(data.groupby('dataset_name')))
	dfs = {key:dict(tuple(dfs[key].groupby('fitter_name'))) for key in dfs}

	for data_name in dfs:
		fig = plt.figure()
		for model_name in dfs[data_name]:
			dfs[data_name][model_name] = dfs[data_name][model_name].groupby(['nb_params']).mean().reset_index()
			plt.plot(dfs[data_name][model_name]['nb_params'], dfs[data_name][model_name]['criterion'], label=model_name)
		plt.xlabel('nb_params')
		plt.ylabel(args['param_criterion'])
		plt.title(data_name)
	

	plt.legend()
	plt.savefig(args['output_file'])
	print(f'Plot saved in {args["output_file"]}')
