#!python
import argparse
import nbclient
import nbformat
import nbparameterise
import sys
import os
import time
from collections import defaultdict
parser = argparse.ArgumentParser(
    description='Notebook Runner',
    add_help=False
)
parser.add_argument('Notebook')
parser.add_argument('--use_notebook_client',default=False,action='store_true', help="Warning: Using this disables stdin")
parser.add_argument('--master_debug',default=False,action='store_true', help="")

nboargs, unknown = parser.parse_known_args()

debugMode = nboargs.master_debug
debug_cell_sourcecode = debugMode
debug_print_cell_numbers = debugMode
debug_sleep_between_cells = 0
debug_print_args=debugMode
debug_print_params=debugMode
debug_print_param_params = debugMode
os.chdir(os.path.dirname(nboargs.Notebook))
with open(nboargs.Notebook, 'rb') as f:
    nb = nbformat.read(f, as_version=4)

notebook_params = nbparameterise.extract_parameters(nb)
if debug_print_params:
    print('orginal params', notebook_params)


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
flagglessArgs={}

requiredNamed = parser.add_argument_group('Required notebook arguments')
optionalNamed = parser
gs ={}
gs[True] = requiredNamed
gs[False] = optionalNamed
p_prams=defaultdict(lambda:{
    'type':None,
    'required':False,
    'toLower':False,
    'flagglessArgs':False
})


for item in notebook_params:
    name =item.name
    p_prams[name]['type'] = item.type
    comment = ''
    value = item.value
    if item.comment is not None:
        commentlist = item.comment.split('#')  
        commentlist.reverse()
        while len(commentlist):
            comment = commentlist.pop(0)
            if comment.lower().endswith("required"):
                p_prams[name]['required'] = True

            elif comment.lower().endswith("tolower"):
                p_prams[name]['toLower'] = True
            else:
                commentlist.reverse()
                comment= '#'.join(commentlist)
                break

    if item.name.startswith('__') and item.name.endswith('__'):
        name=name.replace('__','')
        p_prams[name]['flagglessArgs'] = item.name
        if item.type == list:
            parser.add_argument(f'{name}',  choices=value, help = f"{comment}")    
        else:
            parser.add_argument(f'{name}', nargs='?', default=value, type=item.type, help = f"{comment}  (default: %(default)s)")


    else:    
        if p_prams[name]['required']:
            flag = f'-{name}'
            if value  == '':
                value = None
        else:
            flag = f'--{name}'

        if item.type == list:
            gs[p_prams[name]['required']].add_argument(f'{flag}',  choices=value, help = f"{comment}",required=p_prams[name]['required'])    
        elif item.type == bool:
            gs[p_prams[name]['required']].add_argument(f'{flag}',  choices=['True','False'], help = f"{comment}",required=p_prams[name]['required'])   
        else:
            gs[p_prams[name]['required']].add_argument(f'{flag}', nargs='?', default=value, type=item.type, help = f"{comment}  (default: %(default)s)",required=p_prams[name]['required'])
if debug_print_param_params:            
    print('p_prams',p_prams)
args = sys.argv
if debug_print_args:
    print(args)
args.pop(0)
args.pop(0)
if debug_print_args:
    print(args)
args, unknown = parser.parse_known_args(args)
if debug_print_args:
    print('NBARGS',args,unknown)


final_args = {}
args = vars(args)
for a in args:
    if p_prams[a]['flagglessArgs']:
        final_args[f'__{a}__'] = args[a]
    else:
        final_args[a]=args[a]



new_params = nbparameterise.parameter_values(notebook_params,**final_args)
if debug_print_params:
    print('replacement params', new_params)
for item in new_params:
    if p_prams[item.name]['type'] ==bool:
        item.value = ('True' == item.value)
new_nb = nbparameterise.replace_definitions(nb, new_params)



if not nboargs.use_notebook_client:
    for cell_index, cell in enumerate(new_nb['cells']):
        if cell['cell_type'] =='code':
            if debug_cell_sourcecode:
                print(cell)
            if debug_print_cell_numbers:
                print(f"[{cell_index}]")
            exec(cell['source'])
            if debug_print_cell_numbers:
                print()
            if debug_sleep_between_cells>0:
                time.sleep(debug_sleep_between_cells)


else:
    print("Using Notebook client")
    output = nbclient.execute(new_nb)
    for item in output['cells']:
        for o in item['outputs']:
            print(o['text'].strip())