#!/usr/bin/env python3

"""
This utility connects to a running TrustGraph through the API and creates
a knowledge core from the data streaming through the processing queues.
For completeness of data, tg-save-kg-core should be initiated before data
loading takes place.  The default output  format, msgpack should be used.
JSON output format is also available - msgpack produces a more compact
representation, which is also more performant to load.
"""

import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os
import signal

class Running:
    def __init__(self): self.running = True
    def get(self): return self.running
    def stop(self): self.running = False

async def fetch_de(running, queue, user, collection, url):

    async with aiohttp.ClientSession() as session:

        de_url = f"{url}stream/document-embeddings"

        async with session.ws_connect(de_url) as ws:

            while running.get():

                try:
                    msg = await asyncio.wait_for(ws.receive(), 1)
                except:
                    continue

                if msg.type == aiohttp.WSMsgType.TEXT:

                    data = msg.json()

                    if user:
                        if data["metadata"]["user"] != user:
                            continue

                    if collection:
                        if data["metadata"]["collection"] != collection:
                            continue

                    await queue.put([
                        "de",
                        {
                            "m": {
                                "i": data["metadata"]["id"], 
                                "m": data["metadata"]["metadata"],
                                "u": data["metadata"]["user"],
                                "c": data["metadata"]["collection"],
                            },
                            "c": [
                                {
                                    "c": chunk["chunk"],
                                    "v": chunk["vectors"],
                                }
                                for chunk in data["chunks"]
                            ]
                        }
                    ])
                if msg.type == aiohttp.WSMsgType.ERROR:
                    print("Error")
                    break

de_counts = 0

async def stats(running):

    global t_counts
    global de_counts

    while running.get():

        await asyncio.sleep(2)

        print(
            f"Document embeddings: {de_counts:10d}"
        )

async def output(running, queue, path, format):

    global t_counts
    global de_counts
    
    with open(path, "wb") as f:

        while running.get():

            try:
                msg = await asyncio.wait_for(queue.get(), 0.5)
            except:
                # Hopefully it's TimeoutError.  Annoying to match since
                # it changed in 3.11.
                continue

            if format == "msgpack":
                f.write(msgpack.packb(msg, use_bin_type=True))
            else:
                f.write(json.dumps(msg).encode("utf-8"))

            if msg[0] == "de":
                de_counts += 1

    print("Output file closed")

async def run(running, **args):

    q = asyncio.Queue()

    de_task = asyncio.create_task(
        fetch_de(
            running=running,
            queue=q, user=args["user"], collection=args["collection"],
            url=args["url"] + "api/v1/"
        )
    )

    output_task = asyncio.create_task(
        output(
            running=running, queue=q,
            path=args["output_file"], format=args["format"],
        )
        
    )

    stats_task = asyncio.create_task(stats(running))

    await output_task
    await de_task
    await stats_task

    print("Exiting")

async def main(running):
    
    parser = argparse.ArgumentParser(
        prog='tg-save-kg-core',
        description=__doc__,
    )

    default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
    default_user = "trustgraph"
    collection = "default"

    parser.add_argument(
        '-u', '--url',
        default=default_url,
        help=f'TrustGraph API URL (default: {default_url})',
    )

    parser.add_argument(
        '-o', '--output-file',
        # Make it mandatory, difficult to over-write an existing file
        required=True,
        help=f'Output file'
    )

    parser.add_argument(
        '--format',
        default="msgpack",
        choices=["msgpack", "json"],
        help=f'Output format (default: msgpack)',
    )

    parser.add_argument(
        '--user',
        help=f'User ID to filter on (default: no filter)'
    )

    parser.add_argument(
        '--collection',
        help=f'Collection ID to filter on (default: no filter)'
    )

    args = parser.parse_args()

    await run(running, **vars(args))

running = Running()

def interrupt(sig, frame):
    running.stop()
    print('Interrupt')

signal.signal(signal.SIGINT, interrupt)

asyncio.run(main(running))

