#!python
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Decode from trained T2T models.

This binary performs inference using the Estimator API.

Example usage to decode from dataset:

  t2t-decoder \
      --data_dir ~/data \
      --problems=algorithmic_identity_binary40 \
      --model=transformer
      --hparams_set=transformer_base

Set FLAGS.decode_interactive or FLAGS.decode_from_file for alternative decode
sources.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from tensor2tensor.utils import decoding
from tensor2tensor.utils import trainer_utils
from tensor2tensor.utils import usr_dir

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("t2t_usr_dir", "",
                    "Path to a Python module that will be imported. The "
                    "__init__.py file should include the necessary imports. "
                    "The imported files should contain registrations, "
                    "e.g. @registry.register_model calls, that will then be "
                    "available to the t2t-decoder.")


def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  trainer_utils.log_registry()
  trainer_utils.validate_flags()
  data_dir = os.path.expanduser(FLAGS.data_dir)
  output_dir = os.path.expanduser(FLAGS.output_dir)

  hparams = trainer_utils.create_hparams(
      FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams)
  estimator, _ = trainer_utils.create_experiment_components(
      hparams=hparams,
      output_dir=output_dir,
      data_dir=data_dir,
      model_name=FLAGS.model)

  if FLAGS.decode_interactive:
    decoding.decode_interactively(estimator)
  elif FLAGS.decode_from_file:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file)
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problems.split("-"),
        return_beams=FLAGS.decode_return_beams,
        beam_size=FLAGS.decode_beam_size,
        max_predictions=FLAGS.decode_num_samples,
        decode_to_file=FLAGS.decode_to_file,
        save_images=FLAGS.decode_save_images,
        identity_output=FLAGS.identity_output)


if __name__ == "__main__":
  tf.app.run()
