#! /usr/bin/env python3

# This file is part of rddl2tf.

# rddl2tf is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# rddl2tf is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with rddl2tf. If not, see <http://www.gnu.org/licenses/>.

import argparse
import os
import tensorflow as tf

import rddlgym

from rddl2tf import version
from rddl2tf.compiler import Compiler


def parse_args():
    description = 'rddl2tf (v{}): RDDL2TensorFlow compiler in Python3.'.format(version.__version__)
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        'rddl',
        type=str,
        help='path to RDDL file or rddlgym problem id'
    )
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=256,
        help='number of fluents in a batch (default=256)'
    )
    parser.add_argument(
        '--logdir',
        type=str, default='/tmp/rddl2tf',
        help='log directory for tensorboard graph visualization (default=/tmp/rddl2tf)'
    )
    return parser.parse_args()


def save_model(model, compiler, logdir):
    logdir = os.path.join(logdir, model.domain.name, model.instance.name)
    file_writer = tf.summary.FileWriter(logdir, compiler.graph)
    return file_writer.get_logdir()


if __name__ == '__main__':
    args = parse_args()

    model = rddlgym.make(args.rddl, mode=rddlgym.AST)
    compiler = Compiler(model)
    compiler.batch_mode_on()

    state = compiler.compile_initial_state(args.batch_size)
    action = compiler.compile_default_action(args.batch_size)

    invariants = compiler.compile_state_invariants(state)
    preconditions = compiler.compile_action_preconditions(state, action)

    scope = compiler.transition_scope(state, action)
    interms, next_state = compiler.compile_cpfs(scope, args.batch_size)

    scope.update(next_state)
    reward = compiler.compile_reward(scope)

    logdir = save_model(model, compiler, args.logdir)
    print('tensorboard --logdir {}\n'.format(logdir))
