#!/usr/bin/env python3

# This file is part of tf-plan.

# tf-plan is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-plan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-plan. If not, see <http://www.gnu.org/licenses/>.


import argparse
import numpy as np

import tfplan
import tfrddlsim.viz

def parse_args():
    description = 'Planning via gradient-based optimization in TensorFlow.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('rddl', type=str, help='RDDL filepath')
    parser.add_argument(
        '-m', '--mode',
        default='offline',
        choices=['offline', 'online'],
        help='planning mode (default=offline)'
    )
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=128,
        help='number of trajectories in a batch (default=128)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps (default=40)'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int, default=500,
        help='number of timesteps (default=500)'
    )
    parser.add_argument(
        '-lr', '--learning-rate',
        type=float, default=0.01,
        help='optimizer learning rate (default=0.001)'
    )
    parser.add_argument(
        '--viz',
        default='generic',
        choices=tuple(tfrddlsim.viz.visualizers),
        help='type of visualizer (default=generic)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def print_parameters(args):
    if args.verbose:
        print()
        print('Running tf-plan v{} ...'.format(tfplan.__version__))
        print('>> RDDL:            {}'.format(args.rddl))
        print('>> Planning mode:   {}'.format(args.mode))
        print('>> Horizon:         {}'.format(args.horizon))
        print('>> Batch size:      {}'.format(args.batch_size))
        print('>> Training epochs: {}'.format(args.epochs))
        print('>> Learning rate:   {}'.format(args.learning_rate))
        print()


def read_file(path):
    with open(path, 'r') as f:
        return f.read()


def parse_rddl(path):
    from pyrddl.parser import RDDLParser
    parser = RDDLParser()
    parser.build()
    rddl = parser.parse(read_file(path))
    return rddl


def compile(rddl):
    from tfrddlsim.rddl2tf.compiler import Compiler
    rddl2tf = Compiler(rddl, batch_mode=True)
    return rddl2tf


def optimize(rddl2tf, args):
    planning = online_planning if args.mode == 'online' else offline_planning
    return planning(rddl2tf, args.batch_size, args.horizon, args.epochs, args.learning_rate)


def offline_planning(compiler, batch_size, horizon, epochs, learning_rate):
    from tfplan.planners.offline import OfflineOpenLoopPlanner
    from tfplan.train.policy import OpenLoopPolicy
    from tfplan.test.evaluator import ActionEvaluator

    # optimize actions
    planner = OfflineOpenLoopPlanner(compiler, batch_size, horizon)
    planner.build(learning_rate)
    actions, policy_variables = planner.run(epochs)

    # evaluate solution
    plan = OpenLoopPolicy(compiler, 1, horizon)
    plan.build('test', initializers=policy_variables)
    evaluator = ActionEvaluator(compiler, plan)
    trajectories = evaluator.run()
    return trajectories


def online_planning(compiler, batch_size, horizon, epochs, learning_rate):
    from tfplan.planners.environment import OnlinePlanning
    from tfplan.planners.online import OnlineOpenLoopPlanner

    # build online planner
    open_loop_planner = OnlineOpenLoopPlanner(compiler, batch_size, horizon, parallel_plans=False)
    open_loop_planner.build(learning_rate, epochs)

    # run plan-execute-monitor cycle and evaluate solution
    planner = OnlinePlanning(compiler, open_loop_planner)
    planner.build()
    trajectories = planner.run(horizon)
    return trajectories


def print_performance(trajectories):
    print()
    print('>> total reward = {:.6f}'.format(np.sum(trajectories[-1])))
    print()


def display(compiler, trajectories, visualizer_type, verbose=True):
    if verbose:
        visualizer = tfrddlsim.viz.visualizers.get(visualizer_type, 'generic')
        viz = visualizer(compiler, verbose)
        viz.render(trajectories)


if __name__ == '__main__':

    # parse CLI arguments
    args = parse_args()

    # print planner parameters
    print_parameters(args)

    # read RDDL file
    rddl = parse_rddl(args.rddl)

    # compile RDDL to TensorFlow
    rddl2tf = compile(rddl)

    # optimize actions
    trajectories = optimize(rddl2tf, args)

    # report performance
    print_performance(trajectories)

    # render visualization
    display(rddl2tf, trajectories, args.viz, args.verbose)
