#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import multiprocessing as mp
import queue as mpq
import random
import sys

import braceexpand

from tarproclib import gopen, proc, reader, writer

parser = argparse.ArgumentParser(
    description="Read, shuffle, and combine multiple shards in parallel.")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-T", "--filelist", default=None)
parser.add_argument("-b", "--braceexpand", action="store_true")
parser.add_argument("-c", "--count", type=int, default=1000000000)
parser.add_argument("-s", "--shuffle", type=int, default=0)
parser.add_argument("-p", "--workers", type=int, default=8)
parser.add_argument("-o", "--output", default="-")
parser.add_argument("--dummy", action="store_true")
parser.add_argument("input", nargs="*")
args = parser.parse_args()


def dprint(*args, **kw):
    print(*args, file=sys.stderr, **kw)


def read_filelist(filelist):
    with gopen.gopen(filelist, "r") as stream:
        for line in stream:
            yield line.strip()


def reader_proc(file_queue, sample_queue):
    try:
        fname = file_queue.get(timeout=5.0)
        print(f"# opening {fname}", file=sys.stderr)
        for sample in reader.TarIterator(fname, braceexpand=False):
            if "__source__" not in sample:
                sample["__source__"] = fname
            sample_queue.put(sample)
        print(f"# done {fname}", file=sys.stderr)
    except mpq.Empty:
        pass


if args.filelist is not None:
    filelist = list(read_filelist(args.filelist))
elif args.braceexpand:
    assert len(args.input) == 1, args.input
    filelist = list(braceexpand.braceexpand(args.input[0]))
else:
    filelist = args.input

dprint(f"# got {len(filelist)} files")

n = 0

if args.shuffle > 0:
    random.shuffle(filelist)

file_queue = mp.Queue(len(filelist) + 10)
sample_queue = mp.Queue(10000)

for fname in filelist:
    file_queue.put(fname)


def parallel_source():
    jobs = []

    for i in range(args.workers):
        job_args = (file_queue, sample_queue)
        process = mp.Process(target=reader_proc, args=job_args)
        jobs.append(process)

    for job in jobs:
        job.start()

    dprint(f"# started {len(jobs)} jobs")

    try:
        while len(jobs) > 0:
            try:
                sample = sample_queue.get(timeout=5.0)
                yield sample
            except mpq.Empty:
                dprint("timeout")
            for job in jobs:
                alive = []
                if not job.is_alive():
                    job.join()
                else:
                    alive.append(job)
            jobs = alive
    finally:
        for job in jobs:
            job.kill()
            job.join()


source = parallel_source()

if args.shuffle > 0:
    source = proc.ishuffle(source, args.shuffle)

sink = writer.TarWriter(args.output, keep_meta=True)
total = 0
for sample in source:
    total += 1
    if args.dummy:
        dprint(sample.get("__key__", total), sample.get("__source__", None))
    else:
        sink.write(sample)
    if total > args.count:
        break
