#!/usr/bin/env python3
import sys
import asyncio
import json
from argparse import ArgumentParser 
import logging 
import sqlite3
import string
import random
import os
import time
import tempfile

## Configuration 
DATABASE_FILE = os.environ.get("DATABASE_FILE", os.path.join(tempfile.gettempdir(), "clustersubmit.sqlite3"))
LOG_FILE = os.environ.get("LOG_FILE", os.path.join(tempfile.gettempdir(), "clustersubmit.log"))
TIMEOUT_SECONDS = float(os.environ.get("TIMEOUT_SECONDS", "300"))

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    filename=LOG_FILE,
    filemode="a",
)
QUERY_RUNNABLE = """
    SELECT child.jobid, child.jobscript 
    FROM jobs AS child
    WHERE done = 0 AND NOT EXISTS (
        SELECT * FROM jobs AS parents
        JOIN edges AS e ON e.parent = parents.jobid
        WHERE e.child = child.jobid AND done = 0
    );
"""


## Arguments 
parser = ArgumentParser("Mock script representing cluster-managed DAG execution")
parser.add_argument("-d", "--dependencies", help="Space-separated job ids", default="")
parser.add_argument("--execute", action='store_true', help="Execute pending jobs and exits")
parser.add_argument("jobscript", type=str, default=None, nargs='?')
args = parser.parse_args()


## Initialization
with sqlite3.connect(DATABASE_FILE) as db:
    db.execute("""
        CREATE TABLE IF NOT EXISTS jobs(
            jobid TEXT PRIMARY KEY, 
            jobscript TEXT, 
            done INT, 
            created_at TEXT
        )
    """) 

    db.execute("""
        CREATE TABLE IF NOT EXISTS edges(
            parent TEXT, 
            child TEXT, 
            created_at TEXT,
            FOREIGN KEY(parent) REFERENCES jobs(jobid),
            FOREIGN KEY(child) REFERENCES jobs(jobid)
        );
    """) 

    db.execute("DELETE FROM jobs WHERE created_at <= datetime('now', '-1 hour')")
    db.execute("DELETE FROM edges WHERE created_at <= datetime('now', '-1 hour')")



## execution
db_lock = asyncio.Lock()

async def execute_single_job(jobid, jobscript):
    """
    Coroutine executing one job
    """
    logging.info(f"Executing {jobid}")
    proc = await asyncio.create_subprocess_shell(jobscript)
    stdout, stderr = await proc.communicate()
    if proc.returncode != 0:
        logging.error(f"Job {jobid} failed with return code {proc.returncode}")
        if stdout:
            logging.debug(f"Job {jobid} stdout: {stdout.decode()[:500]}...") 
        if stderr:
            logging.error(f"Job {jobid} stderr: {stderr.decode()[:500]}...")

        raise RuntimeError(f"Faied executing job {jobid}")

    async with db_lock:
        with sqlite3.connect(DATABASE_FILE, timeout=30) as db:
            db.execute("PRAGMA journal_mode=WAL")
            db.execute("UPDATE jobs SET done = 1 WHERE jobid = ?", (jobid,))
    logging.info(f"Marked {jobid} as done")


async def execute_jobs(timeout_seconds=TIMEOUT_SECONDS):
    """
    Simplistic scheduler checking jobs that can run and executing them
    """
    start = time.perf_counter()
    while (time.perf_counter() - start) < timeout_seconds:
        with sqlite3.connect(DATABASE_FILE) as db:
            result = db.execute(QUERY_RUNNABLE).fetchall()
        if not result:
            logging.info("Could not find any other job to execute.")
            return 

        async with asyncio.TaskGroup() as tg:
            for row in result:
                tg.create_task(execute_single_job(*row))

    raise TimeoutError(f"Exceeded timeout of {timeout_seconds} seconds")
                

if args.execute:
    asyncio.run(execute_jobs())
    exit(0)
elif args.jobscript is None:
    logging.critical(f"Failed to submit a job: missing jobscript. {sys.argv}")
    raise ValueError("Expected jobscript as the last positional argument")

    

## Submission 
with open(args.jobscript) as input_file:
    jobscript_content = input_file.read()

for line in jobscript_content.splitlines():
    if line.startswith("# properties = "):
        properties = json.loads(line[len("# properties = "):])
        break

logging.info(f"---")
logging.info(f"Processing rule: {properties['rule']}")

job_name = ''.join(
    [properties['rule'], '-', str(time.time()), '-'] + 
    [random.choice(string.ascii_lowercase) for _ in range(5)]
)

dependencies = args.dependencies.split()
logging.info(f"Submitting `{job_name}` which depends on {dependencies}")

with sqlite3.connect(DATABASE_FILE) as db:
    db.execute(
        "INSERT INTO jobs(jobid, jobscript, done, created_at) VALUES (?, ?, 0, datetime('now'))",
        (job_name, jobscript_content)
    )

    for dependency in dependencies:
        db.execute(
            "INSERT INTO edges(parent, child, created_at) VALUES (?, ?, datetime('now'))",
            (dependency, job_name)
        )

print(job_name)




