#!/usr/bin/env python

import argparse
import json
import traceback
import pandas as pd
from lpsds.logger import log
from scripts.operation import ProductionModelOperation


def get_options():
    """
    Essa função define os parãmetros necessários para a execução deste script que devem ser passados por linha de comando.
    
    NÃO MEXA EM NADA NESTA FUNÇÃO. 
    """
    parser = argparse.ArgumentParser(
                    prog='Script de treinamento do classificador.',
                    description='Treina um modelo. O modelo treinado será salvo em ',
                    epilog='Este programa roda dentro de ima imagem docker própria apra garantir compatibilidade.'
                  )
    parser.add_argument('-t', '--trn_dataset', default=None, help='Um arquivo parquet contendo o conjunto de treino.')
    parser.add_argument('-v', '--val_dataset', default=None, help='Um arquivo parquet contendo o conjunto de validação.')
    parser.add_argument('-c', '--config', default=None, help='Um arquivo JSON contendo os hiperparâmetros do modelo.')
    parser.add_argument('-g', '--gpu', dest='gpu', action='store_true', default=False, help='If set, the code will be run in GPU (if one is available).')
    parser.add_argument('-o', '--output', dest='output', default=None, help='Onde o modelo treinado será salvo.')

    args = parser.parse_args()

    assert args.trn_dataset is not None, 'Você precisa especificar um arquivo parquet contendo as infos de treino. Veja o help.'
    assert args.val_dataset is not None, 'Você precisa especificar um arquivo parquet contendo as infos de validação. Veja o help.'
    assert args.config is not None, 'Você precisa especificar o json com as configurações do seu modelo. Veja o help.'
    assert args.output is not None, 'Você precisa especificar onde salvará o modelo de saída. Veja o help.'
    return args



def main():
    """
    Função principal do script. Ela organiza o workflow necessário para a operação do modelo.

    NÃO MEXA NESTA FUNÇÃO
    """
    opt = get_options()
    device_type = 'gpu' if  opt.gpu else 'cpu'
    
    log.info('Carregando as imagens do dataset de treino e validação.')
    trn_df = pd.read_parquet(opt.trn_dataset)
    val_df = pd.read_parquet(opt.val_dataset)
    log.info('  Total de %d imagens carregadas para treino.', trn_df.shape[0])
    log.info('  Total de %d imagens carregadas para validação.', val_df.shape[0])

    log.info('Carregando a base de hiperparâmetros para o modelo.')
    hyperparams = json.load(open(opt.config, 'r'))
    for k,v in hyperparams.items():
        log.info('  %s: %s', k, v)

    log.info('Começando o treinamento do modelo')
    op = ProductionModelOperation(device_type=device_type)
    model = op.train_model(trn_df, val_df, **hyperparams)

    log.info('Salvando o modelo treinado em "%s".', opt.output)
    model_save_path = op.save_model(model, opt.output)

    log.info('Fim da execução. Nenhum problema encontrado.')


if __name__ == "__main__":
    try: main()
    except Exception as e:
        log.fatal ('    ' + str(e))
        traceback.print_exc()
