#!/usr/bin/env python3

import argparse
import sys, os
import traceback

import hyperchamber as hc
import hypergan as hg
import hypergan.cli as cli
from hypergan.viewer import GlobalViewer

class CommandParser:
    def common(self, parser, directory=True):
        if directory:
            parser.add_argument('directory', action='store', type=str, help='The location of your data.  Subdirectories are treated as different classes.  You must have at least 1 subdirectory.')
        self.common_flags(parser)

    def common_flags(self, parser):
        parser.add_argument('--size', '-s', type=str, default='64x64x3', help='Size of your data.  For images it is widthxheightxchannels.')
        parser.add_argument('--batch_size', '-b', type=int, default=0, help='Number of samples to include in each batch.  If using batch norm, this needs to be preserved when in server mode')
        parser.add_argument('--config', '-c', action='store', default=None, type=str, help='The configuration file to load.')
        parser.add_argument('--input_config', '-i', action='store', default=None, type=str, help='The input configuration to use.')
        parser.add_argument('--toml', '-T', action='store_true', help='Use TOML as configration file format instead of JSON')
        parser.add_argument('--device', '-d', type=str, default='/gpu:0', help='In the form "/gpu:0", "/cpu:0", etc.  Always use a GPU (or TPU) to train')
        parser.add_argument('--crop', dest='crop', action='store_true', help='If your images are perfectly sized you can skip cropping.')
        parser.add_argument('--random_crop', dest='random_crop', action='store_true', help='Randomly crop images.')
        parser.add_argument('--resize', dest='resize', action='store_true', help='If your images are perfectly sized you can skip resize.')
        parser.add_argument('--align', '-a', nargs='?', action='append', dest='align', help='Adds an additional input folder.')
        parser.add_argument('--save_every', type=int, default=-1, help='Saves the model every n steps.')
        parser.add_argument('--sample_every', type=int, default=100, help='Saves a sample every X steps.')
        parser.add_argument('--save_samples', action='store_true', help='Saves samples to the local `samples` directory.')
        parser.add_argument('--sampler', type=str, default=None, help='Select a sampler.  Some choices: static_batch, batch, grid, progressive')
        parser.add_argument('--save_file', type=str, default=None, help='The save file to load and save from.  Use gs://bucket/mymodel/model.ckpt for TPU saving')
        parser.add_argument('--ipython', type=bool, default=False, help='Enables iPython embedded mode.')
        parser.add_argument('--steps', type=int, default=-1, help='Number of steps to train for.  -1 is unlimited (default)')
        parser.add_argument('--noviewer', dest='viewer', action='store_false', help='Disables the display of samples in a window.')
        parser.add_argument('--viewer_size', '-z', type=float, dest='viewer_size', default=1, help='Size of the viewer window as a multiplier. WARNING: values above 60 may cause crashes')
        parser.add_argument('--list-templates', '-l', dest='list_templates', action='store_true', help='List available templates.')
        parser.add_argument('--debug', dest='debug', action='store_true', help='Start debugger.')
        parser.add_argument('--version', action='version', version='%(prog)s 0.10.0 alpha')
        parser.add_argument('--nomenu', dest='menu', action='store_false', help='Disables the file menu.')
        parser.add_argument('--width', '-w', type=int, default=8, help='Number of samples per row')
        parser.add_argument('--loss_every', type=int, default=1, help='Saves a losses graph every X steps.')
        parser.add_argument('--save_losses', action='store_true', help='Saves losses graph to the local `losses` directory.')

    def get_parser(self):
        parser = argparse.ArgumentParser(description='Train, run, and deploy your GANs.', add_help=True)
        subparsers = parser.add_subparsers(dest='method')
        train_parser = subparsers.add_parser('train')
        test_parser = subparsers.add_parser('test')
        sample_parser = subparsers.add_parser('sample')
        build_parser = subparsers.add_parser('build')
        new_parser = subparsers.add_parser('new')
        subparsers.required = True
        self.common_flags(parser)
        self.common(sample_parser)
        self.common(train_parser)
        self.common(test_parser, directory=False)
        self.common(build_parser)
        self.common(new_parser)

        return parser

try:
    args = CommandParser().get_parser().parse_args()

    if args.method == 'train' and args.save_every == -1 and not args.viewer:
        print("Error, there is no way to save the model.  Refusing to run.  Please add --save_every N")
        exit(1)

    if args.size:
        size = [int(x) for x in args.size.split("x")] + [None, None, None]
    else:
        size = [None, None, None]
    width = size[0] or 64
    height = size[1] or 64
    channels = size[2] or 3
    if args.toml:
        use_toml = True
        config_format = '.toml'
    else:
        use_toml = False
        config_format = '.json'
    config_name = args.config or 'default'
    config_filename = hg.Configuration.find(config_name, config_format=config_format)
    config = hc.Selector().load(config_filename, load_toml=use_toml)
    if config.input is None and args.input_config is None:
        directories = [args.directory]
        if args.align:
            directories+=args.align
        input_config = hc.Config({
            "class": "class:hypergan.inputs.image_loader.ImageLoader",
            "batch_size": args.batch_size,
            "directories": directories,
            "channels": channels,
            "crop": args.crop,
            "height": height,
            "random_crop": args.random_crop,
            "resize": args.resize,
            "shuffle": True,
            "width": width
        })
    else:
        if config.input:
            if args.batch_size:
                config.input["batch_size"] = args.batch_size
            input_config = hc.Config(config.input)
        else:
            input_config_filename = hg.Configuration.find(args.input_config, config_format=config_format)
            input_config = hc.Selector().load(input_config_filename, load_toml=use_toml)
    if 'inherit' in config:
        base_filename = hg.Configuration.find(config['inherit']+config_format, config_format=config_format)
        base_config = hc.Selector().load(base_filename, load_toml=use_toml)
        config = {**base_config, **config}

    if args.list_templates:
        configs = hg.Configuration.list(config_format=config_format)
        for config in configs:
            try:
                template = hg.Configuration.load(config+config_format, verbose=False, use_toml=use_toml)
            except Exception as e:
                template = hc.Config({"description": str(e), "publication": "error"})
            print("%-30s - %-60s" %  (config, str(template.description)+"("+str(template.publication)+")"))
            if template.runtime and "train" in template.runtime:
                print("  > %s" %  (template.runtime["train"]))
        exit(0)

    if args.method == 'new' or args.method == 'test':
        gan_fn = None
        inputs_fn = None
        pass


    gancli = cli.CLI(args=vars(args), input_config=input_config, gan_config=config)
    gancli.run()
except:
    traceback.print_exc()
    print("Error detected, HyperGAN exiting")
finally:
    GlobalViewer.close()
