#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import random

import braceexpand
import numpy as np
import yaml

from tarproclib import proc, reader, writer

description = "Randomly mix data sources to standard out."
epilog = """
The mix is specified by a YAML file that looks something like this:

```YAML
    shuffle: 1000
    sources:
      - shards: "gs://bucket1/data-{000..999}.tar"
        weight: 1
        probability: 0.9
        shuffle: 10000
      - shards: "gs://bucket3/data-{000..100}.tar"
        weight: 1.7
        probability: 0.5
        nstreams: 10
        nrepeats: 5
        convert:
          - from: page.jpg
            to: png
```

The source attributes have the following meaning:

    - shards: actual shard spec
    - weight: streams are chosen for sampling with a probability proportional to weight
    - probability: samples are output from a stream with this probability
    - nstreams: for multi-sharded streams, tries to keep this many streams open in parallel
    - nrepeats: repeats the source this often
    - convert: perform a limited number of sample format conversions
    - rename: rename sample components (rename to "" to delete)
    - shuffle: shuffle the files and samples for this stream
    - allow_missing: permit input files to be missing for the stream

Global parameters:

    - shuffle: shuffle the output samples before writing (use external program for very large shuffles)
"""

parser = argparse.ArgumentParser(
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description=description,
    epilog=epilog)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-c", "--count", type=int, default=1000000000)
parser.add_argument("-o", "--output", default="-")
parser.add_argument("--output-mode", default="random")
parser.add_argument("--skip", type=int, default=0)
parser.add_argument("--shuffle", type=int, default=-1)
parser.add_argument("--eof", action="store_true")
parser.add_argument("yamlspec")
args = parser.parse_args()

with open(args.yamlspec, "r") as stream:
    yamlspec = yaml.load(stream)

if args.shuffle < 0:
    args.shuffle = yamlspec.get("shuffle", 0)

sources = []
weights = []
probabilities = []
accum = []


def repeat(source, n):
    for i in range(n):
        for sample in source:
            yield sample


def multistream(files, n, **kw):
    sources = []
    while True:
        while len(sources) < n and len(files) > 0:
            sources.append(iter(reader.TarIterator(files.pop(0), **kw)))
        if len(sources) == 0:
            return
        index = np.random.randint(0, len(sources))
        try:
            sample = next(sources[index])
            yield sample
        except StopIteration:
            del sources[index]


for spec in yamlspec["sources"]:
    assert "convert" not in spec, "convert unimplemented"
    assert "rename" not in spec, "rename unimplemented"
    shuffle = spec.get("shuffle", 0)
    nstreams = spec.get("nstreams", 1)
    nrepeats = spec.get("nrepeats", 1)
    allow_missing = spec.get("allow_missing", False)
    files = list(braceexpand.braceexpand(spec["shards"]))
    if shuffle > 0:
        random.shuffle(files)
    source = multistream(files, nstreams, allow_missing=allow_missing)
    if nrepeats > 1:
        source = repeat(source, nrepeats)
    if shuffle > 0:
        source = proc.ishuffle(source, shuffle)
    sources.append(source)
    weights.append(float(spec.get("weight", 1.0)))
    probabilities.append(float(spec.get("probability", 1.0)))
    accum.append(0.0)

accum = np.array(weights, "f")
accum /= np.sum(accum)
accum = np.add.accumulate(accum)
probabilities = np.array(probabilities)

assert (probabilities >= 0.0).all() and (probabilities <= 1.0).all()


def sampling():
    while True:
        try:
            r = np.random.uniform()
            index = np.searchsorted(accum, r)
            sample = next(sources[index])
            if np.random.uniform() > probabilities[index]:
                continue
            yield sample
        except StopIteration:
            return


source = sampling()

if args.shuffle > 0:
    source = proc.ishuffle(source, args.shuffle)

sink = writer.TarWriter(args.output, keep_meta=True, output_mode=args.output_mode)

n = 0
for sample in source:
    if n >= args.count:
        break
    if n >= args.skip:
        sink.write(sample)
    n += 1

if args.eof:
    sink.send_eof()

sink.close()
