#!/usr/bin/env python

import sys
import argparse
import logging
from cloud_training.commands.configure import ConfigureCommand
from cloud_training.commands.shutdown import ShutdownCommand
from cloud_training.commands.spot_price import SpotPriceCommand
from cloud_training.commands.sync_session import SyncSessionCommand
from cloud_training.commands.train import TrainCommand


parser = argparse.ArgumentParser()
parser.add_argument('--profile', type=str, default='default', help='Settings profile')
parser.add_argument('--project-dir', type=str, default='.', help='Local path to the project')
parser.add_argument('--model', type=str, default=None, help='Model name')
parser.add_argument('--debug', action='store_true', help='Show debug messages')

subparsers = parser.add_subparsers()

# "configure" command
parser_configure = subparsers.add_parser('configure', help='Configure the tool')
parser_configure.set_defaults(command_class=ConfigureCommand, parser=parser_configure)

# "spot-price" command
parser_spot_price = subparsers.add_parser('spot-price', help='Spot Instance Prices')
parser_spot_price.add_argument('--region', type=str, default=None, help='Prices for a particular AWS region')
parser_spot_price.add_argument('--instance-type', type=str, default=None, help='Instance Type')
parser_spot_price.add_argument('--all-regions', action='store_true', help='Get prices for all regions')
parser_spot_price.set_defaults(command_class=SpotPriceCommand, parser=parser_spot_price)

# "session-sync" command
parser_sync_session = subparsers.add_parser('sync-session', help='Get the session from S3 (the last one by default)')
parser_sync_session.add_argument('--session', type=str, help='Session ID')
parser_sync_session.set_defaults(command_class=SyncSessionCommand, parser=parser_sync_session)

# "train" command
parser_train = subparsers.add_parser('train', help='Train the model on AWS machine (starts new session by default)')
parser_train.add_argument('--session', type=str, default=None, help='Continue the training with this session')
parser_train.add_argument('--instance-type', type=str, default=None, help='Instance Type')
parser_train.add_argument('--conda-env', type=str, default=None, help='Name of a Conda environment for training')
parser_train.set_defaults(command_class=TrainCommand, parser=parser_train)

# "shutdown" command
parser_shutdown = subparsers.add_parser('shutdown', help='Shutdown the instance')
parser_shutdown.set_defaults(command_class=ShutdownCommand, parser=parser_shutdown)

# parse arguments
args = parser.parse_args()

# logging
logging_level = logging.DEBUG if args.debug else logging.WARNING
logging.basicConfig(level=logging_level, format='%(levelname)s %(message)s')

if not hasattr(args, 'command_class'):
    parser.print_usage()
    sys.exit(1)

# create a command object
try:
    command = args.command_class(args)
except ValueError as e:
    parser.print_usage()
    print(e)
    sys.exit(1)

# run a command
try:
    res = command.run()
except ValueError as e:
    args.parser.print_usage()
    print(e)
    sys.exit(1)

if not res:
    sys.exit(1)
