#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import atexit
import glob
import os
import re
import shutil
import sys

from tarproclib import paths, reader, writer

epilog = """
Run a command line tool over all samples.

Each sample is extracted into its own directory with
a common basename (default=sample) and the extensions from the sample.

Example:

    tarproc -I png -c 'convert sample.jpg sample.png' inputs.tar -o outputs.tar
"""

parser = argparse.ArgumentParser(
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="Run commands over all samples.",
    epilog=epilog)

parser.add_argument("-v", "--verbose", action="store_true",
                    help="output extra information")
parser.add_argument("-q", "--silent", action="store_true",
                    help="extra quiet")
parser.add_argument("-c", "--command", default=None,
                    help="command to run for each sample (working dir = sample)")
parser.add_argument("-S", "--script", default=None,
                    help="script to run for each sample (working dir = sample)")
parser.add_argument("-w", "--working_dir", default="__{pid}__",
                    help="temporary working dir")
parser.add_argument("-b", "--base", default="sample",
                    help="base to substitute for __key__ (default=\"sample\")")
parser.add_argument("-f", "--fields", default=None,
                    help="fields to run on (default=all, space separated)")
parser.add_argument("-F", "--fieldmode", default="ignore",
                    help="how to handle missing fields (error or ignore)")
parser.add_argument("-p", "--parallel", default=0, type=int,
                    help="execute scripts in parallel")
parser.add_argument("-e", "--error-handling", default="skip",
                    help="how to handle errors in scripts (ignore, skip, abort)")
parser.add_argument("-E", "--exclude", default=None,
                    help="exclude anything matching this from output")
parser.add_argument("-I", "--include", default=None,
                    help="include only files matching this in output")
parser.add_argument("-s", "--separator", default="",
                    help="separator between key and new file bases")
parser.add_argument("--interpreter", default="bash",
                    help="interpreter used for script argument")
parser.add_argument("--count", type=int, default=1000000000,
                    help="stop after processing this many samples")
parser.add_argument("-o", "--output", default=None)
parser.add_argument("input", default="-", nargs="?")
args = parser.parse_args()


def dprint(*args, **kw):
    print(*args, file=sys.stderr, **kw)


if args.script:
    assert not args.command
    # handle relative paths not on $PATH specially
    path = os.path.abspath(args.script)
    if os.path.exists(path):
        args.command = f"{args.interpreter} '{path}'"
    else:
        args.command = args.script
elif args.command:
    assert not args.script
else:
    sys.exit("most provide either --command or --script")

if args.fields is not None:
    fields = set(f for f in args.fields.split(","))
else:
    fields = None


def proc_sample(sample, index=0, fields=None, separator="", fieldmode="ignore"):
    assert isinstance(sample, dict)

    # if there are fields, we limit processing to those fields
    if fields is not None:
        if fieldmode == "ignore":
            sample = {k: v for k, v in sample.items() if k in fields or k[0] == "_"}
        elif fieldmode == "error":
            sample = {k: sample[k] for k in fields}

    old_sample = sample

    # process in a subdirectory
    dirname = os.path.join(args.working_dir, "_%08d" % index)
    os.mkdir(dirname)
    with paths.ChDir(dirname):

        # write the sample out as files
        for k, v in sample.items():
            fname = args.base + "." + k if k[0] != "_" else k
            paths.write_binary(fname, v)

        # execute the command and handle errors
        status = os.system(args.command)
        if status != 0:
            if args.error_handling == "ignore":
                if not args.silent:
                    dprint("exit status (ignore):", status, sample.get("__key__", "?"))
                pass
            elif args.error_handling == "skip":
                if not args.silent:
                    dprint("exit status (skip)", status, sample.get("__key__", "?"))
                return []
            else:
                if not args.silent:
                    dprint("exit status (abort)", status, sample.get("__key__", "?"))
                assert status == 0, status

        # gather up the result files
        files = sorted([fname for fname in glob.glob("*.*") if os.path.isfile(fname)])
        if args.exclude is not None:
            if args.verbose:
                print(f"< exclude {files}")
            files = [fname for fname in files if not re.search(args.exclude, fname)]
            if args.verbose:
                print(f"> exclude {files}")
        if args.include is not None:
            if args.verbose:
                print(f"< include {files}")
            files = [fname for fname in files if re.search(args.include, fname)]
            if args.verbose:
                print(f"> include {files}")

        # processing may have produced multiple outputs; gather them up separately
        bases = sorted(set(map(paths.filebase, files)))
        samples = []
        for base in bases:
            matching = [fname for fname in files if fname.startswith(base + ".")]
            extra_key = base
            if extra_key.startswith(args.base):
                extra_key = extra_key[len(args.base):]
            sample = {}
            if extra_key != "":
                sample["__key__"] = old_sample["__key__"] + args.separator + extra_key
            else:
                sample["__key__"] = old_sample["__key__"]
            for fname in matching:
                assert fname.startswith(base)
                key = paths.fullext(fname)
                value = paths.read_binary(fname)
                sample[key] = value
            samples.append(sample)
    shutil.rmtree(dirname)
    for sample in samples:
        assert isinstance(sample, dict), sample
    return samples


def proc_sample1(arg):
    i, sample = arg
    result = proc_sample(sample, separator=args.separator, index=i, fields=fields, fieldmode=args.fieldmode)
    assert isinstance(result, list)
    for sample in result:
        assert isinstance(sample, dict), sample
    return result


args.working_dir = args.working_dir.format(pid=str(os.getpid()))

assert not os.path.exists(args.working_dir)
os.mkdir(args.working_dir)

atexit.register(lambda: shutil.rmtree(args.working_dir))

sink = None

if args.output is not None:
    sink = writer.TarWriter(args.output)
    atexit.register(lambda: sink.close())


def handle_result(new_samples):
    assert isinstance(new_samples, list), new_samples
    global sink
    if args.verbose:
        for s in new_samples:
            assert isinstance(s, dict), s
            assert "__key__" in s, s.keys()
            keyinfo = [k for k in s.keys() if k[0] != "_"]
            dprint(s.get("__key__"), " ".join(keyinfo))
    if sink is not None:
        for s in new_samples:
            assert isinstance(s, dict), s
            assert "__key__" in s, s.keys()
            sink.write(s)


if args.parallel == 0:
    count = 0
    for i, sample in enumerate(reader.TarIterator(args.input)):
        if count >= args.count:
            break
        assert isinstance(sample, dict)
        new_samples = proc_sample1((i, sample))
        handle_result(new_samples)
        count += 1
elif args.parallel > 0:
    from multiprocessing import Pool
    count = 0
    with Pool(processes=args.parallel) as pool:
        for new_samples in pool.imap_unordered(proc_sample1,
                                               enumerate(reader.TarIterator(args.input))):
            if count >= args.count:
                break
            handle_result(new_samples)
            count += 1
