#!python

import argparse
import multiprocessing
import sys

import numpy as np
from tensorcom import zcom
from webdataset import WebDataset
from torch.utils.data import DataLoader

parser = argparse.ArgumentParser(
    """
Serve the Imagenet dataset for training.

By default, data is served as tuples (img, cls), where
img is a (batch, h, w, channel) array of type uint8
and cls is a (batch,) array of type int32.

The batch size can be adjusted using the `-b` argument.

This program reads sharded tar files from any URL.

Usage:
"""
)
parser.add_argument("service_address", nargs="*")
parser.add_argument(
    "-u",
    "--url",
    default="http://storage.googleapis.com/lpr-imagenet-augmented/imagenet_train-{0000..0147}-{000..019}.tgz",
    help="source shard(s) (use --google to point at Google bucket)",
)
parser.add_argument("-b", "--batch-size", type=int, default=32, help="batch the input")
parser.add_argument(
    "-r", "--report", type=int, default=10, help="report on progress this frequently"
)
parser.add_argument(
    "-B",
    "--benchmark",
    action="store_true",
    help="eliminate I/O overhead by just preloading and serving one sample",
)
parser.add_argument(
    "-p",
    "--parallel",
    type=int,
    default=0,
    help="spawn multiple subprocesses for parallel I/O",
)
parser.add_argument("-S", "--shuffle", action="store_true", help="shuffle the data")
parser.add_argument(
    "-N", "--num-workers", type=int, default=0, help="num_workers for DataLoader"
)
args = parser.parse_args()

assert args.batch_size > 0, args.batch_size
assert args.batch_size < 100000, args.batch_size


if args.service_address == []:
    args.service_address = ["zpub://127.0.0.1:7880"]


if args.parallel > 0:
    assert len(args.service_address) == 1
    assert args.service_address[0].startswith("zpush") or args.service_address[
        0
    ].startswith("zrpub")
    args.service_address = args.service_address * args.parallel


def fixtype(a):
    if isinstance(a, (int, float, str)):
        return a
    if isinstance(a, np.ndarray):
        if a.dtype == np.int64:
            return a.astype(np.int32)
        if a.dtype == np.float64:
            return a.astype(np.float32)
    return a


def infinite(source):
    while True:
        for sample in source:
            yield sample


def start_server(con, report=args.report, benchmark=args.benchmark):
    print("serving {}".format(con))
    serve = zcom.Connection(con)
    dataset = WebDataset(
        args.url,
        extensions="png;jpg;jpeg;ppm cls",
        decoder="rgb8",
        shuffle=int(args.shuffle > 0),
    )
    source = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=args.shuffle,
        num_workers=args.num_workers,
    )

    if not benchmark:
        for i, (img, cls) in enumerate(infinite(source)):
            if i % report == 0:
                print(i, serve.stats.summary())
                sys.stdout.flush()
            img, cls = img.numpy().astype(np.uint8), cls.numpy().astype(np.int32)
            serve.send([img, cls])
    else:
        for i, (img, cls) in enumerate(source):
            break
        while True:
            if i % report == 0:
                print(i, serve.stats.summary())
                sys.stdout.flush()
            img, cls = img.numpy(), cls.numpy()
            serve.send([img, cls])


if len(args.service_address) == 1:
    start_server(args.service_address[0])
else:
    nproc = len(args.service_address)
    pool = multiprocessing.Pool(nproc)
    print(pool)
    pool.map(start_server, args.service_address)
