#!/usr/bin/env python

from argparse import ArgumentParser
from prettytable import PrettyTable
from awscore.AwsConnection import AwsConnection


def listVpcs(awsconnection, doption):
    # all vpc information in the region
    allVpcs = awsconnection.ec2client.describe_vpcs()
    x = PrettyTable(["Vpc Id", "Name", "CIDR Block"])
    x.align["Vpc Id"] = "l"
    x.align["Name"] = "l"
    x.align["CIDR Block"] = "l"
    '''
    iterate over the VPCs
    '''
    print('\n')
    print('{0}'.format("\033[1m" + "VPCs in the Account" + "\033[0m"))

    for vpcs in allVpcs['Vpcs']:
        Name = None
        if vpcs.get("Tags"):
            for tag in vpcs["Tags"]:
                if tag["Key"] == "Name":
                    Name = tag["Value"]
        else:
            Name = "--"

        x.add_row([vpcs['VpcId'], Name, vpcs['CidrBlock']])


    if doption == "table":
        print(x)

    elif doption == "html":
        print(x.get_html_string())


def listRoute53Records(awsconnection, doption):
    r53client = awsconnection.r53client
    print("DNS Records(Hosted Zones, A Records, CNAMEs)")
    x = PrettyTable(["Name/Host/Alias", "Record Type", "Value", "Hosted Zone"])
    x.align["Name/Host/Alias"] = "l"
    x.align["Record Type"] = "l"
    x.align["Value"] = "l"
    x.align["Hosted Zone"] = "l"
    for zone in r53client.list_hosted_zones()['HostedZones']:
        for rrsets in r53client.list_resource_record_sets(HostedZoneId=zone['Id'])['ResourceRecordSets']:
            for type in ["A", "CNAME"]:
                if type == rrsets['Type']:
                    x.add_row([rrsets['Name'][:-1], rrsets['Type'], rrsets["ResourceRecords"][0]["Value"], zone['Name']])
    if doption == "table":
        print(x)

    elif doption == "html":
        print(x.get_html_string())


def ec2Instances(awsconnection, doption):
    # session client for ec2
    ec2client = awsconnection.ec2client
    # all vpc information in the region
    allVpcs = ec2client.describe_vpcs()
    print("EC2 Instances")
    # iterate over the vpc list
    for vpcs in allVpcs['Vpcs']:
        vpcid = vpcs['VpcId']
        print('\n')
        print('{0}'.format("\033[1m" + "Vpc Id" + "\033[0m"))
        print('{0}'.format("------"))
        print('{0}'.format(vpcs['VpcId']))
        print('\n')
        x = PrettyTable(["Instance Tag Name", "DNS Name", "Private IP Address", "Public IP Address", "AMI Id"])
        x.align["Instance Tag Name"] = "l"
        x.align["DNS Name"] = "l"
        x.align["Private IP Address"] = "l"
        x.align["Public IP Address"] = "l"
        x.align["AMI Id"] = "l"

        for item in ec2client.describe_instances()['Reservations']:
            tagName = None
            dnsName = None
            pvtIpAddress = None
            pubIpAddress = None
            imageId = None

            if item['Instances'][0]['State']['Name'] == 'running':
                if item['Instances'][0].get('PublicDnsName'):
                    dnsName = item['Instances'][0]['PublicDnsName']

                if item['Instances'][0].get('PublicIpAddress'):
                    pubIpAddress = item['Instances'][0]['PublicIpAddress']

                if item['Instances'][0].get('PrivateIpAddress'):
                    pvtIpAddress = item['Instances'][0]['PrivateIpAddress']
            
                if item['Instances'][0].get('Tags'):
                    for tag in item['Instances'][0]['Tags']:
                        if tag['Key'] == 'Name':
                            tagName = tag['Value']
                if item['Instances'][0].get('ImageId'):
                    imageId = item['Instances'][0]['ImageId']

                x.add_row([tagName, dnsName, pvtIpAddress, pubIpAddress, imageId])

        if doption == "table":
            print(x.get_string(sortby="Instance Tag Name"))

        elif doption == "html":
            print(x.get_html_string(sortby="Instance Tag Name"))


def subnetInfo(awsconnection, doption):
    #
    # session client for ec2
    ec2client = awsconnection.ec2client
    # get the ec2 resource from the session
    ec2resource = awsconnection.ec2resource
    # all vpc information in the region
    allVpcs = ec2client.describe_vpcs()
    # iterate over the vpc list
    for vpcs in allVpcs['Vpcs']:
        vpcid = vpcs['VpcId']
        print('\n')
        print('{0}'.format("\033[1m" + "Vpc Id" + "\033[0m"))
        print('{0}'.format("------"))
        print('{0}'.format(vpcs['VpcId']))
        print('\n')
        subnet_name = None
        subnets = ec2resource.Vpc(vpcid).subnets.all()
        print("Subnet Information")
        x = PrettyTable(["Subnet Id", "CIDR Block", "Subnet - Friendly Name", "Avail. IP Count", "No. of Used IPs", "Availability Zone"])
        x.align["Subnet Id"] = "l"
        x.align["CIDR Block"] = "l"
        x.align["Subnet - Friendly Name"] = "l"
        x.align["Avail. IP Count"] = "l"
        x.align["No. of Used IPs"] = "l"
        x.align["Availability Zone"] = "l"

        for subnet in subnets:
            i = 0
            subnet_name = None
            for instance in subnet.instances.all():
                i += 1
        
            if subnet.tags: 
                for item in subnet.tags:
                    if item['Key'] == 'Name':
                        subnet_name = item['Value']
            x.add_row([subnet.subnet_id, 
                subnet.cidr_block, 
                subnet_name, 
                subnet.available_ip_address_count, 
                i,
                subnet.availability_zone])
        if doption == "table":
            print(x.get_string(sortby="Subnet - Friendly Name"))
        elif doption == "html":
            print(x.get_html_string(sortby="Subnet - Friendly Name"))


def sgInfo(awsconnection, doption):
    #
    #
    # session client for ec2
    ec2client = awsconnection.ec2client
    # get the ec2 resource from the session
    ec2resource = awsconnection.ec2resource
    # all vpc information in the region
    allVpcs = ec2client.describe_vpcs()
    # iterate over the vpc list
    for vpcs in allVpcs['Vpcs']:
        vpcid = vpcs['VpcId']
        print('\n')
        print('{0}'.format("\033[1m" + "Vpc Id" + "\033[0m"))
        print('{0}'.format("------"))
        print('{0}'.format(vpcs['VpcId']))
        print('\n')
        sentinel = object()
        secgroups = ec2resource.Vpc(vpcid).security_groups.all()
        print("Security Groups Information")
        x = PrettyTable(["Security Group Id", "Security Group Name", "ACLs - Source IP", "ACLs - ToPort"])
        x.align["Security Group Id"] = "l"
        x.align["Security Group Name"] = "l"
        x.align["ACLs - Source IP"] = "l"
        x.align["ACLs - ToPort"] = "l"
        size = 0
        i = 0
        for secgrp in secgroups:
            size += 1

        for secgrp in secgroups:
            ToPort = None
            IpRanges = list()
            x.add_row([secgrp.group_id, secgrp.group_name, "", ""])

            if len(secgrp.ip_permissions) != 0:
                for perm in secgrp.ip_permissions:
                    if perm.get('ToPort') != None:
                        ToPort = perm.get('ToPort')
                    if len(perm.get('IpRanges')) != 0:
                        IpRanges = perm.get('IpRanges')

                    if len(IpRanges) != 0:
                        for IpRange in IpRanges:
                            x.add_row(["", "", IpRange['CidrIp'], ToPort])
        
            if i < size - 1:
                x.add_row(["-------------------", 
                  "-------------------------------------------------", 
                  "------------------", 
                  "---------------"])
                i += 1
        if doption == "table":
            print(x)

        if doption == "html":
            print(x.get_html_string())


def lisKeyPairs(awsconnection, doption):
    # session client for ec2
    ec2client = awsconnection.ec2client
    keypairs = ec2client.describe_key_pairs()
    x = PrettyTable(["KeyName", "KeyFingerprint"])
    x.align["KeyName"] = "l"
    x.align["KeyFingerprint"] = "l"
    print("KeyPairs Available")
    for keypair in keypairs['KeyPairs']:
        x.add_row([keypair['KeyName'], keypair['KeyFingerprint']])

    if doption == "table":
        print(x)

    if doption == "html":
        print(x.get_html_string())


def listAvailZones(awsconnection, doption):
    # session client for ec2
    ec2client = awsconnection.ec2client
    availZones = ec2client.describe_availability_zones()
    x = PrettyTable(["RegionName", "ZoneName"])
    x.align["RegionName"] = "l"
    x.align["ZoneName"] = "l"
    print("Availability Zones")
    for zone in availZones['AvailabilityZones']:
        x.add_row([zone["RegionName"], zone["ZoneName"]])

    if doption == "table":
        print(x)

    if doption == "html":
        print(x.get_html_string())


def listLoadBalancers(awsconnection, doption):
    # 'elb' client
    elbclient = awsconnection.elbclient
    print("Load Balancers")
    x = PrettyTable(["DNS Name", "Scheme", "SourceSecurityGroup"])
    x.align["DNS Name"] = "l"
    lbNames = []
    for item in elbclient.describe_load_balancers()['LoadBalancerDescriptions']:
        x.add_row([item['DNSName'], item['Scheme'], item['SourceSecurityGroup']['GroupName']])

    if doption == "table":
        print(x)

    if doption == "html":
        print(x.get_html_string())

def listCfStacks(awsconnection, doption):
    # cf client
    cfclient = awsconnection.cfclient
    # stack status
    stack_status = [
        'CREATE_IN_PROGRESS',
        'CREATE_FAILED',
        'CREATE_COMPLETE',
        'ROLLBACK_IN_PROGRESS',
        'ROLLBACK_FAILED',
        'DELETE_IN_PROGRESS',
        'DELETE_FAILED',
        'UPDATE_IN_PROGRESS',
        'UPDATE_COMPLETE_CLEANUP_IN_PROGRESS',
        'UPDATE_COMPLETE',
        'UPDATE_ROLLBACK_IN_PROGRESS',
        'UPDATE_ROLLBACK_FAILED',
        'UPDATE_ROLLBACK_COMPLETE_CLEANUP_IN_PROGRESS',
        'UPDATE_ROLLBACK_COMPLETE'
    ]

    x = PrettyTable(["Stack Name", "Stack Status", "Creation Time", "Last Updated Time"])
    x.align["Stack Name"] = "l"
    x.align["Stack Status"] = "l"
    x.align["Creation Time"] = "l"
    x.align["Last Updated Time"] = "l"


    print('\n')
    print('{0}'.format("\033[1m" + "Cloud Formation Stacks in the Account" + "\033[0m"))
    for cfstack in cfclient.list_stacks()['StackSummaries']:
        if cfstack['StackStatus'] in stack_status:
            x.add_row([cfstack['StackName'],
                       cfstack['StackStatus'],
                       cfstack['CreationTime'],
                       cfstack['LastUpdatedTime'] if cfstack.get('LastUpdatedTime') else "--"])

    if doption == "table":
        print(x)

    if doption == "html":
        print(x.get_html_string())

def listRds(awsconnection, doption):
    rdsclient = awsconnection.rdsclient
    for instance in rdsclient.describe_db_instances()['DBInstances']:
        print(instance['DBName'], instance['Endpoint']['Address'], instance['Endpoint']['Port'])

def listAll(awsconnection, doption):
    # all subnets
    subnetInfo(awsconnection, doption)
    # all sec groups
    sgInfo(awsconnection, doption)
    # all ec2 instances
    ec2Instances(awsconnection, doption)
    # list all keypairs
    lisKeyPairs(awsconnection, doption)
    # list availability zones
    print('\n')
    listAvailZones(awsconnection, doption)
    print('\n')
    listLoadBalancers(awsconnection, doption)
    print('\n')
    listRoute53Records(awsconnection, doption)


def parseCmdLineArgs():
    parser = ArgumentParser(description='AwsViewCmdConsole')

    parser.add_argument("-n", "--aws-account-id",
            action="store",
            required=True,
            dest="awsaccountid",
            help="AWS Account Number")

    parser.add_argument("-d", "--display-option",
            action="store",
            required=True,
            choices=["table", "html"],
            dest="doption",
            help="How do you want to display the results?")

    parser.add_argument("-l", "--list",
            action="store",
            required=True,
            choices=["all", "vpcs", "subnets", "route53", "elbs", "cloudformation", "rds"],
            dest="resources",
            help="What AWS resource details you need")

    args = vars(parser.parse_args())

    return args

if __name__=="__main__":
    args = parseCmdLineArgs()
    awsconnection = AwsConnection(args["awsaccountid"])
    callfunc = {
        'all': 'listAll(awsconnection, args["doption"])',
        'vpcs': 'listVpcs(awsconnection, args["doption"])',
        'subnets': 'subnetInfo(awsconnection, args["doption"])',
        'route53': 'listRoute53Records(awsconnection, args["doption"])',
        'elbs': 'listLoadBalancers(awsconnection, args["doption"])',
        'cloudformation': 'listCfStacks(awsconnection, args["doption"])',
        'rds': 'listRds(awsconnection, args["doption"])'
    }

    func = callfunc.get(args["resources"])
    eval(func)
