#! /usr/bin/env python3

# This file is part of MDP-ProbLog.

# MDP-ProbLog is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# MDP-ProbLog is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with MDP-ProbLog.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import os
import sys
import time

import mdpproblog
from mdpproblog.mdp import MDP
from mdpproblog.value_iteration import ValueIteration
from mdpproblog.simulator import Simulator
from mdpproblog.fluent import StateSpace

MODELS = {
    'sysadmin1' : ['sysadmin',       'domain.pl',  'star2.pl'],
    'sysadmin2' : ['sysadmin',       'domain2.pl', 'star2.pl'],
    'vm1'       : ['viralmarketing', 'domain1.pl', 'network1.pl'],
    'vm2'       : ['viralmarketing', 'domain2.pl', 'network1.pl'],
    'vm3'       : ['viralmarketing', 'domain3.pl', 'network2.pl'],
    'vm4'       : ['viralmarketing', 'domain4.pl', 'network2.pl']
}


def parse():
    usage = "mdp-problog {list, show, simulate, solve} [-m DOMAIN INSTANCE] [OPTIONS]"
    description = """
        MDP-ProbLog is a Python3 framework to represent and solve
        Markovian Decision Processes by Probabilistic Logic Programming.

        This project is free software.

        Please check the documentation at http://pythonhosted.org/mdpproblog/.
        """.strip()

    help_commands = """
        available commands: list examples, show and solve models or simulate optimal policy
        """.strip()

    parser = argparse.ArgumentParser(usage=usage, description=description)
    parser.add_argument("command", choices=['list', 'show', 'solve', 'simulate'], help=help_commands)
    parser.add_argument("-m", "--model", nargs=2, help="list of domain and instance files")
    parser.add_argument("-x", "--example", type=str, help="select model from examples")
    parser.add_argument("-g", "--gamma",   type=float, default=0.9, help="discount factor    (default=0.9)")
    parser.add_argument("-e", "--epsilon", type=float, default=0.1, help="maximum error      (default=0.1)")
    parser.add_argument("-t", "--trials",  type=int,   default=100, help="number of trials   (default=100)")
    parser.add_argument("-z", "--horizon", type=int,   default=50,  help="simulation horizon (default=50)")
    return parser.parse_args()


def load_model(domain, instance):
    with open(domain, 'r') as domain:
        domain_model = domain.read()
    with open(instance, 'r') as instance:
        instance_model = instance.read()
    return domain_model, instance_model


def print_list_examples():
    examples = ', '.join(sorted(MODELS.keys()))
    print('>> Available examples: [{}]'.format(examples))


def show_model(domain, instance):
    domain, instance = load_model(domain, instance)
    print(":: DOMAIN ::")
    print(domain)
    print()
    print(":: INSTANCE ::")
    print(instance)
    print()


def solve_model(mdp, gamma, epsilon):
    vi = ValueIteration(mdp)
    return vi.run(gamma, epsilon)


def print_solution(V, policy, iterations, uptime):
    print()
    for state, value in V.items():
        state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in state])
        print("Value({state}) = {value:.3f}".format(state=state, value=value))
    print()
    for state, action in policy.items():
        state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in state])
        print("Policy({state}) = {action}".format(state=state, action=action))
    print()
    print(">> Value iteration converged in {0:.3f}sec after {1} iterations.".format(uptime, iterations))
    print(">> Average time per iteration = {0:.3f}sec.".format(uptime / iterations))


if __name__ == '__main__':
    args = parse()

    if args.command == 'list':
        print_list_examples()
        exit(0)

    if args.example:
        model_dir, domain, instance = MODELS[args.example]
        model_dir = os.path.join(os.path.dirname(mdpproblog.__file__), 'models/', model_dir)
        domain    = os.path.join(model_dir, domain)
        instance  = os.path.join(model_dir, instance)
    else:
        domain   = args.model[0]
        instance = args.model[1]

    if args.command == 'show':
        show_model(domain, instance)

    if args.command in ['solve', 'simulate']:
        domain, instance = load_model(domain, instance)
        model = domain + instance
        mdp = MDP(model)
        start = time.clock()
        V, policy, iterations = solve_model(mdp, args.gamma, args.epsilon)
        end = time.clock()
        uptime = end - start

        print_solution(V, policy, iterations, uptime)

        if args.command == 'simulate':
            simulator = Simulator(mdp, policy)
            states  = StateSpace(mdp.current_state_fluents())
            print()
            for i, state in enumerate(states):
                avg, _, _ = simulator.run(args.trials, args.horizon, states[i], args.gamma)
                state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in states[i]])
                print("Expectation({state}) = {value:.3f}".format(state=state, value=avg))
