#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import random
import sys

import braceexpand

from tarproclib import gopen, proc, reader, writer

parser = argparse.ArgumentParser("Concatenate tar files sequentially to standard out.")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-T", "--filelist", default=None)
parser.add_argument("-b", "--braceexpand", action="store_true")
parser.add_argument("-s", "--skip", type=int, default=0)
parser.add_argument("-c", "--count", type=int, default=1000000000)
parser.add_argument("-o", "--output", default="-")
parser.add_argument("--output-mode", default="random")
parser.add_argument("--shuffle", type=int, default=0)
parser.add_argument("--eof", action="store_true")
parser.add_argument("--nodata", action="store_true")
parser.add_argument("input", nargs="*")
args = parser.parse_args()


def dprint(*args, **kw):
    print(*args, file=sys.stderr, **kw)


def read_filelist():
    with gopen.gopen(args.filelist, "r") as stream:
        for line in stream:
            yield line.strip()


if args.nodata:
    filelist = []
elif args.filelist is not None:
    filelist = list(read_filelist())
elif args.braceexpand:
    assert len(args.input) == 1, args.input
    filelist = list(braceexpand.braceexpand(args.input[0]))
elif len(args.input) > 0:
    filelist = args.input
else:
    filelist = ["-"]


if filelist != ["-"]:
    dprint(f"# got {len(filelist)} files")

n = 0
sink = writer.TarWriter(args.output, keep_meta=True, output_mode=args.output_mode)
if args.shuffle > 0:
    random.shuffle(filelist)
for fname in filelist:
    if fname != "-":
        dprint(f"# {n} {fname}")
    source = reader.TarIterator(fname, braceexpand=False)
    if args.shuffle > 0:
        source = proc.ishuffle(iter(source), args.shuffle)
    for sample in source:
        if "__source__" not in sample:
            sample["__source__"] = fname
        if n >= args.count:
            break
        if n >= args.skip:
            sink.write(sample)
        n += 1
    if n >= args.count:
        break
if args.eof:
    sink.send_eof()
# sink.socket.close(linger=-1)
# sink.context.term()
sink.close()
