#!/usr/bin/env python3
# coding: utf-8

import re
import sys
import json
import sqlite3
from configparser import ConfigParser
from mosaic import Scheduler
from argparse import ArgumentParser

def usage(name=sys.argv[0]):
	return f'''{name} [monitor_config_path] [database_path]
 	[-h]
	[-id ID, RANGE, LIST | str]
	[-request SQL_REQUEST | str]
	[-param_name_and_value NAME VALUE | str str]
	[-param_name_and_value_low_threshold NAME VALUE | str str]
	[-param_name_and_value_high_threshold NAME VALUE | str str]

	[-epoch EPOCH | int]
	'''

def flag_parse(args, database_path):
    if args['id'] is not None:
        return args_parse_id(args['id'])

    elif args['param_name_value'] is not None:
        return args_parse_param_name_value(args['param_name_value'], database_path)

    elif args['param_name_value_low_threshold'] is not None:
        return args_parse_param_low_threshold(args['param_name_value_low_threshold'], database_path)

    elif args['param_name_value_high_threshold'] is not None:
        return args_parse_param_high_threshold(args['param_name_value_high_threshold'], database_path)

    elif args['request'] is not None:
        return args_parse_request(args['request'])

    else:
        return {}


def args_parse_id(value):
    ids_set = {}
    range_pattern = re.compile(r"\d+\-\d+")
    result = range_pattern.match(value)
    if result:
        mini, maxi = map(int, result.group(0).split('-'))
        ids_set = set(map(str, range(mini, maxi + 1)))
    list_pattern = re.compile(r"^\d+(\,\d+)*$")
    result = list_pattern.match(value)
    if result:
        ids_set = set(result.group(0).split(','))
    return ids_set

def args_parse_id(value):
    ids_sets = []
    ids_set = set()
    list_pattern = re.compile(r"^(\d|\d-\d)+(,(\d|\d-\d)+)*$")
    range_pattern = re.compile(r"\d+-\d+")
    result = list_pattern.match(value)
    if result:
        ids_list = result.group(0).split(',')
        for elem in ids_list:
            result = range_pattern.match(elem)
            if result:
                mini, maxi = map(int, result.group(0).split('-'))
                ids_sets.append(set(map(str, range(mini, maxi + 1))))
            else:
                ids_sets.append({elem})

        for s in ids_sets:
            ids_set |= s
    return ids_set

def args_parse_param_name_value(arg, database_path):
    name, value = arg
    try:
        db_con = sqlite3.connect(database_path)
        db_con.row_factory = lambda cursor, row: row[0]
        db_cur = db_con.cursor()
        ids = db_cur.execute('''SELECT run_id FROM params WHERE param_name=? AND param_value=?''', (name, value)).fetchall()
        if ids:
            return set(ids)
        return {}
    except Exception as e:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

def args_parse_param_low_threshold(arg, database_path):
    name, value = arg
    try:
        db_con = sqlite3.connect(database_path)
        db_con.row_factory = lambda cursor, row: row[0]
        db_cur = db_con.cursor()
        ids = db_cur.execute('''SELECT run_id FROM params WHERE param_name=? AND param_value<=?''', (name, value)).fetchall()
        if ids:
            return set(ids)
        return {}
    except Exception as e:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

def args_parse_param_high_threshold(arg, database_path):
    name, value = arg
    try:
        db_con = sqlite3.connect(database_path)
        db_con.row_factory = lambda cursor, row: row[0]
        db_cur = db_con.cursor()
        ids = db_cur.execute('''SELECT run_id FROM params WHERE param_name=? AND param_value>=?''', (name, value)).fetchall()
        if ids:
            return set(ids)
        return {}
    except Exception as e:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

def args_parse_request(request, database_path):
    try:
        db_con = sqlite3.connect(database_path)
        db_con.row_factory = lambda cursor, row: row[0]
        db_cur = db_con.cursor()
        try:
            ids = db_cur.execute(request).fetchall()
        except Exception as e:
            print(e)
            print(f'Unable to request "{request}" to database.')
            exit(1)
    except:
        print(e)
        print(f'Unable to connect to {database_path}.')
        exit(1)

    return set(ids)

if __name__ == '__main__':

    parser = ArgumentParser(usage=usage())
    parser.add_argument('monitor_config_path', type=str)
    parser.add_argument('database_path', type=str)

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('-id', type=str, help='test', metavar='ID, RANGE, LIST')
    group.add_argument('-request', type=str, help='', metavar='\t')
    group.add_argument('-param_name_value', type=str, nargs=2, help='', metavar='\t')
    group.add_argument('-param_name_value_low_threshold', type=str, nargs=2, help='', metavar='\t')
    group.add_argument('-param_name_value_high_threshold', type=str, nargs=2, help='', metavar='\t')

    parser.add_argument('-epoch', type=int, help='epoch', required=True)

    args = vars(parser.parse_args())

    ids = list(map(int, flag_parse(args, args['database_path'])))
    if not ids:
        print('No ids selected.')
        exit(0)
    try:
        con = sqlite3.connect(args['database_path'])
        cur = con.cursor()
    except:
        print('Unable to connect to database.')
        exit(1)
    
    config = ConfigParser()
    config.read(args['monitor_config_path'])

    pipelines = []
    for i in range((len(ids) // 500) + 1):
        pipelines.extend(cur.execute(f'''SELECT param_value FROM params WHERE (param_name="pipeline" AND run_id IN ({','.join(['?'] * len(ids[500*i:500*(i+1)]))}))''', ids[500*i:500*(i+1)], ).fetchall())

    pipelines = [json.loads(pipeline[0]) for pipeline in pipelines]

    if not 'MONITOR' in config.sections():
        print('MONITOR section not found in config file.')
        exit(1)
    elif not 'PROCESS' in config.sections():
        print('PROCESS section not found in config file.')
        exit(1)
    else:
        monitor_info = dict(config['MONITOR'])
        monitor_info['database_path'] = args['database_path']
        
        select_results = []
        for i in range((len(ids) // 500) + 1):
            select_results.extend(cur.execute(f'''SELECT * FROM params WHERE run_id IN ({','.join(['?'] * len(ids[500*i:500*(i+1)]))})''', ids[500*i:500*(i+1)], ).fetchall())
        params_ids, params_names, params_values = map(list, zip(*select_results))
        rerun_params = {}
        for i in range(len(params_names)):
            if not params_ids[i] in rerun_params:
                rerun_params[params_ids[i]] = {params_names[i] : params_values[i]}
            else:
                rerun_params[params_ids[i]][params_names[i]] = params_values[i]
            rerun_params[params_ids[i]]['epochs'] = args['epoch']
            rerun_params[params_ids[i]]['database_path'] = args['database_path']
                    
    monitor = Scheduler(pipelines, monitor_info, None, rerun_ids=ids, rerun_params=rerun_params)
