#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import os
import subprocess
import sys

from tarproclib import reader, writer

parser = argparse.ArgumentParser("Split a tar file into shards based on size or number of samples.")
parser.add_argument("-n", "--num-samples", default=100000, type=float)
parser.add_argument("-s", "--max-size", default=1e9, type=float)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-C", "--command", default=None)
parser.add_argument("-o", "--output", default="temp")
parser.add_argument("-O", "--open", default=None)
parser.add_argument("-z", "--compress", action="store_true")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--maxshards", default=1000000000, type=int)
parser.add_argument("--nodelete", action="store_true", help="don't delete after executing command")
parser.add_argument("input", default="-", nargs="?")
args = parser.parse_args()


def dprint(*args, **kw):
    print(*args, file=sys.stderr, **kw)


def sample_size(sample):
    total = 0
    for k, v in sample.items():
        total += len(k)
        total += len(v)
    return total


total_count = 0
total_size = 0

shard = 0
shard_name = None
count = 0
size = 0
sink = None


def finish_shard():
    global sink
    if sink is None:
        return
    if hasattr(sink, "process"):
        sink.process.wait(timeout=60.0)
    sink.close()
    sink = None
    if args.command is not None:
        basename = os.path.basename(shard_name)
        base, ext = os.path.splitext(basename)
        kw = dict(
            shard=shard_name,
            abspath=os.path.abspath(shard_name),
            basename=basename,
            dirname=os.path.basename(shard_name),
            base=base,
            ext=ext)
        cmd = args.command.format(**kw)
        print(f"# {cmd}", file=sys.stderr)
        status = os.system(cmd)
        assert status == 0, (status, cmd)
        if not args.nodelete:
            print(f"# removing {shard_name}", file=sys.stderr)
            os.unlink(shard_name)


if "{" not in args.output:
    if args.compress:
        output_pattern = args.output + "-{shard:06d}.tgz"
    else:
        output_pattern = args.output + "-{shard:06d}.tar"
else:
    output_pattern = args.output

for sample in reader.TarIterator(args.input):
    if args.verbose:
        dprint(sample.get("__key__"))
    if sink is None or count >= args.num_samples or size >= args.max_size:
        total_count += count
        total_size += size
        count = 0
        size = 0
        finish_shard()
        if shard >= args.maxshards:
            break
        shard_name = output_pattern.format(shard=shard)
        dprint(f"# writing {shard_name} ({total_count}, {total_size})")
        if shard_name[0] == "|":
            process = subprocess.Popen(shard_name[1:], stdin=subprocess.PIPE, shell=True)
            stream = process.stdin
            stream.process = process
            sink_stream = stream
        elif args.open is not None:
            process = subprocess.Popen(args.open + " " + shard_name, stdin=subprocess.PIPE, shell=True)
            stream = process.stdin
            stream.process = process
            sink_stream = stream
        else:
            sink_stream = open(shard_name, "wb")
        compress = None if not args.compress else True
        sink = writer.TarWriter(sink_stream, compress=compress)
        shard += 1
    sink.write(sample)
    count += 1
    size += sample_size(sample)

finish_shard()
