"""Federated A/B simulation pipeline."""

SITES = ["site_a", "site_b"]
SAMPLE_HANDLE = "app/data/samples/federated/{site}.dhandle.json"
SITE_METRICS = expand("app/code/artifacts/federated/{site}_metrics.json", site=SITES)
SITE_RECEIPTS = expand("app/code/receipts/federated/{site}.receipt.json", site=SITES)
AGGREGATED_RECEIPT = "app/code/receipts/federated/aggregated.receipt.json"


def handle_path(wildcards):
    return SAMPLE_HANDLE.format(site=wildcards.site)


rule all:
    input:
        SITE_RECEIPTS,
        AGGREGATED_RECEIPT,


rule site_run:
    input:
        handle=handle_path,
    output:
        metrics="app/code/artifacts/federated/{site}_metrics.json",
        receipt="app/code/receipts/federated/{site}.receipt.json",
    params:
        site=lambda wildcards: wildcards.site,
    shell:
        "uv run python -m app.code.lib.steps.federated_site --site {params.site} --handle {input.handle} "
        "--metrics-out {output.metrics} --receipt-out {output.receipt}"


rule aggregate_receipts:
    input:
        metrics=SITE_METRICS,
    output:
        receipt=AGGREGATED_RECEIPT,
    params:
        epsilon=1.0,
        max_records=20,
        max_positive=10,
        seed=42,
    shell:
        "uv run python -m app.code.lib.steps.federated_aggregate --metrics {input.metrics} --out {output.receipt} "
        "--epsilon {params.epsilon} --max-records {params.max_records} --max-positive {params.max_positive} --seed {params.seed}"
