#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import io
import re
import select
import sys
import time

import numpy as np

from tarproclib import reader


def input_with_timeout(prompt="", timeout=1e9):
    print(prompt, flush=True, end="")
    ready, _, _ = select.select([sys.stdin], [], [], timeout)
    if ready:
        return sys.stdin.readline().rstrip('\n')  # expect stdin to be line-buffered
    else:
        return None


parser = argparse.ArgumentParser("Show data inside a tar file.")
parser.add_argument("-f", "--field", default=None,
                    help="field to be viewed")
parser.add_argument("-c", "--count", type=int, default=10000000000,
                    help="number of records to display")
parser.add_argument("-N", "--normalize", action="store_true",
                    help="normalize images before display")
parser.add_argument("-C", "--cmap", default="gray",
                    help="color map for images")
parser.add_argument("-d", "--delay", type=float, default=100000,
                    help="delay between  displayed images records")
parser.add_argument("--silent", action="store_true",
                    help="less output")
parser.add_argument("--verbatim-keys", action="store_true",
                    help="compare keys verbatim (rather than case sensitive)")
parser.add_argument("--use-keyboard", action="store_true",
                    help="use the keyboard rather than the mouse for input")
parser.add_argument("input", default="-", nargs="?", help="tar file")
args = parser.parse_args()

if args.field is not None:
    import matplotlib.pylab as plt
    import PIL
    plt.ion()
    fields = [re.split("[,;]", f) for f in args.field.split()]
else:
    fields = None

output = sys.stdout

for i, sample in enumerate(reader.TarIterator(args.input)):
    if i >= args.count:
        break
    try:
        if not args.verbatim_keys:
            sample = {k.lower(): v for k, v in sample.items()}
        if not args.silent:
            for k, v in sorted(list(sample.items())):
                print(f"{k:20s}\t{str(v)[:60]}", file=output)
        if fields is not None:
            plt.clf()
            for i, field in enumerate(fields):
                image = None
                for f in field:
                    if f in sample:
                        image = sample[f]
                if image is None:
                    continue
                image = PIL.Image.open(io.BytesIO(image))
                if not args.silent:
                    k = "__image__"
                    print(f"{k:20s}\t{image}", file=output)
                if len(fields) > 1:
                    plt.subplot(1, len(fields), i + 1)
                image = np.asarray(image)
                if args.normalize:
                    image = image.astype(float)
                    image -= np.amin(image)
                    image /= np.amax(image)
                if image.ndim == 2:
                    cmap = eval(f"plt.cm.{args.cmap}")
                    plt.imshow(image, cmap)
                else:
                    plt.imshow(image)
            plt.show()
            if not args.use_keyboard:
                plt.waitforbuttonpress(timeout=max(1e-3, args.delay))
            else:
                plt.waitforbuttonpress(timeout=0.001)
        if not args.silent:
            print(file=output)
        if args.use_keyboard:
            input_with_timeout(timeout=args.delay)
    except Exception as e:
        print(e)
        time.sleep(1)
