#!python
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Produces the training and dev data for --problem into --data_dir.

generator.py produces sharded and shuffled TFRecord files of tensorflow.Example
protocol buffers for a variety of datasets registered in this file.

All datasets are registered in _SUPPORTED_PROBLEM_GENERATORS. Each entry maps a
string name (selectable on the command-line with --problem) to a function that
takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
yields for each training example a dictionary mapping string feature names to
lists of {string, int, float}. The generator will be run once for each mode.
"""

import random
import tempfile

# Dependency imports

import numpy as np

from tensor2tensor.data_generators import algorithmic
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("data_dir", "", "Data directory.")
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
                    "Temporary storage directory.")
flags.DEFINE_string("problem", "",
                    "The name of the problem to generate data for.")
flags.DEFINE_integer("num_shards", 10, "How many shards to use.")
flags.DEFINE_integer("max_cases", 0,
                     "Maximum number of cases to generate (unbounded if 0).")
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")

# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
    "algorithmic_identity_binary40": (
        lambda: algorithmic.identity_generator(2, 40, 100000),
        lambda: algorithmic.identity_generator(2, 400, 10000)),
    "algorithmic_identity_decimal40": (
        lambda: algorithmic.identity_generator(10, 40, 100000),
        lambda: algorithmic.identity_generator(10, 400, 10000)),
    "algorithmic_shift_decimal40": (
        lambda: algorithmic.shift_generator(20, 10, 40, 100000),
        lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
    "algorithmic_reverse_binary40": (
        lambda: algorithmic.reverse_generator(2, 40, 100000),
        lambda: algorithmic.reverse_generator(2, 400, 10000)),
    "algorithmic_reverse_decimal40": (
        lambda: algorithmic.reverse_generator(10, 40, 100000),
        lambda: algorithmic.reverse_generator(10, 400, 10000)),
    "algorithmic_addition_binary40": (
        lambda: algorithmic.addition_generator(2, 40, 100000),
        lambda: algorithmic.addition_generator(2, 400, 10000)),
    "algorithmic_addition_decimal40": (
        lambda: algorithmic.addition_generator(10, 40, 100000),
        lambda: algorithmic.addition_generator(10, 400, 10000)),
    "algorithmic_multiplication_binary40": (
        lambda: algorithmic.multiplication_generator(2, 40, 100000),
        lambda: algorithmic.multiplication_generator(2, 400, 10000)),
    "algorithmic_multiplication_decimal40": (
        lambda: algorithmic.multiplication_generator(10, 40, 100000),
        lambda: algorithmic.multiplication_generator(10, 400, 10000)),
    "algorithmic_algebra_inverse": (
        lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
        lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
    "algorithmic_algebra_simplify": (
        lambda: algorithmic_math.algebra_simplify(8, 0, 2, 100000),
        lambda: algorithmic_math.algebra_simplify(8, 3, 3, 10000)),
    "algorithmic_calculus_integrate": (
        lambda: algorithmic_math.calculus_integrate(8, 0, 2, 100000),
        lambda: algorithmic_math.calculus_integrate(8, 3, 3, 10000)),
    "wmt_parsing_characters": (
        lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, True),
        lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, False)),
    "wmt_parsing_tokens_8k": (
        lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
        lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
    "wsj_parsing_tokens_16k": (
        lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
                                                    2**14, 2**9),
        lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
                                                    2**14, 2**9)),
    "wsj_parsing_tokens_32k": (
        lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
                                                    2**15, 2**9),
        lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
                                                    2**15, 2**9)),
    "wmt_enfr_characters": (
        lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
        lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
    "wmt_enfr_tokens_8k": (
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
    ),
    "wmt_enfr_tokens_32k": (
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
    ),
    "wmt_enfr_tokens_128k": (
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17),
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17)
    ),
    "wmt_ende_characters": (
        lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True),
        lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)),
    "wmt_ende_bpe32k": (
        lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
        lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
    "wmt_ende_tokens_8k": (
        lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
        lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
    ),
    "wmt_ende_tokens_32k": (
        lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
        lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
    ),
    "wmt_ende_tokens_128k": (
        lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17),
        lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17)
    ),
    "image_mnist_tune": (
        lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
        lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
    "image_mnist_test": (
        lambda: image.mnist_generator(FLAGS.tmp_dir, True, 60000),
        lambda: image.mnist_generator(FLAGS.tmp_dir, False, 10000)),
    "image_cifar10_tune": (
        lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 48000),
        lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 2000, 48000)),
    "image_cifar10_test": (
        lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
        lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
    "image_mscoco_characters_tune": (
        lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 70000),
        lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 10000, 70000)),
    "image_mscoco_characters_test": (
        lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
        lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
    "image_mscoco_tokens_8k_tune": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            70000,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            10000,
            70000,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "image_mscoco_tokens_8k_test": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            80000,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            False,
            40000,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "image_mscoco_tokens_32k_tune": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            70000,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            10000,
            70000,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
    "image_mscoco_tokens_32k_test": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            80000,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            False,
            40000,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
    "image_mscoco_tokens_128k_tune": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            70000,
            vocab_filename="tokens.vocab.%d" % 2**17,
            vocab_size=2**17),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            10000,
            70000,
            vocab_filename="tokens.vocab.%d" % 2**17,
            vocab_size=2**17)),
    "image_mscoco_tokens_128k_test": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            80000,
            vocab_filename="tokens.vocab.%d" % 2**17,
            vocab_size=2**17),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            False,
            40000,
            vocab_filename="tokens.vocab.%d" % 2**17,
            vocab_size=2**17)),
    "snli_32k": (
        lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
        lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
    ),
    "audio_timit_characters_tune": (
        lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1374),
        lambda: audio.timit_generator(FLAGS.tmp_dir, True, 344, 1374)),
    "audio_timit_characters_test": (
        lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1718),
        lambda: audio.timit_generator(FLAGS.tmp_dir, False, 626)),
    "audio_timit_tokens_8k_tune": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            344,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "audio_timit_tokens_8k_test": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1718,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            False,
            626,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "audio_timit_tokens_32k_tune": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            344,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
    "audio_timit_tokens_32k_test": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1718,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            False,
            626,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
}

# pylint: enable=g-long-lambda

UNSHUFFLED_SUFFIX = "-unshuffled"


def set_random_seed():
  """Set the random seed from flag everywhere."""
  tf.set_random_seed(FLAGS.random_seed)
  random.seed(FLAGS.random_seed)
  np.random.seed(FLAGS.random_seed)


def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.problem not in _SUPPORTED_PROBLEM_GENERATORS:
    problems_str = "\n  * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS))
    error_msg = ("You must specify one of the supported problems to "
                 "generate data for:\n  * " + problems_str + "\n")
    raise ValueError(error_msg)

  if not FLAGS.data_dir:
    FLAGS.data_dir = tempfile.gettempdir()
    tf.logging.warning("It is strongly recommended to specify --data_dir. "
                       "Data will be written to default data_dir=%s.",
                       FLAGS.data_dir)

  set_random_seed()

  training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[FLAGS.problem]

  tf.logging.info("Generating training data for %s.", FLAGS.problem)
  train_output_files = generator_utils.generate_files(
      training_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-train",
      FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases)

  tf.logging.info("Generating development data for %s.", FLAGS.problem)
  dev_output_files = generator_utils.generate_files(
      dev_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1)

  tf.logging.info("Shuffling data...")
  for fname in train_output_files + dev_output_files:
    records = generator_utils.read_records(fname)
    random.shuffle(records)
    out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
    generator_utils.write_records(records, out_fname)
    tf.gfile.Remove(fname)


if __name__ == "__main__":
  tf.app.run()
