#!/usr/bin/env python3

"""
This utility connects to a running TrustGraph through the API and creates
a knowledge core from the data streaming through the processing queues.
For completeness of data, tg-save-kg-core should be initiated before data
loading takes place.  The default output  format, msgpack should be used.
JSON output format is also available - msgpack produces a more compact
representation, which is also more performant to load.
"""

import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os

async def fetch_ge(queue, user, collection, url):
    async with aiohttp.ClientSession() as session:
        async with session.ws_connect(f"{url}stream/graph-embeddings") as ws:
            async for msg in ws:
                if msg.type == aiohttp.WSMsgType.TEXT:

                    data = msg.json()

                    if user:
                        if data["metadata"]["user"] != user:
                            continue

                    if collection:
                        if data["metadata"]["collection"] != collection:
                            continue

                    await queue.put([
                        "ge",
                        {
                            "m": {
                                "i": data["metadata"]["id"], 
                                "m": data["metadata"]["metadata"],
                                "u": data["metadata"]["user"],
                                "c": data["metadata"]["collection"],
                            },
                            "v": data["vectors"],
                            "e": data["entity"],
                        }
                    ])
                if msg.type == aiohttp.WSMsgType.ERROR:
                    print("Error")
                    break

async def fetch_triples(queue, user, collection, url):
    async with aiohttp.ClientSession() as session:
        async with session.ws_connect(f"{url}stream/triples") as ws:
            async for msg in ws:
                if msg.type == aiohttp.WSMsgType.TEXT:

                    data = msg.json()

                    if user:
                        if data["metadata"]["user"] != user:
                            continue

                    if collection:
                        if data["metadata"]["collection"] != collection:
                            continue

                    await queue.put((
                        "t",
                        {
                            "m": {
                                "i": data["metadata"]["id"], 
                                "m": data["metadata"]["metadata"],
                                "u": data["metadata"]["user"],
                                "c": data["metadata"]["collection"],
                            },
                            "t": data["triples"],
                        }
                    ))
                if msg.type == aiohttp.WSMsgType.ERROR:
                    print("Error")
                    break

ge_counts = 0
t_counts = 0

async def stats():

    global t_counts
    global ge_counts

    while True:
        await asyncio.sleep(5)
        print(
            f"Graph embeddings: {ge_counts:10d}  Triples: {t_counts:10d}"
        )

async def output(queue, path, format):

    global t_counts
    global ge_counts
    
    with open(path, "wb") as f:

        while True:

            msg = await queue.get()

            if format == "msgpack":
                f.write(msgpack.packb(msg, use_bin_type=True))
            else:
                f.write(json.dumps(msg).encode("utf-8"))

            if msg[0] == "t":
                t_counts += 1
            else:
                if msg[0] == "ge":
                    ge_counts += 1

async def run(**args):

    q = asyncio.Queue()

    ge_task = asyncio.create_task(
        fetch_ge(
            queue=q, user=args["user"], collection=args["collection"],
            url=args["url"] + "api/v1/"
        )
    )

    triples_task = asyncio.create_task(
        fetch_triples(
            queue=q, user=args["user"], collection=args["collection"],
            url=args["url"] + "api/v1/"
        )
    )

    output_task = asyncio.create_task(
        output(
            queue=q, path=args["output_file"], format=args["format"],
        )
        
    )

    stats_task = asyncio.create_task(stats())

    await output_task
    await triples_task
    await ge_task
    await stats_task

async def main():
    
    parser = argparse.ArgumentParser(
        prog='tg-save-kg-core',
        description=__doc__,
    )

    default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
    default_user = "trustgraph"
    collection = "default"

    parser.add_argument(
        '-u', '--url',
        default=default_url,
        help=f'TrustGraph API URL (default: {default_url})',
    )

    parser.add_argument(
        '-o', '--output-file',
        # Make it mandatory, difficult to over-write an existing file
        required=True,
        help=f'Output file'
    )

    parser.add_argument(
        '--format',
        default="msgpack",
        choices=["msgpack", "json"],
        help=f'Output format (default: msgpack)',
    )

    parser.add_argument(
        '--user',
        help=f'User ID to filter on (default: no filter)'
    )

    parser.add_argument(
        '--collection',
        help=f'Collection ID to filter on (default: no filter)'
    )

    args = parser.parse_args()

    await run(**vars(args))

asyncio.run(main())

