#!/usr/bin/python
"""
Shodan CLI

Note: Always run "shodan init <api key>" before trying to execute any other command!

A simple interface to search Shodan, download data and parse compressed JSON files.
The following commands are currently supported:

    alert
    count
    download
    host
    init
    myip
    parse
    scan
    search

"""

import click
import collections
import datetime
import gzip
import itertools
import os
import os.path
import shodan
import simplejson
import socket
import sys
import threading
import time

# Constants
SHODAN_CONFIG_DIR = '~/.shodan/'
ARRAY_SEPARATOR = ';'
COLORIZE_FIELDS = {
    'ip_str': 'green',
    'port': 'yellow',
    'data': 'white',
    'hostnames': 'magenta',
    'org': 'cyan',
    'vulns': 'red',
}

# Make "-h" work like "--help"
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])

# Utility methods
def get_api_key():
    shodan_dir = os.path.expanduser(SHODAN_CONFIG_DIR)
    keyfile = shodan_dir + '/api_key'

    # If the file doesn't yet exist let the user know that they need to
    # initialize the shodan cli
    if not os.path.exists(keyfile):
        raise click.ClickException('Please run "shodan init <api key>" before using this command')

    # Make sure it is a read-only file
    os.chmod(keyfile, 0600)

    with open(keyfile, 'r') as fin:
        return fin.read().strip()

    raise click.ClickException('Please run "shodan init <api key>" before using this command')


def escape_data(args):
    return args.encode('ascii', 'replace').replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t')

def timestr():
        return datetime.datetime.utcnow().strftime('%Y-%m-%d')

def open_file(directory, timestr):
        return gzip.open('%s/%s.json.gz' % (directory, timestr), 'a', 1)


@click.group(context_settings=CONTEXT_SETTINGS)
def main():
    pass


@main.command()
@click.argument('key', metavar='<api key>')
def init(key):
    """Initialize the Shodan command-line"""
    # Create the directory if necessary
    shodan_dir = os.path.expanduser(SHODAN_CONFIG_DIR)
    if not os.path.isdir(shodan_dir):
        try:
            os.mkdir(shodan_dir)
        except OSError:
            raise click.ClickException('Unable to create directory to store the Shodan API key (%s)' % shodan_dir)

    # Make sure it's a valid API key
    key = key.strip()
    try:
        api = shodan.Shodan(key)
        test = api.info()
    except shodan.APIError, e:
        raise click.ClickException('Invalid API key')

    # Store the API key in the user's directory
    keyfile = shodan_dir + '/api_key'
    with open(keyfile, 'w') as fout:
        fout.write(key.strip())
        click.echo(click.style('Successfully initialized', fg='green'))

    os.chmod(keyfile, 0600)


@main.group()
def alert():
    pass


@alert.command(name='clear')
def alert_clear():
    """Remove all alerts"""
    key = get_api_key()

    # Get the list
    api = shodan.Shodan(key)
    try:
        alerts = api.alerts()
        for alert in alerts:
            click.echo('Removing {} ({})'.format(alert['name'], alert['id']))
            api.delete_alert(alert['id'])
    except shodan.APIError, e:
        raise click.ClickException(e.value)
    click.echo("Alerts deleted")

@alert.command(name='list')
@click.option('--expired', help='Whether or not to show expired alerts.', default=True, type=bool)
def alert_list(expired):
    """List all the active alerts"""
    key = get_api_key()

    # Get the list
    api = shodan.Shodan(key)
    try:
        results = api.alerts(include_expired=expired)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    if len(results) > 0:
        click.echo('# {:14} {:<21} {:<15s}'.format('Alert ID', 'Name', 'IP/ Network'))
        # click.echo('#' * 65)
        for alert in results:
            click.echo(
                '{:16} {:<30} {:<35} '.format(
                    click.style(alert['id'],  fg='yellow'),
                    click.style(alert['name'], fg='cyan'),
                    click.style(', '.join(alert['filters']['ip']), fg='white')
                ),
                nl=False
            )

            if 'expired' in alert and alert['expired']:
                click.echo(click.style('expired', fg='red'))
            else:
                click.echo('')
    else:
        click.echo("You haven't created any alerts yet.")


@alert.command(name='remove')
@click.argument('alert_id', metavar='<alert ID>')
def alert_remove(alert_id):
    """Remove the specified alert"""
    key = get_api_key()

    # Get the list
    api = shodan.Shodan(key)
    try:
        results = api.delete_alert(alert_id)
    except shodan.APIError, e:
        raise click.ClickException(e.value)
    click.echo("Alert deleted")


@main.command()
@click.argument('query', metavar='<search query>', nargs=-1)
def count(query):
    """Returns the number of results for a search"""
    key = get_api_key()

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    # Perform the search
    api = shodan.Shodan(key)
    try:
        results = api.count(query)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    click.echo(results['total'])


@main.command()
@click.option('--limit', help='The number of results you want to download. -1 to download all the data possible.', default=1000, type=int)
@click.argument('filename', metavar='<filename>')
@click.argument('query', metavar='<search query>', nargs=-1)
def download(limit, filename, query):
    """Download search results and save them in a compressed JSON file."""
    key = get_api_key()

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    filename = filename.strip()
    if filename == '':
        raise click.ClickException('Empty filename')

    # Add the appropriate extension if it's not there atm
    if not filename.endswith('.json.gz'):
        filename += '.json.gz'

    # Perform the search
    api = shodan.Shodan(key)

    try:
        total = api.count(query)['total']
        info = api.info()
    except:
        raise click.ClickException('The Shodan API is unresponsive at the moment, please try again later.')

    # Print some summary information about the download request
    click.echo('Search query:\t\t\t%s' % query)
    click.echo('Total number of results:\t%s' % total)
    click.echo('Query credits left:\t\t%s' % info['unlocked_left'])
    click.echo('Output file:\t\t\t%s' % filename)

    if limit > total:
        limit = total

    # A limit of -1 means that we should download all the data
    if limit == -1:
        limit = total

    with gzip.open(filename, 'w') as fout:
        count = 0
        try:
            cursor = api.search_cursor(query)
            with click.progressbar(cursor, length=limit) as bar:
                for banner in bar:
                    fout.write(simplejson.dumps(banner) + '\n')
                    count += 1

                    if count >= limit:
                        break
        except:
            pass

        # Let the user know we're done
        if count < limit:
            click.echo(click.style('Notice: fewer results were saved than requested', 'yellow'))
        click.echo(click.style('Saved %s results into file %s' % (count, filename), 'green'))


@main.command()
@click.option('--format', help='The output format for the host information. Possible values are: pretty, csv, tsv. (placeholder)', default='pretty', type=str)
@click.argument('ip', metavar='<ip address>')
def host(format, ip):
    """Scan an IP/ netblock using Shodan."""
    key = get_api_key()
    api = shodan.Shodan(key)

    try:
        host = api.host(ip)

        # General info
        click.echo(click.style(ip, fg='green'))
        if len(host['hostnames']) > 0:
            click.echo('{:25s}{}'.format('Hostnames:', ';'.join(host['hostnames'])))

        if 'city' in host and host['city']:
            click.echo('{:25s}{}'.format('City:', host['city']))

        if 'country_name' in host and host['country_name']:
            click.echo('{:25s}{}'.format('Country:', host['country_name']))

        if 'os' in host and host['os']:
            click.echo('{:25s}{}'.format('Operating System:', host['os']))

        if 'org' in host and host['org']:
            click.echo('{:25s}{}'.format('Organization:', host['org']))

        click.echo('{:25s}{}'.format('Number of open ports:', len(host['ports'])))

        # Output the vulnerabilities the host has
        if 'vulns' in host and len(host['vulns']) > 0:
            vulns = []
            for vuln in host['vulns']:
                if vuln.startswith('!'):
                    continue
                if vuln.upper() == 'CVE-2014-0160':
                    vulns.append(click.style('Heartbleed', fg='red'))
                else:
                    vulns.append(click.style(vuln, fg='red'))

            if len(vulns) > 0:
                click.echo('{:25s}'.format('Vulnerabilities:'), nl=False)

                for vuln in vulns:
                    click.echo(vuln + '\t', nl=False)

                click.echo('')

        click.echo('')

        click.echo('Ports:')
        for banner in sorted(host['data'], key=lambda k: k['port']):
            product = ''
            version = ''
            if 'product' in banner:
                product = banner['product']
            if 'version' in banner:
                version = '({})'.format(banner['version'])

            click.echo(click.style('{:>7d} '.format(banner['port']), fg='cyan'), nl=False)
            click.echo('{} {}'.format(product, version))

            # Show optional ssl info
            if 'ssl' in banner:
                if 'versions' in banner['ssl']:
                    click.echo('\t|-- SSL Versions: {}'.format(', '.join([version for version in sorted(banner['ssl']['versions']) if not version.startswith('-')])))
                if 'dhparams' in banner['ssl']:
                    click.echo('\t|-- Diffie-Hellman Parameters:')
                    click.echo('\t\t{:15s}{}\n\t\t{:15s}{}'.format('Bits:', banner['ssl']['dhparams']['bits'], 'Generator:', banner['ssl']['dhparams']['generator']))
                    if 'fingerprint' in banner['ssl']['dhparams']:
                        click.echo('\t\t{:15s}{}'.format('Fingerprint:', banner['ssl']['dhparams']['fingerprint']))
    except shodan.APIError, e:
        raise click.ClickException(e.value)


@main.command()
def info():
    """Shows general information about your account"""
    key = get_api_key()
    api = shodan.Shodan(key)
    try:
        results = api.info()
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    click.echo("""Query credits available: {0}
Scan credits available: {1}
    """.format(results['query_credits'], results['scan_credits']))


@main.command()
@click.option('--color/--no-color', default=True)
@click.option('--fields', help='List of properties to output.', default='ip_str,port,hostnames,data')
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
@click.argument('filename', metavar='<filename>', type=click.Path(exists=True))
def parse(color, fields, separator, filename):
    """Extract information out of compressed JSON files."""
    # Make sure it's some sort of json file
    if not filename.endswith('.json.gz') and not filename.endswith('.json'):
        raise click.ClickException('Invalid file, please make sure it is a valid Shodan JSON file')

    # Strip out any whitespace in the fields and turn them into an array
    fields = [item.strip() for item in fields.split(',')]

    if len(fields) == 0:
        raise click.ClickException('Please define at least one property to show')

    # Create a file handle depending on the filetype
    if filename.endswith('.gz'):
        fin = gzip.open(filename, 'r')
    else:
        fin = open(filename, 'r')

    for line in fin:
        # Convert the JSON into a native Python object
        banner = simplejson.loads(line)
        row = ''

        # Loop over all the fields and print the banner as a row
        for field in fields:
            tmp = ''
            if field in banner and banner[field]:
                field_type = type(banner[field])

                # If the field is an array then merge it together
                if field_type == list:
                    tmp = ';'.join(banner[field])
                elif field_type in [int, float]:
                    tmp = str(banner[field])
                else:
                    tmp = escape_data(banner[field])

                # Colorize certain fields if the user wants it
                if color:
                    tmp = click.style(tmp, fg=COLORIZE_FIELDS.get(field, 'white'))

                # Add the field information to the row
                row += tmp
            row += separator

        click.echo(row)


@main.command()
def myip():
    """Print your external IP address"""
    key = get_api_key()

    api = shodan.Shodan(key)
    try:
        click.echo(api.tools.myip())
    except shodan.APIError, e:
        raise click.ClickException(e.value)


@main.group()
def scan():
    pass


@scan.command(name='internet')
@click.option('--quiet', help='Disable the printing of information to the screen.', default=False, is_flag=True)
@click.argument('port', type=int)
@click.argument('protocol', type=str)
def scan_internet(quiet, port, protocol):
    """Scan the Internet for a specific port and protocol using the Shodan infrastructure."""
    key = get_api_key()
    api = shodan.Shodan(key)

    try:
        # Submit the request to Shodan
        click.echo('Submitting Internet scan to Shodan...', nl=False)
        scan = api.scan_internet(port, protocol)
        click.echo('Done')

        # If the requested port is part of the regular Shodan crawling, then
        # we don't know when the scan is done so lets return immediately and
        # let the user decide when to stop waiting for further results.
        official_ports = api.ports()
        if port in official_ports:
            click.echo('The requested port is already indexed by Shodan. A new scan for the port has been launched, please subscribe to the real-time stream for results.')
        else:
            # Create the output file
            filename = '{0}-{1}.json.gz'.format(port, protocol)
            counter = 0
            with gzip.open(filename, 'w') as fout:
                click.echo('Saving results to file: {0}'.format(filename))

                # Start listening for results
                done = False

                # Keep listening for results until the scan is done
                click.echo('Waiting for data, please stand by...')
                while not done:
                    try:
                        for banner in api.stream.ports([port], timeout=30):
                            counter += 1
                            fout.write(simplejson.dumps(banner) + '\n')

                            if not quiet:
                                click.echo('{0:<40} {1:<20} {2}'.format(
                                        click.style(banner['ip_str'], fg=COLORIZE_FIELDS['ip_str']),
                                        click.style(str(banner['port']), fg=COLORIZE_FIELDS['port']),
                                        ';'.join(banner['hostnames'])
                                    )
                                )
                    except shodan.APIError, e:
                        # We stop waiting for results if the scan has been processed by the crawlers and
                        # there haven't been new results in a while
                        if done:
                            break

                        scan = api.scan_status(scan['id'])
                        if scan['status'] == 'DONE':
                            done = True
                    except socket.timeout, e:
                        # We stop waiting for results if the scan has been processed by the crawlers and
                        # there haven't been new results in a while
                        if done:
                            break

                        scan = api.scan_status(scan['id'])
                        if scan['status'] == 'DONE':
                            done = True
                    except Exception, e:
                        raise click.ClickException(repr(e))
            click.echo('Scan finished: {0} devices found'.format(counter))
    except shodan.APIError, e:
        raise click.ClickException(e.value)


@scan.command(name='protocols')
def scan_protocols():
    """List the protocols that you can scan with using Shodan."""
    key = get_api_key()
    api = shodan.Shodan(key)
    try:
        protocols = api.protocols()

        for name, description in protocols.iteritems():
            click.echo(click.style('{0:<30}'.format(name), fg='cyan') + description)
    except shodan.APIError, e:
        raise click.ClickException(e.value)


@scan.command(name='submit')
@click.option('--wait', help='How long to wait for results to come back. If this is set to "0" or below return immediately.', default=20, type=int)
@click.option('--filename', help='Save the results in the given file.', default='', type=str)
@click.argument('netblocks', metavar='<ip address>', nargs=-1)
def scan_submit(wait, filename, netblocks):
    """Scan an IP/ netblock using Shodan."""
    key = get_api_key()
    api = shodan.Shodan(key)
    alert = None

    # Submit the IPs for scanning
    try:
        # Submit the scan
        scan = api.scan(netblocks)

        now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')

        click.echo('')
        click.echo('Starting Shodan scan at {} ({} scan credits left)'.format(now, scan['credits_left']))

        # Return immediately
        if wait <= 0:
            click.echo('Exiting now, not waiting for results. Use the API or website to retrieve the results of the scan.')
        else:
            # Setup an alert to wait for responses
            alert = api.create_alert('Scan: {}'.format(', '.join(netblocks)), netblocks)

            # Create the output file if necessary
            filename = filename.strip()
            fout = None
            if filename != '':
                # Add the appropriate extension if it's not there atm
                if not filename.endswith('.json.gz'):
                    filename += '.json.gz'
                fout = gzip.open(filename, 'w')

            # Start a spinner
            finished_event = threading.Event()
            progress_bar_thread = threading.Thread(target=async_spinner, args=(finished_event,))
            progress_bar_thread.start()

            # Now wait a few seconds for items to get returned
            hosts = collections.defaultdict(dict)
            done = False
            scan_start = time.time()
            cache = {}
            while not done:
                try:
                    for banner in api.stream.alert(aid=alert['id'], timeout=wait):
                        ip = banner.get('ip', banner.get('ipv6', None))
                        if not ip:
                            continue

                        # Don't show duplicate banners
                        cache_key = '{}:{}'.format(ip, banner['port'])
                        if cache_key not in cache:
                            hosts[banner['ip_str']][banner['port']] = banner
                            cache[cache_key] = True

                        # If we've grabbed data for more than 60 seconds it might just be a busy network and we should move on
                        if time.time() - scan_start >= 60:
                            scan = api.scan_status(scan['id'])
                            if scan['status'] == 'DONE':
                                done = True
                                break

                except shodan.APIError, e:
                    # If the connection timed out before the timeout, that means the streaming server
                    # that the user tried to reach is down. In that case, lets wait briefly and try
                    # to connect again!
                    if (time.time() - scan_start) < wait:
                        time.sleep(0.5)
                        continue

                    # Exit if the scan was flagged as done somehow
                    if done:
                        break

                    scan = api.scan_status(scan['id'])
                    if scan['status'] == 'DONE':
                        done = True
                except socket.timeout, e:
                    # If the connection timed out before the timeout, that means the streaming server
                    # that the user tried to reach is down. In that case, lets wait a second and try
                    # to connect again!
                    if (time.time() - scan_start) < wait:
                        continue

                    done = True
                except Exception, e:
                    raise click.ClickException(repr(e))

            finished_event.set()
            progress_bar_thread.join()

            def print_field(name, value):
                click.echo('  {:25s}{}'.format(name, value))

            def print_banner(banner):
                click.echo('    {:20s}'.format(click.style(str(banner['port']), fg='green') + '/' + banner['transport']), nl=False)

                if 'product' in banner:
                    click.echo(banner['product'], nl=False)

                    if 'version' in banner:
                        click.echo(' ({})'.format(banner['version']), nl=False)

                click.echo('')

                # Show optional ssl info
                if 'ssl' in banner:
                    if 'versions' in banner['ssl']:
                        # Only print SSL versions if they were successfully tested
                        versions = [version for version in sorted(banner['ssl']['versions']) if not version.startswith('-')]
                        if len(versions) > 0:
                            click.echo('    |-- SSL Versions: {}'.format(', '.join(versions)))
                    if 'dhparams' in banner['ssl']:
                        click.echo('    |-- Diffie-Hellman Parameters:')
                        click.echo('        {:15s}{}\n        {:15s}{}'.format('Bits:', banner['ssl']['dhparams']['bits'], 'Generator:', banner['ssl']['dhparams']['generator']))
                        if 'fingerprint' in banner['ssl']['dhparams']:
                            click.echo('        {:15s}{}'.format('Fingerprint:', banner['ssl']['dhparams']['fingerprint']))

            if hosts:
                # Remove the remaining spinner character
                click.echo('\b ')

                for ip in sorted(hosts):
                    host = hosts[ip].items()[0][1]

                    click.echo(click.style(ip, fg='cyan'), nl=False)
                    if 'hostnames' in host and host['hostnames']:
                        click.echo(' ({})'.format(', '.join(host['hostnames'])), nl=False)
                    click.echo('')

                    if 'location' in host and 'country_name' in host['location'] and host['location']['country_name']:
                        print_field('Country', host['location']['country_name'])

                        if 'city' in host['location'] and host['location']['city']:
                            print_field('City', host['location']['city'])
                    if 'org' in host and host['org']:
                        print_field('Organization', host['org'])
                    if 'os' in host and host['os']:
                        print_field('Operating System', host['os'])
                    click.echo('')

                    # Output the vulnerabilities the host has
                    if 'vulns' in host and len(host['vulns']) > 0:
                        vulns = []
                        for vuln in host['vulns']:
                            if vuln.startswith('!'):
                                continue
                            if vuln.upper() == 'CVE-2014-0160':
                                vulns.append(click.style('Heartbleed', fg='red'))
                            else:
                                vulns.append(click.style(vuln, fg='red'))

                        if len(vulns) > 0:
                            click.echo('  {:25s}'.format('Vulnerabilities:'), nl=False)

                            for vuln in vulns:
                                click.echo(vuln + '\t', nl=False)

                            click.echo('')

                    # Print all the open ports:
                    click.echo('  Open Ports:')
                    for port in sorted(hosts[ip]):
                        print_banner(hosts[ip][port])

                        # Save the banner in a file if necessary
                        if fout:
                            fout.write(simplejson.dumps(hosts[ip][port]) + '\n')

                    click.echo('')
            else:
                # Prepend a \b to remove the spinner
                click.echo('\bNo open ports found or the host has been recently crawled and cant get scanned again so soon.')
    except shodan.APIError, e:
        raise click.ClickException(e.value)
    finally:
        # Remove any alert
        if alert:
            api.delete_alert(alert['id'])


@main.command()
@click.option('--color/--no-color', default=True)
@click.option('--fields', help='List of properties to show in the search results.', default='ip_str,port,hostnames,data')
@click.option('--limit', help='The number of search results that should be returned. Maximum: 1000', default=100, type=int)
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
@click.argument('query', metavar='<search query>', nargs=-1)
def search(color, fields, limit, separator, query):
    """Search the Shodan database"""
    key = get_api_key()

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    # For now we only allow up to 1000 results at a time
    if limit > 1000:
        raise click.ClickException('Too many results requested, maximum is 1,000')

    # Strip out any whitespace in the fields and turn them into an array
    fields = [item.strip() for item in fields.split(',')]

    if len(fields) == 0:
        raise click.ClickException('Please define at least one property to show')

    # Perform the search
    api = shodan.Shodan(key)
    try:
        results = api.search(query, limit=limit)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    # We buffer the entire output so we can use click's pager functionality
    output = ''
    for banner in results['matches']:
        row = ''

        # Loop over all the fields and print the banner as a row
        for field in fields:
            tmp = ''
            if field in banner and banner[field]:
                field_type = type(banner[field])

                # If the field is an array then merge it together
                if field_type == list:
                    tmp = ';'.join(banner[field])
                elif field_type in [int, float]:
                    tmp = str(banner[field])
                else:
                    tmp = escape_data(banner[field])

                # Colorize certain fields if the user wants it
                if color:
                    tmp = click.style(tmp, fg=COLORIZE_FIELDS.get(field, 'white'))

                # Add the field information to the row
                row += tmp
            row += separator

            # click.echo(out + separator, nl=False)
        output += row + '\n'
        # click.echo('')
    click.echo_via_pager(output)


@main.command()
@click.option('--limit', help='The number of results to return.', default=10, type=int)
@click.option('--facets', help='List of facets to get statistics for.', default='country,org')
@click.argument('query', metavar='<search query>', nargs=-1)
def stats(limit, facets, query):
    # Setup Shodan
    key = get_api_key()
    api = shodan.Shodan(key)

    # Create the query string out of the provided tuple
    query = ' '.join(query).strip()

    # Make sure the user didn't supply an empty string
    if query == '':
        raise click.ClickException('Empty search query')

    facets = facets.split(',')
    facets = [(facet, limit) for facet in facets]

    # Perform the search
    api = shodan.Shodan(key)
    try:
        results = api.count(query, facets=facets)
    except shodan.APIError, e:
        raise click.ClickException(e.value)

    # Print the stats tables
    for facet in results['facets']:
        click.echo('Top {} Results for Facet: {}'.format(limit, facet))

        for item in results['facets'][facet]:
            click.echo(click.style('{:28s}'.format(item['value'].encode('ascii', errors='replace')), fg='cyan'), nl=False)
            click.echo(click.style('{:12,d}'.format(item['count']), fg='green'))

        click.echo('')


@main.command()
@click.option('--color/--no-color', default=True)
@click.option('--fields', help='List of properties to output.', default='ip_str,port,hostnames,data')
@click.option('--separator', help='The separator between the properties of the search results.', default='\t')
@click.option('--limit', help='The number of results you want to download. -1 to download all the data possible.', default=-1, type=int)
@click.option('--datadir', help='Save the stream data into the specified directory as .json.gz files.', default=None, type=str)
@click.option('--ports', help='A comma-separated list of ports to grab data on.', default=None, type=str)
@click.option('--quiet', help='Disable the printing of information to the screen.', is_flag=True)
@click.option('--streamer', help='Specify a custom Shodan stream server to use for grabbing data.', default='https://stream.shodan.io', type=str)
def stream(color, fields, separator, limit, datadir, ports, quiet, streamer):
    """Stream data in real-time."""
    # Setup the Shodan API
    key = get_api_key()
    api = shodan.Shodan(key)

    # Temporarily change the baseurl
    api.stream.base_url = streamer

    # Strip out any whitespace in the fields and turn them into an array
    fields = [item.strip() for item in fields.split(',')]

    if len(fields) == 0:
        raise click.ClickException('Please define at least one property to show')

    # Turn the list of ports into integers
    if ports:
        try:
            ports = [int(item.strip()) for item in ports.split(',')]
        except:
            raise click.ClickException('Invalid list of ports')

    # Decide which stream to subscribe to based on whether or not ports were selected
    if ports:
        stream = api.stream.ports(ports)
    else:
        stream = api.stream.banners()

    counter = 0
    quit = False
    last_time = timestr()
    fout = None

    if datadir:
        fout = open_file(datadir, last_time)

    while not quit:
        try:
            for banner in stream:
                # Limit the number of results to output
                if limit > 0:
                    counter += 1

                    if counter > limit:
                        quit = True
                        break

                # Write the data to the file
                if datadir:
                    cur_time = timestr()
                    if cur_time != last_time:
                            last_time = cur_time
                            fout.close()
                            fout = open_file(datadir, last_time)
                    fout.write(simplejson.dumps(banner) + '\n')

                # Print the banner information to stdout
                if not quiet:
                    row = ''

                    # Loop over all the fields and print the banner as a row
                    for field in fields:
                        tmp = ''
                        if field in banner and banner[field]:
                            field_type = type(banner[field])

                            # If the field is an array then merge it together
                            if field_type == list:
                                tmp = ';'.join(banner[field])
                            elif field_type in [int, float]:
                                tmp = str(banner[field])
                            else:
                                tmp = escape_data(banner[field])

                            # Colorize certain fields if the user wants it
                            if color:
                                tmp = click.style(tmp, fg=COLORIZE_FIELDS.get(field, 'white'))

                            # Add the field information to the row
                            row += tmp
                        row += separator

                    click.echo(row)
        except KeyboardInterrupt:
            quit = True
        except:
            # For other errors lets just wait a few seconds and try to reconnect again
            time.sleep(2)


def async_spinner(finished):
    spinner = itertools.cycle(['-', '/', '|', '\\'])
    while not finished.is_set():
        sys.stdout.write('\b{}'.format(spinner.next()))
        sys.stdout.flush()
        finished.wait(0.2)

if __name__ == '__main__':
    main()
