#!/usr/bin/env python3

# This file is part of tf-plan.

# tf-plan is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-plan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-plan. If not, see <http://www.gnu.org/licenses/>.


import argparse
import numpy as np

import rddlgym
import tfrddlsim.viz

import tfplan
from tfplan.planners.environment import OnlinePlanning
from tfplan.planners.online import OnlineOpenLoopPlanner
from tfplan.planners.offline import OfflineOpenLoopPlanner
from tfplan.train.policy import OpenLoopPolicy
from tfplan.test.evaluator import ActionEvaluator


def parse_args():
    description = '''
    tf-plan (v{}): Planning via gradient-based optimization in TensorFlow.
    '''.format(tfplan.__version__)
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        'rddl',
        type=str,
        help='RDDL file or rddlgym domain id'
    )
    parser.add_argument(
        '-m', '--mode',
        default='offline',
        choices=['offline', 'online'],
        help='planning mode (default=offline)'
    )
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=128,
        help='number of trajectories in a batch (default=128)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps (default=40)'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int, default=500,
        help='number of timesteps (default=500)'
    )
    parser.add_argument(
        '-lr', '--learning-rate',
        type=float, default=0.01,
        help='optimizer learning rate (default=0.001)'
    )
    parser.add_argument(
        '--viz',
        default='generic',
        choices=tuple(tfrddlsim.viz.visualizers),
        help='type of visualizer (default=generic)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def print_parameters(args):
    if args.verbose:
        print()
        print('Running tf-plan v{} ...'.format(tfplan.__version__))
        print('>> RDDL:            {}'.format(args.rddl))
        print('>> Planning mode:   {}'.format(args.mode))
        print('>> Horizon:         {}'.format(args.horizon))
        print('>> Batch size:      {}'.format(args.batch_size))
        print('>> Training epochs: {}'.format(args.epochs))
        print('>> Learning rate:   {}'.format(args.learning_rate))
        print()


def load_model(args):
    compiler = rddlgym.make(args.rddl, mode=rddlgym.SCG)
    compiler.batch_mode_on()
    return compiler


def optimize(compiler, args):
    planning = online_planning if args.mode == 'online' else offline_planning
    return planning(compiler, args.batch_size, args.horizon, args.epochs, args.learning_rate)


def offline_planning(compiler, batch_size, horizon, epochs, learning_rate):
    # optimize actions
    planner = OfflineOpenLoopPlanner(compiler, batch_size, horizon)
    planner.build(learning_rate)
    actions, policy_variables = planner.run(epochs)

    # evaluate solution
    plan = OpenLoopPolicy(compiler, 1, horizon)
    plan.build('test', initializers=policy_variables)
    evaluator = ActionEvaluator(compiler, plan)
    trajectories = evaluator.run()
    return trajectories, None


def online_planning(compiler, batch_size, horizon, epochs, learning_rate):
    # build online planner
    open_loop_planner = OnlineOpenLoopPlanner(compiler, batch_size, horizon, parallel_plans=False)
    open_loop_planner.build(learning_rate, epochs)

    # run plan-execute-monitor cycle and evaluate solution
    planner = OnlinePlanning(compiler, open_loop_planner)
    planner.build()
    trajectories, stats = planner.run(horizon)
    return trajectories, stats


def report_performance(trajectories, stats):
    print()
    print('>> total reward = {:.6f}'.format(np.sum(trajectories[-1])))
    print()
    if stats:
        building_times, optimization_times = stats['build'], stats['optimization']
        total_building_time = np.sum(building_times)
        avg_building_time = np.mean(building_times)
        total_optimization_time = np.sum(optimization_times)
        avg_optimization_time = np.mean(optimization_times)
        stddev_optimization_time = np.std(optimization_times)
        print('>> total building time = {:.6e}'.format(total_building_time))
        print('>> avg   building time = {:.6e}'.format(avg_building_time))
        print()
        print('>> total  optimization time = {:.6e}'.format(total_optimization_time))
        print('>> avg    optimization time   = {:.6e}'.format(avg_optimization_time))
        print('>> stddev optimization time   = {:.6e}'.format(stddev_optimization_time))


def display(compiler, trajectories, visualizer_type, verbose=True):
    if verbose:
        visualizer = tfrddlsim.viz.visualizers.get(visualizer_type, 'generic')
        viz = visualizer(compiler, verbose)
        viz.render(trajectories)


if __name__ == '__main__':

    # parse CLI arguments
    args = parse_args()

    # print planner parameters
    print_parameters(args)

    # read and compile RDDL file
    compiler = load_model(args)

    # optimize actions
    trajectories, stats = optimize(compiler, args)

    # report results
    report_performance(trajectories, stats)

    # render visualization
    display(compiler, trajectories, args.viz, args.verbose)
