#!/usr/bin/env python
"""
Shodan CLI

Note: Always run "shodan init <api key>" before trying to execute any other command!

A simple interface to search Shodan, download data and parse compressed JSON files.
The following commands are currently supported:

    count
    download
    init
    myip
    parse
    scan
    search

"""

import click
import collections
import datetime
import gzip
import os
import os.path
import shodan
import simplejson
import time

# Constants
SHODAN_CONFIG_DIR = '~/.shodan/'
ARRAY_SEPARATOR = ';'
COLORIZE_FIELDS = {
    'ip_str': 'green',
    'port': 'yellow',
    'data': 'white',
    'hostnames': 'magenta',
    'org': 'cyan',
}


# Utility methods
def get_api_key():
    shodan_dir = os.path.expanduser(SHODAN_CONFIG_DIR)
    keyfile = shodan_dir + '/api_key'

    # If the file doesn't yet exist let the user know that they need to
    # initialize the shodan cli
    if not os.path.exists(keyfile):
        raise click.ClickException('Please run "shodan init <api key>" before using this command')

    # Make sure it is a read-only file
    os.chmod(keyfile, 0600)

    with open(keyfile, 'r') as fin:
        return fin.read().strip()

    raise click.ClickException('Please run "shodan init <api key>" before using this command')


def escape_data(args):
    return args.encode('ascii', 'replace').replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t')

def timestr():
        return datetime.datetime.utcnow().strftime('%Y-%m-%d')

def open_file(directory, timestr):
        return gzip.open('%s/%s.json.gz' % (directory, timestr), 'a', 1)


@click.group()
def main():
    pass


@main.command()
@click.argument('key', metavar='<api key>')
def init(key):
    """Initialize the Shodan command-line"""
    # Create the directory if necessary
    shodan_dir = os.path.expanduser(SHODAN_CONFIG_DIR)
    if not os.path.isdir(shodan_dir):
        try:
            os.mkdir(shodan_dir)
        except OSError:
            raise click.ClickException('Unable to create directory to store the Shodan API key (%s)' % shodan_dir)

    # Store the API key in the user's directory
    keyfile = shodan_dir + '/api_key'
    with open(keyfile, 'w') as fout:
        fout.write(key.strip())
        click.echo(click.style('Successfully initialized', fg='green'))

    os.chmod(keyfile, 0600)


@main.group()
def alert():
    pass

@alert.command(name='list')
@click.option('--expired', help='Whether or not to show expired alerts.', default=True, type=bool)
def alert_list(expired):
    """Returns the number of results for a search"""
    key = get_api_key()

    # Get the list
    api = shodan.Shodan(key)
    try:
        results = api.alerts(include_expired=expired)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    if len(results) > 0:
        click.echo('# {:14} {:<21} {:<15s}'.format('Alert ID', 'Name', 'IP/ Network'))
        # click.echo('#' * 65)
        for alert in results:
            click.echo(
                '{:16} {:<30} {:<35} '.format(
                    click.style(alert['id'],  fg='yellow'),
                    click.style(alert['name'], fg='cyan'),
                    click.style(', '.join(alert['filters']['ip']), fg='white')
                ),
                nl=False
            )

            if 'expired' in alert and alert['expired']:
                click.echo(click.style('expired', fg='red'))
            else:
                click.echo('')
    else:
        click.echo("You haven't created any alerts yet.")


@alert.command(name='remove')
@click.argument('alert_id', metavar='<alert ID>')
def alert_remove(alert_id):
    """Returns the number of results for a search"""
    key = get_api_key()

    # Get the list
    api = shodan.Shodan(key)
    try:
        results = api.delete_alert(alert_id)
    except shodan.APIError, e:
        raise click.ClickException(e.value)
    click.echo("Alert deleted")


@main.command()
@click.argument('query', metavar='<search query>', nargs=-1)
def count(query):
    """Returns the number of results for a search"""
    key = get_api_key()

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    # Perform the search
    api = shodan.Shodan(key)
    try:
        results = api.count(query)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    click.echo(results['total'])


@main.command()
@click.option('--limit', help='The number of results you want to download. -1 to download all the data possible.', default=1000, type=int)
@click.argument('filename', metavar='<filename>')
@click.argument('query', metavar='<search query>', nargs=-1)
def download(limit, filename, query):
    """Download search results and save them in a compressed JSON file."""
    key = get_api_key()

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    filename = filename.strip()
    if filename == '':
        raise click.ClickException('Empty filename')

    # Add the appropriate extension if it's not there atm
    if not filename.endswith('.json.gz'):
        filename += '.json.gz'

    # Perform the search
    api = shodan.Shodan(key)

    try:
        total = api.count(query)['total']
        info = api.info()
    except:
        raise click.ClickException('The Shodan API is unresponsive at the moment, please try again later.')

    # Print some summary information about the download request
    click.echo('Search query:\t\t\t%s' % query)
    click.echo('Total number of results:\t%s' % total)
    click.echo('Query credits left:\t\t%s' % info['unlocked_left'])
    click.echo('Output file:\t\t\t%s' % filename)

    if limit > total:
        limit = total

    # A limit of -1 means that we should download all the data
    if limit == -1:
        limit = total

    with gzip.open(filename, 'w') as fout:
        count = 0
        try:
            cursor = api.search_cursor(query)
            with click.progressbar(cursor, length=limit) as bar:
                for banner in bar:
                    fout.write(simplejson.dumps(banner) + '\n')
                    count += 1

                    if count >= limit:
                        break
        except:
            pass

        # Let the user know we're done
        if count < limit:
            click.echo(click.style('Notice: fewer results were saved than requested', 'yellow'))
        click.echo(click.style('Saved %s results into file %s' % (count, filename), 'green'))


@main.command()
def info():
    """Shows general information about your account"""
    key = get_api_key()
    api = shodan.Shodan(key)
    try:
        results = api.info()
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    click.echo("""Query credits available: {0}
Scan credits available: {1}
    """.format(results['query_credits'], results['scan_credits']))


@main.command()
@click.option('--color/--no-color', default=True)
@click.option('--fields', help='List of properties to output.', default='ip_str,port,hostnames,data')
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
@click.argument('filename', metavar='<filename>', type=click.Path(exists=True))
def parse(color, fields, separator, filename):
    """Extract information out of compressed JSON files."""
    # Make sure it's some sort of json file
    if not filename.endswith('.json.gz') and not filename.endswith('.json'):
        raise click.ClickException('Invalid file, please make sure it is a valid Shodan JSON file')

    # Strip out any whitespace in the fields and turn them into an array
    fields = [item.strip() for item in fields.split(',')]

    if len(fields) == 0:
        raise click.ClickException('Please define at least one property to show')

    # Create a file handle depending on the filetype
    if filename.endswith('.gz'):
        fin = gzip.open(filename, 'r')
    else:
        fin = open(filename, 'r')

    for line in fin:
        # Convert the JSON into a native Python object
        banner = simplejson.loads(line)
        row = ''

        # Loop over all the fields and print the banner as a row
        for field in fields:
            tmp = ''
            if field in banner and banner[field]:
                field_type = type(banner[field])

                # If the field is an array then merge it together
                if field_type == list:
                    tmp = ';'.join(banner[field])
                elif field_type in [int, float]:
                    tmp = str(banner[field])
                else:
                    tmp = escape_data(banner[field])

                # Colorize certain fields if the user wants it
                if color:
                    tmp = click.style(tmp, fg=COLORIZE_FIELDS.get(field, 'white'))

                # Add the field information to the row
                row += tmp
            row += separator

        click.echo(row)


@main.command()
def myip():
    """Print your external IP address"""
    key = get_api_key()

    api = shodan.Shodan(key)
    try:
        click.echo(api.tools.myip())
    except shodan.APIError, e:
        raise click.ClickException(e.value)


@main.command()
@click.option('--wait', help='How long to wait for results to come back. If this is set to "0" or below return immediately.', default=30, type=int)
@click.argument('netblocks', metavar='<ip address>', nargs=-1)
def scan(wait, netblocks):
    """Scan an IP/ netblock using Shodan."""
    key = get_api_key()
    api = shodan.Shodan(key)

    # Submit the IPs for scanning
    try:
        # Setup an alert to wait for responses
        alert = api.create_alert('Scan', netblocks)

        # Submit the scan
        scan = api.scan(netblocks)
        click.echo(click.style('Success (%s)! ', fg='green') % scan['id'] + '%s host(s) submitted (%s scan credits remaining)' % (scan['count'], scan['credits_left']))

        # Wait for responses
        if wait > 0:
            click.echo('Waiting for results...')

            # Now wait a few seconds for items to get returned
            hosts = collections.defaultdict(list)
            done = False
            while not done:
                try:
                    for banner in api.stream.alert(timeout=wait):
                        click.echo('Open: {0}:{1}'.format(banner['ip_str'], banner['port']))
                        hosts[banner['ip_str']].append(banner)
                except shodan.APIError, e:
                    if done:
                        break

                    scan = api.scan_status(scan['id'])
                    if scan['status'] == 'DONE':
                        done = True
                except Exception, e:
                    raise click.ClickException(repr(e))

            # Cleanup
            click.echo('Cleaning up...')
            api.delete_alert(alert['id'])

            if hosts:
                click.echo('')
                click.echo('Summary')
                click.echo('-------')
                click.echo('')

                for ip in hosts:
                    click.echo(click.style(ip, fg='cyan'))

                    host = hosts[ip][-1]
                    if 'location' in host and 'country_name' in host['location'] and host['location']['country_name']:
                        click.echo('Country: {0}'.format(host['location']['country_name']))

                        if 'city' in host['location'] and host['location']['city']:
                            click.echo('City: {0}'.format(host['location']['city']))
                    if 'org' in host and host['org']:
                        click.echo('Organization: {0}'.format(host['org']))
                    if 'os' in host and host['os']:
                        click.echo('Operating system: {0}'.format(host['os']))
                    click.echo('')

                    # Print all the open ports:
                    for banner in hosts[ip]:
                        click.echo('Port: {0}'.format(click.style(str(banner['port']), fg='yellow')))
                        click.echo(banner['data'].strip())
                        click.echo('')
            else:
                click.echo('No open ports found or the host has been recently crawled and cant get scanned again so soon.')
    except shodan.APIError, e:
        raise click.ClickException(e.value)


@main.command()
@click.option('--color/--no-color', default=True)
@click.option('--fields', help='List of properties to show in the search results.', default='ip_str,port,hostnames,data')
@click.option('--limit', help='The number of search results that should be returned. Maximum: 1000', default=100, type=int)
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
@click.argument('query', metavar='<search query>', nargs=-1)
def search(color, fields, limit, separator, query):
    """Search the Shodan database"""
    key = get_api_key()

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    # For now we only allow up to 1000 results at a time
    if limit > 1000:
        raise click.ClickException('Too many results requested, maximum is 1,000')

    # Strip out any whitespace in the fields and turn them into an array
    fields = [item.strip() for item in fields.split(',')]

    if len(fields) == 0:
        raise click.ClickException('Please define at least one property to show')

    # Perform the search
    api = shodan.Shodan(key)
    try:
        results = api.search(query, limit=limit)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    # We buffer the entire output so we can use click's pager functionality
    output = ''
    for banner in results['matches']:
        row = ''

        # Loop over all the fields and print the banner as a row
        for field in fields:
            tmp = ''
            if field in banner and banner[field]:
                field_type = type(banner[field])

                # If the field is an array then merge it together
                if field_type == list:
                    tmp = ';'.join(banner[field])
                elif field_type in [int, float]:
                    tmp = str(banner[field])
                else:
                    tmp = escape_data(banner[field])

                # Colorize certain fields if the user wants it
                if color:
                    tmp = click.style(tmp, fg=COLORIZE_FIELDS.get(field, 'white'))

                # Add the field information to the row
                row += tmp
            row += separator

            # click.echo(out + separator, nl=False)
        output += row + '\n'
        # click.echo('')
    click.echo_via_pager(output)


@main.command()
@click.option('--limit', help='The number of results to return.', default=10, type=int)
@click.option('--facets', help='List of facets to get statistics for.', default='country,org')
@click.argument('query', metavar='<search query>', nargs=-1)
def stats(limit, facets, query):
    # Setup Shodan
    key = get_api_key()
    api = shodan.Shodan(key)

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    facets = facets.split(',')
    facets = [(facet, limit) for facet in facets]

    # Perform the search
    api = shodan.Shodan(key)
    try:
        results = api.count(query, facets=facets)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    # Print the stats tables
    for facet in results['facets']:
        print '# Top %s %s' % (limit, facet)

        for item in results['facets'][facet]:
            print '  {:28s}'.format(item['value'].encode('ascii', errors='replace')), '{:12,d}'.format(item['count'])

        print ''


@main.command()
@click.option('--color/--no-color', default=True)
@click.option('--fields', help='List of properties to output.', default='ip_str,port,hostnames,data')
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
@click.option('--limit', help='The number of results you want to download. -1 to download all the data possible.', default=-1, type=int)
@click.option('--datadir', help='Save the stream data into the specified directory as .json.gz files.', default=None, type=str)
@click.option('--ports', help='A comma-separated list of ports to grab data on.', default=None, type=str)
@click.option('--quiet', help='Disable the printing of information to the screen.', is_flag=True)
def stream(color, fields, separator, limit, datadir, ports, quiet):
    """Stream data in real-time."""
    # Setup the Shodan API
    key = get_api_key()
    api = shodan.Shodan(key)

    # Strip out any whitespace in the fields and turn them into an array
    fields = [item.strip() for item in fields.split(',')]

    if len(fields) == 0:
        raise click.ClickException('Please define at least one property to show')

    # Turn the list of ports into integers
    if ports:
        try:
            ports = [int(item.strip()) for item in ports.split(',')]
        except:
            raise click.ClickException('Invalid list of ports')

    # Decide which stream to subscribe to based on whether or not ports were selected
    if ports:
        stream = api.stream.ports(ports)
    else:
        stream = api.stream.banners()

    counter = 0
    quit = False
    last_time = timestr()
    fout = None

    if datadir:
        fout = open_file(datadir, last_time)

    while not quit:
        try:
            for banner in stream:
                # Limit the number of results to output
                if limit > 0:
                    counter += 1

                    if counter > limit:
                        quit = True
                        break

                # Write the data to the file
                if datadir:
                    cur_time = timestr()
                    if cur_time != last_time:
                            last_time = cur_time
                            fout.close()
                            fout = open_file(datadir, last_time)
                    fout.write(simplejson.dumps(banner) + '\n')

                # Print the banner information to stdout
                if not quiet:
                    row = ''

                    # Loop over all the fields and print the banner as a row
                    for field in fields:
                        tmp = ''
                        if field in banner and banner[field]:
                            field_type = type(banner[field])

                            # If the field is an array then merge it together
                            if field_type == list:
                                tmp = ';'.join(banner[field])
                            elif field_type in [int, float]:
                                tmp = str(banner[field])
                            else:
                                tmp = escape_data(banner[field])

                            # Colorize certain fields if the user wants it
                            if color:
                                tmp = click.style(tmp, fg=COLORIZE_FIELDS.get(field, 'white'))

                            # Add the field information to the row
                            row += tmp
                        row += separator

                    click.echo(row)
        except KeyboardInterrupt:
            quit = True
        except:
            # For other errors lets just wait a few seconds and try to reconnect again
            time.sleep(2)


if __name__ == '__main__':
    main()
