#!/usr/bin/env python

from __future__ import print_function

import argparse
import boto3
from six.moves.urllib import parse
from six.moves import queue
import threading

import sys
import time
from botocore.exceptions import EndpointConnectionError, ClientError


def safe_print(to_print, **kwargs):
    # let's refresh status by default
    status_refresh = kwargs.pop('status_refresh', True)

    with print_lock:
        # first delete status we writen in the past to stderr
        if 'file' in kwargs and kwargs['file'] == sys.stderr:
            to_print = "\r\033[K" + to_print
            print(to_print, **kwargs)
        else:
            print("\r\033[K", end="", file=sys.stderr)
            print(to_print, **kwargs)

    if status_refresh:
        print_status()


def print_status(**kwargs):
    ending = kwargs.pop('end', "")

    if not (args.disable_progress
            or (args.limit != 0
                and counters['matches_count'].value() > args.limit)):
        safe_print('Files processed: {}/{}  Records matched: {}  '
                   'Failed requests: {}'
                   .format(counters['files_processed'].value(),
                           counters['total_files'].value(),
                           counters['matches_count'].value(),
                           counters['failed_requests'].value()),
                   end=ending, file=sys.stderr, status_refresh=False)


class S3ListThread(threading.Thread):
    def __init__(self, the_bucket, the_prefix, files_queue):

        threading.Thread.__init__(self)
        self.the_bucket = the_bucket
        self.the_prefix = the_prefix
        self.files_queue = files_queue

    def run(self):
        paginator = s3.get_paginator('list_objects_v2')
        pages = paginator.paginate(
            Bucket=self.the_bucket,
            Prefix=self.the_prefix)

        print_status()
        for page in pages:
            if page['KeyCount'] == 0:
                safe_print('No files found for prefix {}'
                           .format(self.the_prefix),
                           file=sys.stderr, status_refresh=False)
                break

            if args.limit != 0 \
                    and (counters['matches_count'].value() >= args.limit):
                # limit reached. No more list results needed
                return

            if 'Contents' not in page:
                return

            for obj in page['Contents']:
                # skip 0 bytes files as boto3 deserializer will throw exceptions
                # for them
                if obj['Size'] == 0:
                    continue
                self.files_queue.put(obj['Key'])
                counters['total_files'].inc()

            print_status()


class ScanOneKey(threading.Thread):
    def __init__(self, the_bucket, the_key):

        threading.Thread.__init__(self)
        self.the_bucket = the_bucket
        self.the_key = the_key

    def run(self):
        if args.limit != 0 and counters['matches_count'].value() \
                >= args.limit:
            return
        key = self.the_key
        input_ser = {'JSON': {"Type": "Document"}}
        output_ser = {'JSON': {}}
        if args.delim is not None:
            input_ser = {'CSV': {"FieldDelimiter": args.delim,
                                 "FileHeaderInfo": "NONE"}}
            output_ser = {'CSV': {"FieldDelimiter": args.delim}}

        if args.count:
            # no need to parse JSON if we are only expecting the count of rows
            output_ser = {'CSV': {"FieldDelimiter": " "}}

        query = "SELECT "

        if args.count:
            query += "count(*) "
        elif args.output_fields:
            query += args.output_fields + " "
        else:
            query += "* "

        query += "FROM s3object s "

        if args.where is not None:
            query += "WHERE " + args.where

        if args.limit != 0:
            query += " LIMIT " + str(args.limit)

        if '.gz' == key.lower()[-3:]:
            input_ser['CompressionType'] = 'GZIP'

        while True:
            try:
                r = s3.select_object_content(
                    Bucket=self.the_bucket,
                    Key=self.the_key,
                    ExpressionType='SQL',
                    Expression=query,
                    InputSerialization=input_ser,
                    OutputSerialization=output_ser,
                )
                break
            except (EndpointConnectionError, ClientError) as e:
                counters['failed_requests'].inc()
                safe_print('Exception caught when querying {}: {}'
                           .format(self.the_key, str(e)),
                           file=sys.stderr)

                time.sleep(0.4)

        for event in r['Payload']:
            if args.limit != 0 and counters['matches_count'].value() \
                    >= args.limit:
                # requested limit reached
                return
            if 'Records' in event:
                records = event['Records']['Payload'].decode('utf-8')
                if args.limit != 0:
                    counters['matches_count'].print_and_inc(records, args.limit)
                else:
                    if args.count:
                        counters['matches_count'].inc(int(records))
                    else:
                        safe_print(records, end='')
                        counters[
                            'matches_count'].inc(len(records.split("\n")) - 1)
            elif 'Stats' in event:
                counters['bytes_scanned'].inc(
                    event['Stats']['Details']['BytesScanned'])
                counters['bytes_returned'].inc(
                    event['Stats']['Details']['BytesReturned'])

        counters['files_processed'].inc()
        print_status()


class AtomicInteger:
    def __init__(self, value=0):
        self._value = value
        self._lock = threading.Lock()

    def inc(self, addition=1):
        with self._lock:
            self._value += addition

    def print_and_inc(self, records, limit):
        with self._lock:
            for record in records.split("\n"):
                if self.value() >= limit:
                    return
                self._value += 1
                safe_print(record)

    def value(self):
        return self._value


def select():
    url_parse = parse.urlparse(args.prefix)
    bucket = url_parse.netloc
    prefix = url_parse.path[1:]

    files_queue = queue.Queue(20000)

    listing_thread = S3ListThread(bucket, prefix, files_queue)
    listing_thread.start()
    # issue all ScanOneKey requests
    while ((args.limit == 0
            or counters['matches_count'].value() < args.limit)
           and (not files_queue.empty() or listing_thread.is_alive())
           ):

        if threading.active_count() > args.thread_count:
            time.sleep(0.2)
            continue
        try:
            ScanOneKey(bucket, files_queue.get(timeout=3)).start()
        except queue.Empty:
            # nothing. S3ListThread returned nothing. No keys with prefixes
            # found
            pass

    # wait until we process all files or reach limit
    while True:
        all_files_processed = counters['files_processed'].value() == counters[
            'total_files'].value()
        limit_reached = 0 < args.limit <= counters['matches_count'].value()
        if all_files_processed or limit_reached:
            break

        time.sleep(0.1)

    print_status(end="\n")

    if args.verbose:
        price_for_bytes_scanned = 0.002 * counters['bytes_scanned'].value() / (
            1024 ** 3)
        price_for_bytes_returned = 0.0007 * counters[
            'bytes_returned'].value() / (1024 ** 3)
        price_for_requests = 0.0004 * counters['total_files'].value() / 1000

        safe_print("Cost for data scanned: ${0:.2f}"
                   .format(price_for_bytes_scanned), file=sys.stderr,
                   status_refresh=False)
        safe_print("Cost for data returned: ${0:.2f}"
                   .format(price_for_bytes_returned), file=sys.stderr,
                   status_refresh=False)
        safe_print("Cost for SELECT requests: ${0:.2f}"
                   .format(price_for_requests), file=sys.stderr,
                   status_refresh=False)
        safe_print("Total cost: ${0:.2f}"
                   .format(price_for_bytes_scanned + price_for_bytes_returned +
                           price_for_requests), file=sys.stderr,
                   status_refresh=False)


if __name__ == "__main__":

    print_lock = threading.Lock()

    a = argparse.ArgumentParser(description=
                                's3select makes s3 select querying API much '
                                'easier and faster')
    a.add_argument("-p", "--prefix",
                   help="S3 prefix beneath which all files are queried")
    a.add_argument("-w", "--where",
                   help="WHERE part of the SQL query")
    a.add_argument("-d", "--delim",
                   help="Delimiter to be used for CSV files. If specified CSV "
                        "parsing will be used. By default we expect JSON input")
    a.add_argument("-l", "--limit", type=int, default=0,
                   help="Maximum number of results to return")
    a.add_argument("-v", "--verbose", action='store_true',
                   help="Be more verbose")
    a.add_argument("-D", "--disable_progress", action='store_true',
                   help="Turn off progress line")
    a.add_argument("-c", "--count", action='store_true',
                   help="Only count records without printing them to stdout")
    a.add_argument("-o", "--output_fields",
                   help="What fields or columns to output")
    a.add_argument("-t", "--thread_count", type=int, default=200,
                   help="How many threads to use when executing s3_select api "
                        "requests. Default of 200 seems to be max that doesn't "
                        "cause throttling on AWS side")
    a.add_argument("--profile",
                   help="Use a specific profile from your credential file.")

    args = a.parse_args()

    counters = {
        "total_files": AtomicInteger(),
        "matches_count": AtomicInteger(),
        "files_processed": AtomicInteger(),
        "bytes_returned": AtomicInteger(),
        "bytes_scanned": AtomicInteger(),
        "failed_requests": AtomicInteger()
    }

    if args.prefix is None:
        a.print_help()
        sys.exit(1)

    if args.delim is not None and "\\t" in args.delim:
        args.delim = '\t'

    if args.profile is not None:
        boto3.setup_default_session(profile_name=args.profile)

    s3 = boto3.client('s3')

    select()
