#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import sys

from tarproclib import reader

parser = argparse.ArgumentParser("Extract textual information from tar training files.")
parser.add_argument("-f", "--fields", default="",
                    help="fields")
parser.add_argument("-c", "--count", type=int, default=10000000000,
                    help="maximum number of output records")
parser.add_argument("--nokey", action="store_true",
                    help="don't output the key field")
parser.add_argument("--noheader", action="store_true",
                    help="don't output a header")
parser.add_argument("input", default="-", nargs="?")
args = parser.parse_args()


def dprint(*args, **kw):
    print(*args, file=sys.stderr, **kw)


def getfirst(a, keys, default=None):
    if isinstance(keys, str):
        keys = keys.split(";")
    for k in keys:
        result = a.get(k)
        if result is not None:
            return result
    return default


def parse_field_spec(fields):
    if isinstance(fields, str):
        fields = fields.split()
    return [field.split(";") for field in fields]


fields = parse_field_spec(args.fields)
if not args.nokey:
    fields = [["__key__"]] + fields


def unicode(s):
    if isinstance(s, str):
        return s
    if isinstance(s, bytes):
        return s.decode("utf-8")
    return str(s)


if not args.noheader:
    print("\t".join([f[0] for f in fields]))

for i, sample in enumerate(reader.TarIterator(args.input)):
    if i >= args.count:
        break
    assert isinstance(sample, dict), sample
    result = [unicode(getfirst(sample, k)) for k in fields]
    print("\t".join(result))
