#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import argparse
import codecs
import csv
import os
import sys

from tarproclib import paths, writer

epilog = """
The column headers contain the output filename extensions.  Each column
contains either data or a filename.  Headers starting with "@" denote
that the column contains actual file names.  If there is a __key__ column,
it is used as the key, otherwise records are numbered sequentially.
"""

parser = argparse.ArgumentParser("Create a tar file for a csv/tsv plan.",
                                 epilog=epilog)
parser.add_argument("-f", "--delim", default="",
                    help="delimiter in csv/tsv file")
parser.add_argument("-k", "--key", default="{record:09d}",
                    help="output format for record numbers")
parser.add_argument("-C", "--dir", default=None,
                    help="change into directory before executing")
parser.add_argument("-o", "--output", default="-",
                    help="output file (default: stdout)")
parser.add_argument("plan")
args = parser.parse_args()


def dprint(*args, **kw):
    print(*args, file=sys.stderr, **kw)


if args.delim == "":
    if args.plan.endswith(".csv"):
        args.delim = ","
    else:
        args.delim = "\t"

plan_ = csv.reader(codecs.open(args.plan, "r", encoding="utf-8"), delimiter=args.delim)
plan = iter(plan_)
header = next(plan)
dprint(header)
for row in plan:
    assert len(row) == len(header)
    for h, v in zip(header, row):
        if h[0] == "@":
            if args.dir:
                v = os.path.join(args.dir, v)
            assert os.path.exists(v), v

if args.output == "-":
    output = sys.stdout.buffer
else:
    output = open(args.output, "wb")
sink = writer.TarWriter(output)

plan_ = csv.reader(codecs.open(args.plan, "r", encoding="utf-8"), delimiter=args.delim)
plan = iter(plan_)
header = next(plan)
record = 0
for row in plan:
    sample = {}
    for h, v in zip(header, row):
        if h[0] == "@":
            if args.dir:
                v = os.path.join(args.dir, v)
            sample[h[1:]] = paths.read_binary(v)
        else:
            sample[h] = v.encode("utf-8")
    if "__key__" not in sample:
        sample["__key__"] = args.key.format(record=record)
    sink.write(sample)
    record += 1
sink.close()
