#!/usr/bin/env python3
# coding: utf-8

from argparse import ArgumentParser
from configparser import ConfigParser
import json
import pandas as pd
import numpy as np
import sqlite3
import re
import matplotlib.pyplot as plt
import mosaic.lib.plt_params
from seaborn import violinplot

class Perf_Analysis():
	def __init__(self, config_file, database_file):
		self.config_file = config_file
		self.database_file = database_file
		self.conf_parser = ConfigParser()
		self.conf_parser.read(config_file)
		self.scheduler()

	def _format_pipeline(self, pipe):
		pipe = json.loads(pipe)
		res = []
		for elem in pipe:
			values = []
			for key, value in elem.items():
				if (key != 'name' and key != 'type' and key != 'path_to_class'
				and key != 'class' and key != 'key'):
					values.append(str(value))
			res.append(elem['class'] + '(' + ','.join(values) + ')')
		return ' | '.join(res)

	def replace_list_config_section(self, config_section):
		for key, values in config_section.items():
			list_of_list = re.findall(r'\[(.+?)\]|(\w+)', str(values))
			list_of_list = ["".join(elem) for elem in list_of_list]
			conf_section_list = []
			for li in list_of_list:
				conf_section_list.append(re.findall(r'[\w\.\/\:{}-]+', li))
			config_section[key] = conf_section_list
		return config_section

	def find_pipelines_in_database(self):
		self.con = sqlite3.connect(self.database_file)
		self.cur = self.con.cursor()
		ids_set = {}
		self.includes = list(zip(self.include_keys, self.include_values))
		for key, values in self.includes:
			_where = (f'\nWHERE p0.param_name LIKE "%{key[0]}" AND (' +
					  ' OR '.join([f'p0.param_value LIKE "%{value}"' for value in values])
					  + ') AND p1.param_name LIKE "%pipeline"')

			select = set(self.cur.execute('''
				SELECT DISTINCT p0.run_id, p1.param_value
				FROM params p0
					INNER JOIN params p1 ON p0.run_id = p1.run_id'''
				+ _where
				).fetchall())

			if len(ids_set) != 0:
				ids_set &= select
			else:
				ids_set = select

		df = pd.DataFrame(ids_set, columns=['run_id', 'pipeline'])
		return df

	def crop_pipeline(self, pipeline):
		pipeline = json.loads(pipeline)
		for module in pipeline:
			for to_exclude in self.excludes:
				if to_exclude in module: 
					module.pop(to_exclude)
		return json.dumps(pipeline)

	def get_abscissa_and_ordinate(self, pipeline_ids_dataframe):
		df = pd.DataFrame(pipeline_ids_dataframe['run_id'])
		list_of_params = self.abscissae + self.ordinates
		for param in list_of_params:
			if param in ['test_loss', 'train_loss', 'test_acc', 'train_acc', 'epochs', 'nb_params', 'duration(s)', 'overfit', 'trainability', 'slope']:
				df_parameters = pd.DataFrame(self.cur.execute(f'''
				SELECT run_id, {param}
				FROM run_results
				WHERE run_id IN {tuple(pipeline_ids_dataframe['run_id'])}''').fetchall(), columns=['run_id', param])
			else:
				param_like_syntax = '"%' + param + '"'
				ret = self.cur.execute(f'''
				SELECT run_id, param_value
				FROM params
				WHERE param_name LIKE {param_like_syntax} AND run_id IN {tuple(pipeline_ids_dataframe['run_id'])}
				''').fetchall()
				df_parameters = pd.DataFrame(ret, columns=['run_id', param])
			if df_parameters.empty:
				raise Exception(f'Empty dataframe on param = {param}.')

			run_ids_match = all(np.sort(df_parameters['run_id']) == np.sort(df['run_id']))

			if not run_ids_match:
				raise Exception(f"Ids don't match on param = {param}")

			df = pd.merge(df, df_parameters, on='run_id')
		return pd.merge(df, pipeline_ids_dataframe, on='run_id')

	def plot(self, config_section, plot_name):
		self.abscissae = config_section['abscissae']
		self.ordinates = config_section['ordinates']
		self.include_keys = config_section['include_keys']
		self.include_values = config_section['include_values']
		self.excludes = config_section['excludes']
		self.plot_type = config_section['plot_type']
		self.abscissae = list(map(lambda elem : elem[0], self.abscissae))
		self.ordinates = list(map(lambda elem : elem[0], self.ordinates))
		self.excludes = list(map(lambda elem : elem[0], self.excludes))
		self.plot_type = self.plot_type[0][0]

		df = self.find_pipelines_in_database()
		if df.empty:
			print(f"No data found in database with section {plot_name}.")
			return 
		df = self.get_abscissa_and_ordinate(df)

		df['pipeline'] = df['pipeline'].apply(self.crop_pipeline)
		df['pipeline'] = df['pipeline'].apply(self._format_pipeline)
	
		aggr_dict = {}
		list_of_params = self.abscissae + self.ordinates
		for col in list_of_params:
			aggr_dict[col] = [np.mean, np.std]

		list_of_pipelines = df['pipeline'].unique()
		for abs in self.abscissae:
			for ord in self.ordinates:
				if self.plot_type == 'line' or self.plot_type == 'point':
					for pipeline in list_of_pipelines:
						extract_df = df[df['pipeline'] == pipeline]
						extract_df_mean = extract_df.groupby([abs, 'pipeline']).mean().reset_index()
						extract_df_std = extract_df.groupby([abs, 'pipeline']).std().reset_index()
						plt.errorbar(extract_df_mean[abs], extract_df_mean[ord], yerr=extract_df_std[ord], label=pipeline)
						#plt.plot(extract_df_mean[abs], extract_df_mean[ord], label=pipeline)
						#plt.fill_between(extract_df_mean[abs], extract_df_mean[ord] - extract_df_std[ord], extract_df_mean[ord] + extract_df_std[ord], alpha=0.5)
				elif self.plot_type == 'violin':
					violinplot(data=df, y=ord, x=abs, hue='pipeline', scale='count', inner='box')

			
				#plt.title(plot_name)
				plt.xlabel(abs)
				plt.ylabel(ord)
				plt.legend()
				plt.tight_layout()
				plt.show()
		

	def scheduler(self):
		for plot_name in self.conf_parser.sections():
			if plot_name != 'global':
				section = dict(self.conf_parser[plot_name])
				section = self.replace_list_config_section(section)
				self.plot(section, plot_name)




if __name__ == '__main__':
	arg_parser = ArgumentParser()
	arg_parser.add_argument('config_file', type=str)
	arg_parser.add_argument('database_file', type=str)
	args = vars(arg_parser.parse_args())

	Perf_Analysis(args['config_file'], args['database_file'])