#!/usr/bin/env python3

"""This utility takes a knowledge core and loads it into a running TrustGraph
through the API.  The knowledge core should be in msgpack format, which is the
default format produce by tg-save-kg-core.
"""

import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os

async def load_ge(queue, url):

    async with aiohttp.ClientSession() as session:

        async with session.ws_connect(f"{url}load/graph-embeddings") as ws:

            while True:

                msg = await queue.get()

                msg = {
                    "metadata": {
                        "id": msg["m"]["i"], 
                        "metadata": msg["m"]["m"],
                        "user": msg["m"]["u"],
                        "collection": msg["m"]["c"],
                    },
                    "vectors": msg["v"],
                    "entity": msg["e"],
                }

                await ws.send_json(msg)

async def load_triples(queue, url):
    async with aiohttp.ClientSession() as session:
        async with session.ws_connect(f"{url}load/triples") as ws:

            while True:

                msg = await queue.get()

                msg ={
                    "metadata": {
                        "id": msg["m"]["i"], 
                        "metadata": msg["m"]["m"],
                        "user": msg["m"]["u"],
                        "collection": msg["m"]["c"],
                    },
                    "triples": msg["t"],
                }

                await ws.send_json(msg)

ge_counts = 0
t_counts = 0

async def stats():

    global t_counts
    global ge_counts

    while True:
        await asyncio.sleep(5)
        print(
            f"Graph embeddings: {ge_counts:10d}  Triples: {t_counts:10d}"
        )

async def loader(ge_queue, t_queue, path, format, user, collection):

    global t_counts
    global ge_counts

    if format == "json":

        raise RuntimeError("Not implemented")

    else:

        with open(path, "rb") as f:

            unpacker = msgpack.Unpacker(f, raw=False)

            for unpacked in unpacker:

                if user:
                    unpacked["metadata"]["user"] = user

                if collection:
                    unpacked["metadata"]["collection"] = collection

                if unpacked[0] == "t":
                    await t_queue.put(unpacked[1])
                    t_counts += 1
                else:
                    if unpacked[0] == "ge":
                        await ge_queue.put(unpacked[1])
                        ge_counts += 1

async def run(**args):

    # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't
    # grow to eat all memory
    ge_q = asyncio.Queue(maxsize=500)
    t_q = asyncio.Queue(maxsize=500)

    load_task = asyncio.create_task(
        loader(
            ge_queue=ge_q, t_queue=t_q,
            path=args["input_file"], format=args["format"],
            user=args["user"], collection=args["collection"],
        )
        
    )

    ge_task = asyncio.create_task(
        load_ge(
            queue=ge_q, url=args["url"] + "api/v1/"
        )
    )

    triples_task = asyncio.create_task(
        load_triples(
            queue=t_q, url=args["url"] + "api/v1/"
        )
    )

    stats_task = asyncio.create_task(stats())

    await load_task
    await triples_task
    await ge_task
    await stats_task

async def main():
    
    parser = argparse.ArgumentParser(
        prog='tg-load-kg-core',
        description=__doc__,
    )

    default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
    default_user = "trustgraph"
    collection = "default"

    parser.add_argument(
        '-u', '--url',
        default=default_url,
        help=f'TrustGraph API URL (default: {default_url})',
    )

    parser.add_argument(
        '-i', '--input-file',
        # Make it mandatory, difficult to over-write an existing file
        required=True,
        help=f'Output file'
    )

    parser.add_argument(
        '--format',
        default="msgpack",
        choices=["msgpack", "json"],
        help=f'Output format (default: msgpack)',
    )

    parser.add_argument(
        '--user',
        help=f'User ID to load as (default: from input)'
    )

    parser.add_argument(
        '--collection',
        help=f'Collection ID to load as (default: from input)'
    )

    args = parser.parse_args()

    await run(**vars(args))

asyncio.run(main())

