#!python
import argparse
import subprocess
import sys

from pathlib import Path

import boto3

from botocore.exceptions import ClientError


ec2_client = boto3.client('ec2')
default_region = ec2_client.meta.region_name
if not default_region:
    default_region = 'us-east-1'

default_key_file_path_private = f'{str(Path.home())}/.ssh/id_rsa'
default_key_file_path_public = f'{str(Path.home())}/.ssh/id_rsa.pub'
default_user = 'ec2-user'
parser = argparse.ArgumentParser(description='Connect to any ec2 instance your IAM user has rights for and that are reachable.')
parser.add_argument('instance_id', type=str, help='Id of the instance to connect to. Example: i-0bd8e1251a4713c17')
parser.add_argument('--os-user', type=str, default=default_user, help=f'OS username to use for connection. Default: {default_user}')
parser.add_argument('--public-key-file', type=argparse.FileType('r'),
                    default=default_key_file_path_public, help=f'Public key file to use for connection. Default: {default_key_file_path_public}')
parser.add_argument('--private-key-file', type=argparse.FileType('r'),
                    default=default_key_file_path_private, help=f'Private key file to use for connection. Default: {default_key_file_path_private}')
parser.add_argument('--regions', type=str, nargs='+', default=[default_region], help=f'Look for the instance in the given regions. Default: {default_region}')
args = parser.parse_args()

def connect(region):
    ec2_client = boto3.client('ec2', region_name=region)
    try:
        print(f'Looking for instance {args.instance_id} in region {region}')
        response = ec2_client.describe_instances(InstanceIds=[args.instance_id])
    except ClientError as e:
        if e.response['Error']['Code'] != 'InvalidInstanceID.NotFound':
            raise
        else:
            print(f'Did not find instance {args.instance_id} in region {region}')
            return None
    instance = response['Reservations'][0]['Instances'][0]
    zone = instance['Placement']['AvailabilityZone']
    private_ip = instance['PrivateIpAddress']
    connect_client = boto3.client('ec2-instance-connect', region_name=region)
    connect_client.send_ssh_public_key(
        InstanceId=args.instance_id,
        InstanceOSUser=args.os_user,
        SSHPublicKey=args.public_key_file.read(),
        AvailabilityZone=zone
    )
    return private_ip

ip_to_connect_to = None
for region in args.regions:
    ip_to_connect_to = connect(region)
    if ip_to_connect_to:
        break
if not ip_to_connect_to:
    print(f'Error: Did not find instance id {args.instance_id} in any region: {args.regions}')
    sys.exit(1)
subprocess.run(["ssh", "-i", args.private_key_file.name, f"{args.os_user}@{ip_to_connect_to}"])

