#!/usr/bin/env python
#

"""
ec2launch: Launch Amazon AWS EC2 instance
"""

import collections
import json
import os
import random
import re
import sys
import time

import boto
import ec2common

import gterm

ssh_dir = os.path.expanduser("~/.ssh")
gterm_user = "ubuntu"

usage = "usage: %prog [-h ... ] <instance-tag|full-domain-name>"

form_parser = gterm.FormParser(usage=usage, title="Create Amazon EC2 instance with hostname (e.g., tagname.graphterm.net): ", command="ec2launch -f")
form_parser.add_argument(label="", help="Instance hostname")
form_parser.add_option("type", ("t1.micro", "m1.small", "m1.medium", "m1.large"), help="Instance type")
form_parser.add_option("key_name", "ec2key", help="Instance management SSH key name")
form_parser.add_option("ami", "ami-e5582d8c", help="Instance OS (default: Ubuntu 12.04LTS)")
form_parser.add_option("auth_type", "none", help="Authentication type (local/none/name/multiuser)")
form_parser.add_option("connect_to_gterm", "", help="Full domain name of GraphTerm server to connect to")
form_parser.add_option("install", "", help="Tar file to install on startup (with optional '; command' suffix)")
form_parser.add_option("copy_files", "", help="List of files to copy (space separated, e.g., .twitter_auth)")
form_parser.add_option("pylab", False, help="Install PyLab")
form_parser.add_option("R", False, help="Install R")
form_parser.add_option("auto_users", False, help="Automatically create new users")
form_parser.add_option("allow_share", False, help="Allow sharing between users")
form_parser.add_option("logging", False, help="Enable logging")
form_parser.add_option("oshell", False, help="Enable OTrace shell")
form_parser.add_option("other_opts", "", help="Other options for gtermserver (space-separated)")
form_parser.add_option("dry_run", False, help="Dry run")
form_parser.add_option("verbose", False, help="Verbose")
##form_parser.add_option("wildcard", False, help="Enable wildcard subdomains")

form_parser.add_option("fullpage", False, short="f", help="Fullpage display", raw=True)
form_parser.add_option("text", False, short="t", help="Text only", raw=True)

(options, args) = form_parser.parse_args()

if not gterm.Lterm_cookie or not sys.stdout.isatty():
    options.text = True

if not args:
    if options.text:
        print >> sys.stderr, form_parser.get_usage()
        sys.exit(1)
    gterm.write_form(form_parser.create_form(), command="ec2launch -f")
    sys.exit(1)
    
instance_tag = args[0]

if not re.match(r"^[a-z][a-z0-9\.\-]*$", instance_tag):
    raise Exception("Invalid characters in instance name: "+instance_tag)

instance_name, sep, instance_domain = instance_tag.partition(".")

key_file = os.path.join(ssh_dir, options.key_name+".pem")

if not os.path.exists(ssh_dir):
    print >> sys.stderr, "ec2launch: %s directory not found!" % ssh_dir
    sys.exit(1)

install_file, install_cmd, install_basename, install_baseroot, unarch_cmd = "", "", "", "", ""
copy_files = []
if options.install:
    install_file, sep, install_cmd = options.install.partition(";")
    if not os.path.isfile(install_file):
        raise Exception("Install file "+install_file+" not found")
    install_basename = os.path.basename(install_file)
    install_baseroot, ext = os.path.splitext(install_basename)
    if ext == ".gz":
        install_baseroot, ext = os.path.splitext(install_baseroot)
        if ext != ".tar":
            raise Exception("Invalid extension for install file: "+install_file)
        unarch_cmd = "tar zxf"
    elif ext == ".tgz":
        unarch_cmd = "tar zxf"
    elif ext == ".tar":
        unarch_cmd = "tar xf"
    elif ext == ".zip":
        unarch_cmd = "unzip"
    else:
        raise Exception("Invalid extension for install file: "+install_file)
    copy_files.append(install_file)

if options.copy_files:
    copy_files += options.copy_files.split()

auth_type = options.auth_type.strip()
allow_host_connect = auth_type and auth_type not in ("name", "none")

if allow_host_connect and auth_type not in ("local", "multiuser"):
    raise Exception("Specifying auth code as argument is insecure")

GroupRule = collections.namedtuple("GroupRule",
                                   ["ip_protocol", "from_port", "to_port", "cidr_ip", "src_group_name"])

HostGroupName = "gtermhost"
ServerGroupName = "gtermserver"

SecurityGroups = [
    (HostGroupName, "GraphTerm host group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        ]),

    (ServerGroupName, "GraphTerm server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "8888", "8888", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if allow_host_connect else HostGroupName),
        GroupRule("tcp", "8900", "8900", "0.0.0.0/0", None),
        ],
     )
    ]

setup_file = "SETUP_OVER"

startup_commands = ["#!/bin/bash",
"set -e -x",
"apt-get update && apt-get upgrade -y",
"apt-get install -y python-setuptools",
"easy_install tornado",
]

install_gterm = not install_basename.startswith("graphterm")

if install_gterm:
    startup_commands += ["easy_install graphterm",
                         "gterm_setup"]

if options.pylab:
    startup_commands += ["apt-get install -y python-numpy python-scipy python-matplotlib python-scientific python-pandas libnetcdf-dev netcdf-bin ipython"]
    startup_commands += ["easy_install pil qrcode"]

if options.R:
    startup_commands += ["apt-get install -y r-base libcurl4-openssl-dev libcairo2-dev libxt-dev"]

command_format = "sudo -u "+gterm_user+" -i %s --daemon=start"
if options.auto_users:
    if options.auth_type != "multiuser":
        raise Exception("Must specify auth_type=multiuser to enable auto_users")
    command_format += " --auto_users --super_users=ubuntu"
if options.allow_share:
    command_format += " --allow_share"
if options.other_opts:
    command_format += " " + options.other_opts.strip()
if options.logging:
    command_format += " --logging"
if options.oshell:
    command_format += " --oshell"

init_command = ""
fwall_command = ""
if options.connect_to_gterm:
    security_groups = [HostGroupName]
    server_name, sep, server_domain = options.connect_to_gterm.partition(".")
    if not server_domain:
        raise Exception("Must specify fully qualified domain name to connect to GraphTerm server")
    host_name = instance_name if instance_domain == server_domain else instance_tag
    init_command = (command_format % "gtermhost") + " --server_addr="+ options.connect_to_gterm+" "+host_name
    startup_commands += [ init_command ]
else:
    security_groups = [ServerGroupName]

    host_domain = instance_tag if instance_domain else "`curl http://169.254.169.254/latest/meta-data/public-hostname`"
    fwall_command = "iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to 8900"
    startup_commands += [ fwall_command ]
    if install_gterm and auth_type:
        init_command = (command_format % "gtermserver") + ' --auth_type="'+options.auth_type+'" --host='+host_domain
        ##if options.wildcard:
        ##    init_command += " --blob_host=wildcard"
        startup_commands += [ init_command ]

if init_command or fwall_command:
    startup_commands += [ "echo '"+init_command+"' > /etc/init.d/graphterm",
                          "echo '"+fwall_command+"' >> /etc/init.d/graphterm",
                          "chmod +x /etc/init.d/graphterm",
                          "update-rc.d graphterm defaults"]

startup_commands += ["sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_file]

startup_script = "\n".join(startup_commands) + "\n"

if options.verbose:
    print >> sys.stderr, "Startup_commands:\n   "+ "\n   ".join(startup_commands) + "\n\n"

# Connect to EC2
ec2 = boto.connect_ec2()

route53conn = None
hosted_zone = None
nameservers = ""
if instance_domain:
    route53conn = ec2common.Route53Connection()
    hosted_zone = ec2common.get_hosted_zone(route53conn, instance_domain)
    if not hosted_zone:
        hosted_zone = ec2common.create_hosted_zone(route53conn, instance_domain)
        nameservers = ec2common.get_nameservers(route53conn, instance_domain)

def get_security_group(ec2conn, group_name):
    groups = [x for x in ec2conn.get_all_security_groups() if x.name == group_name]
    return groups[0] if groups else None

def create_or_update_security_group(ec2conn, group_name, description="", rules=[]):
    """Create (or update) security group"""
    group = get_security_group(ec2conn, group_name)
    new_rules = rules[:]
    if group:
        # Group already exists
        for rule in group.rules:
            # Check each rule
            cidr_ip = rule.grants[0].cidr_ip if rule.grants[0].cidr_ip else "0.0.0.0/0"
            src_group_name = None if rule.grants[0].cidr_ip else rule.grants[0].name
            old_rule = GroupRule(rule.ip_protocol,
                                 rule.from_port,
                                 rule.to_port,
                                 cidr_ip,
                                 src_group_name)
            if old_rule in new_rules:
                # Old rule still valid
                new_rules.remove(old_rule)
            else:
                # Old rule no longer valid
                group.revoke(ip_protocol=rule.ip_protocol,
                             from_port=rule.from_port,
                             to_port=rule.to_port,
                             cidr_ip=cidr_ip,
                             src_group=get_security_group(ec2conn, src_group_name) if src_group_name else None)
    else:
        group = ec2conn.create_security_group(group_name, description or group_name)
                                        

    for rule in new_rules:
        # Create new rules
        if rule.src_group_name:
            src_group = get_security_group(ec2conn, rule.src_group_name)
            if not src_group:
                raise Exception("Source group %s not found" % rule.src_group_name)
        else:
            src_group = None

        group.authorize(ip_protocol=rule.ip_protocol,
                        from_port=rule.from_port,
                        to_port=rule.to_port,
                        cidr_ip=rule.cidr_ip,
                        src_group=src_group)
    return group

# Create key pair, if needed
if not os.path.exists(key_file):
    key_pair = ec2.create_key_pair(options.key_name)
    key_pair.save(ssh_dir)
    os.chmod(key_file, 0600)

if options.dry_run:
    print >> sys.stderr, "run_instances:", dict(image_id=options.ami,
                                                instance_type=options.type,
                                                key_name=options.key_name,
                                                security_groups=security_groups)
    print >> sys.stderr, "startup_script:", startup_script
    sys.exit(1)

# Create security groups as needed
for group_name, description, rules in SecurityGroups:
    create_or_update_security_group(ec2, group_name, description=description, rules=rules)

# Launch instance
reservation = ec2.run_instances(image_id=options.ami,
                                instance_type=options.type,
                                key_name=options.key_name,
                                security_groups=security_groups,
                                user_data=startup_script)
instance = reservation.instances[0]

# Wait for instance to start running
Status_template =  """<em>Creating instance</em> <b>%s</b>: status=<b>%s</b> (waiting %ds)"""

status = instance.update()
start_time = time.time()

if options.text:
    print >> sys.stderr, "Waiting for instance to be created..."
    
while status == "pending":
    timesec = int(time.time() - start_time)
    if not options.text:
        gterm.write_pagelet(Status_template % (instance_tag, status, timesec), overwrite=True, display="fullpage")
    time.sleep(3)
    status = instance.update()

if status != "running":
    print >> sys.stderr, "ec2launch: ERROR Failed to launch instance: %s" % status
    sys.exit(1)

# Tag instance
instance.add_tag(instance_tag)

instance_id = reservation.id
instance_obj = None
all_instances = ec2.get_all_instances()
for r in all_instances:
    if r.id == instance_id:
        instance_obj = r.instances[0]
        break

if not instance_obj:
    print >> sys.stderr, "ec2launch: ERROR Unable to find launched instance: %s" % status
    sys.exit(1)

public_dns_name = instance_obj.public_dns_name

if hosted_zone:
    # Create new CNAME entry pointing to instance public DNS
    ec2common.cname(route53conn, hosted_zone, instance_tag, public_dns_name)
    ##if options.wildcard:
    ##    ec2common.cname(route53conn, hosted_zone, "*."+instance_tag, public_dns_name)

print >> sys.stderr, "Created EC2 instance %s: id=%s, public=%s" % (instance_tag, instance_id, public_dns_name)
if options.verbose:
    print >> sys.stderr, "To check status, type:"
    print >> sys.stderr, "   ec2ssh ubuntu@"+(instance_tag if instance_domain else public_dns_name)
    print >> sys.stderr, "   sudo su -"
    print >> sys.stderr, "   tail -f /var/log/dpkg.log"

ssh_cmd_args = ["ssh", "-i", os.path.expanduser("~/.ssh/ec2key.pem"),
                "-o", "StrictHostKeyChecking=no", gterm_user+"@"+public_dns_name]
wait_cmd = "while [ ! -f "+setup_file+" ]; do sleep 1; done"

if nameservers:
    print >> sys.stderr, "***NOTE*** Please add the following nameservers for domain "+instance_domain
    print >> sys.stderr, "           ", nameservers
    print >> sys.stderr, ""

print >> sys.stderr, "Waiting for remote setup to complete..."
time.sleep(30)

if copy_files:
    scp_cmd_args = ["scp", "-i", os.path.expanduser("~/.ssh/ec2key.pem"), "-o", "StrictHostKeyChecking=no"] + copy_files + [gterm_user+"@"+public_dns_name+":"]
    print >> sys.stderr, " ".join(scp_cmd_args)
    out, err = gterm.command_output(scp_cmd_args, timeout=60)
    if out:
        print out
    if err:
        print err
    if options.install:
        wait_cmd += "; "+unarch_cmd+" "+install_basename+"; cd "+install_baseroot+"; sudo python setup.py install"
        if install_cmd:
            wait_cmd += "; "+install_cmd

print >> sys.stderr, " ".join(ssh_cmd_args)
print >> sys.stderr, wait_cmd
out, err = gterm.command_output(ssh_cmd_args + [wait_cmd], timeout=600)

if out:
    print out
if err:
    print >> sys.stderr, err
else:
    print >> sys.stderr, "Remote setup completed."
    
