#!/usr/bin/env python3

"""
Uses the knowledge service to fetch a knowledge core which is saved
to a local file in msgpack format.
"""

import argparse
import os
import textwrap
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack

default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
default_user = 'trustgraph'

def write_triple(f, data):
    msg = (
        "t",
        {
            "m": {
                "i": data["metadata"]["id"], 
                "m": data["metadata"]["metadata"],
                "u": data["metadata"]["user"],
                "c": data["metadata"]["collection"],
            },
            "t": data["triples"],
        }
    )
    f.write(msgpack.packb(msg, use_bin_type=True))

def write_ge(f, data):
    msg = (
        "ge",
        {
            "m": {
                "i": data["metadata"]["id"], 
                "m": data["metadata"]["metadata"],
                "u": data["metadata"]["user"],
                "c": data["metadata"]["collection"],
            },
            "e": [
                {
                    "e": ent["entity"],
                    "v": ent["vectors"],
                }
                for ent in data["entities"]
            ]
        }
    )
    f.write(msgpack.packb(msg, use_bin_type=True))

async def fetch(url, user, id, output):

    if not url.endswith("/"):
        url += "/"

    url = url + "api/v1/socket"

    mid = str(uuid.uuid4())

    async with connect(url) as ws:

        req = json.dumps({
            "id": mid,
            "service": "knowledge",
            "request": {
                "operation": "get-kg-core",
                "user": user,
                "id": id,
            }
        })

        await ws.send(req)

        ge = 0
        t = 0

        with open(output, "wb") as f:

            while True:

                msg = await ws.recv()

                obj = json.loads(msg)

                if "response" not in obj:
                    raise RuntimeError("No response?")

                response = obj["response"]

                if "error" in response:
                    raise RuntimeError(obj["error"])

                if "eos" in response:
                    if response["eos"]: break

                if "triples" in response:
                    t += 1
                    write_triple(f, response["triples"])

                if "graph-embeddings" in response:
                    ge += 1
                    write_ge(f, response["graph-embeddings"])

        print(f"Got: {t} triple, {ge} GE messages.")

        await ws.close()

def main():

    parser = argparse.ArgumentParser(
        prog='tg-get-kg-core',
        description=__doc__,
    )

    parser.add_argument(
        '-u', '--url',
        default=default_url,
        help=f'API URL (default: {default_url})',
    )
    parser.add_argument(
        '-U', '--user',
        default=default_user,
        help=f'User ID (default: {default_user})'
    )

    parser.add_argument(
        '--id', '--identifier',
        required=True,
        help=f'Knowledge core ID',
    )

    parser.add_argument(
        '-o', '--output',
        required=True,
        help=f'Output file'
    )

    args = parser.parse_args()

    try:

        asyncio.run(
            fetch(
                url = args.url,
                user = args.user,
                id = args.id,
                output = args.output,
            )
        )

    except Exception as e:

        print("Exception:", e, flush=True)

main()

