#!/usr/bin/env python

import argparse
import ruamel.yaml as ry
from ruamel.yaml.comments import CommentedMap, CommentedSeq
# requires conda install -c conda-forge ruamel.yaml=0.11.7
from collections import OrderedDict


parser = argparse.ArgumentParser(description='Dolo command line')
parser.add_argument('infile', type=argparse.FileType('r'))
args = parser.parse_args()

ff = args.infile
txt = ff.read()

# raise exception if it doesn't work correctly
try:
    data = ry.load(txt,ry.RoundTripLoader)
except:
    print('[]') # should return parse error
    exit()

model_type = 'dtcscc'

known_symbol_types = {
    'dtcscc':['states','controls','auxiliaries','values','shocks','parameters']
}

def check_symbol_validity(s):
    import ast
    val = ast.parse(s).body[0].value
    assert(isinstance(val, ast.Name))

def check_symbols(cm_symbols):

    # can raise three types of exceptions
    # - unknown symbol
    # - invalid symbol
    # - already declared

    # add: not declared if missing 'states', 'controls' ?

    exceptions = []
    assert(isinstance(cm_symbols, CommentedMap))

    already_declared = {} # symbol: symbol_type, position

    for key,values in cm_symbols.items():
        # (start_line, start_column, end_line, end_column) of the key
        if key not in known_symbol_types[model_type]:
            l0,c0,l1,c1 = cm_symbols.lc.data[key]
            exc = Exception("Unknown symbol type '{}' for model type '{}'".format(key, model_type))
            exc.pos = (l0,c0,l1,c1)
            # print(l0,c0,l1,c1)
            exceptions.append( exc )
            assert( isinstance(values,CommentedSeq) )

        for i,v in enumerate(values):
            (l0,c0) = values.lc.data[i]
            length = len(v)
            l1 = l0
            c1 = c0+length
            try:
                check_symbol_validity(v)
            except:
                exc = Exception("Invalid symbol '{}'".format(v))
                exc.pos = (l0,c0,l1,c1)
                exceptions.append(exc)
            if v in already_declared:
                ll = already_declared[v]
                exc = Exception("Symbol '{}' already declared as '{}'. (pos {})".format(v, ll[0], (ll[1][0]+1,ll[1][1])))
                exc.pos = (l0,c0,l1,c1)
                exceptions.append(exc)
            else:
                already_declared[v] = (key, (l0,c0))

    return symbols, exceptions
import ast
def check_equations(data, symbols, model_type):

    pos0 = data.lc.data['equations']
    equations = data['equations']

    exceptions = []
    from dolo.compiler.recipes import recipes
    recipe = recipes[model_type]
    specs = recipe['specs']

    for eq_type in specs.keys():
        if (eq_type not in equations) and (not specs[eq_type].get('optional',True)):
            exc = Exception("Missing equation type {}.".format(eq_type))
            exc.pos = pos0
            exceptions.append(exc)

    already_declared = {}
    unknown = []
    for eq_type in equations.keys():
        pos = equations.lc.data[eq_type]
        if eq_type not in specs:
            exc = Exception("Unknown equation type {}.".format(eq_type))
            exc.pos = pos
            exceptions.append(exc)
            unknown.append(eq_type)

        # BUG: doesn't produce an error when a block is declared twice
        # should be raised by ruaml.yaml ?
        elif eq_type in already_declared.keys():
            exc = Exception("Equation type {} declared twice at ({})".format(eq_type, pos))
            exc.pos = pos
            exceptions.append(exc)
        else:
            already_declared[eq_type] = pos

    for eq_type in [k for k in equations.keys() if k not in unknown]:
        for n,eq in enumerate(equations[eq_type]):
            pos = equations[eq_type].lc.data[n]
            try:
                ast.parse(eq)
            except SyntaxError as e:
                exc = Exception("Syntax Error.")
                exc.pos = [pos[0], pos[1]+e.offset, pos[0], pos[1]+e.offset]
                exceptions.append(exc)

        # TEMP: incorrect ordering
        if specs[eq_type].get('target'):
            for n,eq in enumerate(equations[eq_type]):
                pos = equations[eq_type].lc.data[n]
                lhs_name = str.split(eq,'=')[0].strip()
                target = specs[eq_type]['target'][0]
                if lhs_name not in data['symbols'][target]:
                    exc = Exception("Undeclared assignement target '{}'. Add it to '{}'.".format(lhs_name, target))
                    exc.pos = [pos[0], pos[1], pos[0], pos[1]+len(lhs_name)]
                    exceptions.append(exc)
                # if n>len(data['symbols'][target]):
                else:
                    right_name = data['symbols'][target][n]
                    if lhs_name!=right_name:
                        exc = Exception("Left hand side should be '{}' instead of '{}'.".format(right_name, lhs_name))
                        exc.pos = [pos[0], pos[1], pos[0], pos[1]+len(lhs_name)]
                        exceptions.append(exc)
        # temp
    return exceptions

def check_calibration(data):
    ## what happens here if symbols are not clean ?
    symbols = data['symbols']
    pos0 = data.lc.data['calibration']
    calibration = data['calibration']
    exceptions = []
    all_symbols = []
    for v in symbols.values(): all_symbols += v
    for s in all_symbols:
        if s not in calibration.keys():
            # should skip invalid symbols there
            exc = Exception("Symbol {} has no calibrated value.".format(s))
            exc.pos = pos0
            exc.type = 'warning'
            exceptions.append(exc)
    for s in calibration.keys():
        val = str(calibration[s])
        try:
            ast.parse(val)
        except SyntaxError as e:
            pos = calibration.lc.data[s]
            exc = Exception("Syntax Error.")
            exc.pos = [pos[0], pos[1]+e.offset, pos[0], pos[1]+e.offset]
            exceptions.append(exc)
    return exceptions


symbols = data['symbols']

import json
symbols, exceptions = check_symbols(data['symbols'])

exceptions += check_equations(data, symbols, model_type)
exceptions += check_calibration(data)


source = ff.name
output = []
for k in exceptions:
    try:
        err_type = k.type
    except:
        err_type = 'error'
    output.append({
        'type': err_type,
        'source': source,
        'range': ((k.pos[0],k.pos[1]),(k.pos[2],k.pos[3])),
        'text': k.args[0]
    })
print(json.dumps(output))
    # print('stdin:{}: {}'.format(k.pos,k.args[0]))

# TODO:
# - check name (already defined by smbdy else ?)
# - description: ?
# - calibration:
#      - incorrect key
#          - warning if not a known symbol ?
#          - not a recognized identifier
#          - defined twice
#      - impossible to solve in closed form (depends on ...)
#      - incorrect equation
#           - grammatically incorrect
#           - contains timed variables
#      - warnings:
#           - missing values
# - equations: symbols already known (beware of speed issues)
#     - unknown group of equations
#     - incorrect syntax
#     - undeclared variable (and not a function)
#     - indexed parameter
#     - incorrect order
#     - incorrect complementarities
#     - incorrect recipe: unexpected symbol type
#     - nonzero residuals (warning, to be done without compiling)
# - options: if present
#     - approximation_space:
#          - inconsistent boundaries
#                - must equal number of states
#     - distribution:
#          - same size as shocks
