#!python

import argparse
import sys

import matplotlib.pyplot as plt
import numpy as np
import simplejson
import tensorcom
from matplotlib import cm

parser = argparse.ArgumentParser("show tensor inputs")
parser.add_argument("input", nargs="*")
parser.add_argument("-b", "--unbatched", action="store_true")
parser.add_argument("-m", "--mode", default="imgclass")
parser.add_argument("-B", "--perbatch", default=1, type=int)
parser.add_argument("-C", "--classes", default=None)
parser.add_argument("-t", "--timeout", default=-1, type=float)
args = parser.parse_args()

if args.input == []:
    args.input = ["zsub://127.0.0.1:%d" % i for i in range(7880, 7884)]

classtable = {}
if args.classes is not None:
    with open(args.classes) as stream:
        classtable = simplejson.load(stream)


def isimage(a, batched=True, minsize=24):
    if not isinstance(a, np.ndarray):
        return False
    if batched:
        a = a[0]
    if sum([d >= minsize for d in a.shape]) < 2:
        return False
    return True


def smartshow(img, batch_index=None, cmap=cm.viridis):
    if index is not None:
        img = img[batch_index]
    if img.dtype == np.uint8:
        img = img.astype(np.float32) / 255.0
    img = img.astype(np.float32)
    if np.amin(img) < 0 or np.amax(img) > 1:
        img -= np.amin(img)
        img /= np.amax(img)
    if img.ndim == 3:
        if img.shape[0] in [3, 4]:
            plt.imshow(img.transpose(1, 2, 0)[..., :3])
        elif img.shape[0] == 1:
            plt.imshow(img[0], cmap=cmap)
        elif img.shape[-1] in [3, 4]:
            plt.imshow(img[..., :3])
        elif img.shape[-1] == 1:
            plt.imshow(img[..., 0], cmap=cmap)
    elif img.ndim == 2:
        plt.imshow(img, cmap=cmap)


def info(x):
    if isinstance(x, np.ndarray):
        print("{} {} {} {}".format(x.dtype, x.shape, np.amin(x), np.amax(x)))
    else:
        print("{} ".format(str(x)[:50]))


plt.ion()
source = tensorcom.Connection(device=None)
for c in args.input:
    print(c)
    source.connect(c)

if args.mode == "imgclass":
    for img, cls in source.items():
        print(img.shape, img.dtype, np.amin(img), np.amax(img))
        print(cls.shape, cls.dtype, np.amin(cls), np.amax(cls))
        if args.unbatched:
            img = np.array([img])
        for index in range(min(len(img), args.perbatch)):
            plt.clf()
            smartshow(img, index)
            c = cls[index]
            c = classtable.get(c, c)
            plt.title("[{}] class = {}".format(index, c))
            plt.show()
            plt.waitforbuttonpress(timeout=args.timeout)
            if not plt.fignum_exists(1):
                sys.exit(0)
elif args.mode == "img2img":
    for img, img2 in source.items():
        print(img.shape, img.dtype, np.amin(img), np.amax(img))
        print(img2.shape, img2.dtype, np.amin(img2), np.amax(img2))
        if args.unbatched:
            img = np.array([img])
            img2 = np.array([img2])
        for index in range(min(len(img), args.perbatch)):
            plt.clf()
            plt.subplot(121)
            smartshow(img, index)
            plt.subplot(122)
            smartshow(img2, index)
            plt.title("[{}]".format(index))
            plt.show()
            plt.waitforbuttonpress(timeout=args.timeout)
            if not plt.fignum_exists(1):
                sys.exit(0)

else:
    print("{}: unknown mode".format(args.mode))
    sys.exit(1)
