#!/usr/bin/python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#

import os
import pickle
import argparse
import sqlite3
import sys
import tarfile

from tarproclib import reader, writer

parser = argparse.ArgumentParser("Sort the samples inside a tar file.")
parser.add_argument("-k", "--key", default="__key__")
parser.add_argument("-s", "--sortkey", default="__key__")
parser.add_argument("-S", "--sorttype", default=None)
parser.add_argument("-r", "--report", default=0, type=int)
parser.add_argument("-t", "--tempfile", default="_tarsort-{pid}.db")
parser.add_argument("-o", "--output", default="-")
parser.add_argument("--update", action="store_true")
parser.add_argument("--keep", action="store_true")
parser.add_argument("--commit", default=1000, type=int)
parser.add_argument("input", default="-", nargs="?")
args = parser.parse_args()

if args.sorttype is None:
    def sorttype(x): return x
    dbtype = "text"
elif args.sorttype == "int":
    sorttype = int
    dbtype = "integer"
elif args.sorttype == "float":
    sorttype = float
    dbtype = "real"
elif args.sorttype == "shuffle":
    def hash(s):
        import hashlib
        if isinstance(s, str):
            s = s.encode("utf-8")
        return hashlib.md5(s).hexdigest()
    sorttype = hash
    dbtype = "text"
else:
    sys.exit(f"{args.sorttype}: unknown sort type")


def dprint(*args, **kw): print(*args, file=sys.stderr, **kw)


tempfile = args.tempfile.format(pid=os.getpid())
assert not os.path.exists(tempfile), tempfile
db = sqlite3.connect(tempfile)
db.execute(f"""
create table tarsort (
    sortkey {dbtype} not null primary key,
    key text not null,
    sample blob
)
""")

try:
    for i, sample in enumerate(reader.TarIterator(args.input)):
        if args.report > 0 and i % args.report == 0:
            dprint(">", i, sample.get("__key__"))
        sortkey = sample.get(args.sortkey, "")
        sortkey = sorttype(sortkey)
        key = sample.get(args.key, "__{}__".format(i))
        cmd = "insert" if not args.update else "insert or replace"
        db.execute(f"{cmd} into tarsort values (?,?,?)",
                   (sortkey, key, pickle.dumps(sample)))
        if i % args.commit == 0:
            db.commit()
except tarfile.ReadError:
    pass

db.commit()

try:
    if args.output == "-":
        stream = sys.stdout.buffer
    else:
        stream = open(args.output, "wb")
    sink = writer.TarWriter(stream)
    cur = db.execute("select sample from tarsort order by sortkey")
    for i, (sample,) in enumerate(cur):
        sample = pickle.loads(sample)
        if args.report > 0 and i % args.report == 0:
            dprint("<", i, sample.get("__key__"))
        sink.write(sample)
    sink.close()
finally:
    if not args.keep:
        os.unlink(tempfile)
    try:
        stream.close()
    except:
        pass
