#!/usr/bin/env python

import sys
import argparse

from cloud_training.commands.shutdown import ShutdownCommand
from cloud_training.commands.spot_price import SpotPriceCommand
from cloud_training.commands.sync_session import SyncSessionCommand
from cloud_training.commands.train import TrainCommand
import logging

logging.basicConfig(level=logging.DEBUG, format='%(levelname)s %(message)s')

parser = argparse.ArgumentParser()
parser.add_argument('--region', type=str, default='us-east-2', help='AWS Region')
parser.add_argument('--s3-path', type=str, default='models-training/projects',
                    help='AWS S3 path to all projects (format: <bucket-name>/<path>)')
parser.add_argument('--project-dir', type=str, default='.', help='Local path to the project')
parser.add_argument('--model', type=str, default=None, help='Model name')
subparsers = parser.add_subparsers(help='sub-command help')

# "spot-price" command
parser_spot_price = subparsers.add_parser('spot-price', aliases=['sp'], help='Spot Instance Prices')
parser_spot_price.add_argument('--instance-type', type=str, default='p2.xlarge', help='Instance Type')
parser_spot_price.add_argument('--all', action='store_true', help='Get prices for all regions')
parser_spot_price.set_defaults(command_class=SpotPriceCommand, parser=parser_spot_price)

# "session-sync" command
parser_sync_session = subparsers.add_parser('sync-session', aliases=['ss'],
                                            help='Get the session from S3 (the last one by default)')
parser_sync_session.add_argument('--session', type=int, default=0, help='Session ID')
parser_sync_session.set_defaults(command_class=SyncSessionCommand, parser=parser_sync_session)

# "train" command
parser_train = subparsers.add_parser('train', help='Train the model on AWS machine (starts new session by default)')
parser_train.add_argument('--session', type=int, default=0, help='Continue the training with this session')
parser_train.add_argument('--instance-type', type=str, default='p2.xlarge', help='Instance Type')
parser_train.add_argument('--conda-env', type=str, default=None, help='Name of a Conda environment for training')
parser_train.set_defaults(command_class=TrainCommand, parser=parser_train)

# "shutdown" command
parser_shutdown = subparsers.add_parser('shutdown', help='Shutdown the instance')
parser_shutdown.set_defaults(command_class=ShutdownCommand, parser=parser_shutdown)

# parse arguments
args = parser.parse_args()

if not hasattr(args, 'command_class'):
    parser.print_usage()
    sys.exit(1)

command = args.command_class(args)

# check the command arguments
if not command.check():
    args.parser.print_usage()
    sys.exit(1)

# run the command
res = command.run()
if not res:
    sys.exit(1)
