#!python

import argparse
import time
import tensorcom

parser = argparse.ArgumentParser("show tensor inputs")
parser.add_argument("input", nargs="*")
parser.add_argument("-R", "--raw", action="store_true")
parser.add_argument("-r", "--report", type=int, default=10)
parser.add_argument("-c", "--count", type=int, default=999999999)
args = parser.parse_args()

if args.input == []:
    args.input = ["zsub://127.0.0.1:7880"]


def make_source():
    print("input:", args.input)
    source = tensorcom.Connection(device=None, raw=args.raw)
    for c in args.input:
        print(c)
        source.connect(c)
    return source


index = 0
total = 0

while True:
    source = make_source()
    for i, batch in enumerate(source.items()):
        if index == 0:
            print("connected")
            last = time.time()
        index += 1
        total += 1
        bs = len(batch[0])
        if index % args.report == 0:
            delta = time.time() - last
            print(
                "{:20d} {:8.3f} batches/s {:8.3f} samples/s (batchsize: {:d})".format(
                    index, total / delta, total / delta * bs, bs
                )
            )
            total = 0
            last = time.time()
        if index > args.count:
            break
    if index > args.count:
        break
