#!python

import argparse
import logging
import multiprocessing
import os
import sys

import numpy as np
from tensorcom import zcom
import torch
from torchvision import datasets, transforms

logger = logging.getLogger()
logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(
    """
Serve the Imagenet dataset for training.

By default, data is served as tuples (img, cls), where
img is a (batch, h, w, channel) array of type uint8
and cls is a (batch,) array of type int32.

The batch size can be adjusted using the `-b` argument.

This program uses the PyTorch data loader, but still
uses NumPy conventions for actually serving the data.

Usage:
"""
)
parser.add_argument("service_address", nargs="*")
parser.add_argument(
    "-d",
    "--dir",
    default="./imagenet",
    help="directory containing the ImagenNet dataset",
)
parser.add_argument(
    "-b",
    "--batch-size",
    type=int,
    default=32,
    help="batch the input (default is no batching)",
)
parser.add_argument(
    "-r", "--report", type=int, default=10, help="report on progress this frequently"
)
parser.add_argument(
    "-B",
    "--benchmark",
    action="store_true",
    help="eliminate I/O overhead by just preloading and serving one sample",
)
parser.add_argument(
    "-w", "--workers", type=int, default=0, help="number of DataLoader workers"
)
parser.add_argument(
    "-P",
    "--parallel",
    type=int,
    default=0,
    help="spawn multiple subprocesses for parallel I/O",
)
parser.add_argument("-S", "--no-shuffle", action="store_false")
parser.add_argument("-n", "--normalize", action="store_true")
args = parser.parse_args()

if args.service_address == []:
    args.service_address = ["zpub://127.0.0.1:7880"]

if args.parallel > 0:
    assert len(args.service_address) == 1
    assert args.service_address[0].startswith("zpush") or args.service_address[
        0
    ].startswith("zrpub")
    args.service_address = args.service_address * args.parallel

logger.info("service:", args.service_address)


def fixtype(a):
    if isinstance(a, (int, float, str)):
        return a
    if isinstance(a, np.ndarray):
        if a.dtype == np.int64:
            return a.astype(np.int32)
        if a.dtype == np.float64:
            return a.astype(np.float32)
    return a


def start_server(con, report=args.report):

    logger.info("starting server")

    serve = zcom.Connection(con)

    logger.info("loading dataset")

    traindir = os.path.join(args.dir, "train")
    valdir = os.path.join(args.dir, "val")

    if args.normalize:
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]
            ),
        )
    else:
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    lambda x: (255 * x).type(torch.uint8),
                ]
            ),
        )

    logger.info("creating dataloader")

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=not args.no_shuffle,
        num_workers=args.workers,
        pin_memory=False,
    )

    for i, (img, cls) in enumerate(train_loader):
        if i % report == 0:
            print(i, serve.stats.summary())
            sys.stdout.flush()
        img = img.permute(0, 2, 3, 1).numpy()
        cls = cls.type(torch.int32).numpy() - 1
        if i == 0:
            print(img.shape, img.dtype, np.amin(img), np.amax(img))
            print(cls.shape, cls.dtype)
        serve.send([img, cls])


if len(args.service_address) == 1:
    start_server(args.service_address[0])
else:
    nproc = len(args.service_address)
    pool = multiprocessing.Pool(nproc)
    print(pool)
    pool.map(start_server, args.service_address)
