#!python

import argparse
import sys
import time
from math import inf

import numpy as np
import tensorcom

parser = argparse.ArgumentParser(
    """
Compute statistics over a tensor in a tensorcom input stream.

Each item in a tensorcom input stream is usually a list of
tensors, each representing a batch.
"""
)
parser.add_argument("input", nargs="*")
parser.add_argument("-b", "--unbatched", action="store_true")
parser.add_argument("-c", "--count", type=int, default=20)
parser.add_argument("-r", "--raw", action="store_true")
args = parser.parse_args()

if args.input == []:
    args.input = ["zsub://127.0.0.1:7880"]

source = tensorcom.Connection(args.input, device=None, raw=args.raw)

print("reading batches...\n")


class Stats(object):
    def __init__(self):
        self.count = 0
        self.lo = inf
        self.hi = -inf
        self.sx = 0
        self.sx2 = 0
        self.n = 0

    def __iadd__(self, x):
        self.count += 1
        self.lo = min(self.lo, np.amin(x))
        self.hi = max(self.hi, np.amax(x))
        self.sx += np.sum(x)
        self.sx2 += np.sum(x ** 2)
        self.n += x.size
        return self

    def summary(self):
        return "{:d} [{:.3g} {:.3g}] mean={:.3g} std={:.3g} n={:d}".format(
            self.count,
            self.lo,
            self.hi,
            self.sx / self.n,
            (self.sx2 / self.n - (self.sx / self.n) ** 2) ** 0.5,
            self.n,
        )


shapes = [set() for _ in range(10)]
stats = [Stats() for _ in range(10)]
ninputs = 0

start = time.time()
for i, batch in enumerate(source.items()):
    if i >= args.count:
        break
    if args.raw:
        continue
    ninputs = max(ninputs, len(batch))
    for i, a in enumerate(batch):
        if not isinstance(a, np.ndarray):
            continue
        shapes[i].add((str(a.dtype),) + tuple(a.shape))
        stats[i] += a.astype(np.float32)

finish = time.time()

if args.raw:
    print(source.stats.summary())
    sys.exit(0)

print("Source:")
print(source.stats.summary())
print()

for i in range(ninputs):
    print("=== Input {} ===\n".format(i))
    if stats[i].count == 0:
        print("not a tensor")
    else:
        print(stats[i].summary())
        print(shapes[i])
    print()
