rule all:
    input:
        expand("results/{activation}.tbl", activation=config["activation"]),
        "results/test_prior.tbl"


rule train_scoring_model:
    output: 
        "train_scoring_model_outputs/{lm}_{dimension}_{activation}.out"
    threads: 4
    resources:
        mem_mb = 100000,
        gpu = 1,
        runtime = "3d",
        partition = "vision"
    run:
        shell("""
        cd .. && python3 PretrainScoringModel.py --lm {wildcards.lm} --dim {wildcards.dimension} \
            --activation {wildcards.activation} --lr 0.1 > train_scoring_models/{output}
        """)


rule test_scoring_model:
    input: 
        "train_scoring_model_outputs/{lm}_{dimension}_{activation}.out"
    output: 
        "results/rows/{lm}_{dimension}_{activation}"
    threads: 4
    resources:
        mem_mb = 100000,
        gpu = 1,
        runtime = "1d",
        partition = "vision"
    run:
        shell("""
        cd .. && python3 TestScoringModel.py --lm {wildcards.lm} --dim {wildcards.dimension} \
            --activation {wildcards.activation} > train_scoring_models/{output}
        """)


rule concat_results_scores:
    input:
        expand("results/rows/{lm}_{dimension}_{{activation}}",
                lm=config["language_models"],
                dimension=config["dimensions"])
    output:
        "results/{activation}.tbl"
    threads: 1
    resources:
        mem_mb = 1000,
        partition = "batch,snowball,pinky"
    run:
        shell("echo -n "" > {output}")
        for file in zip(input):
            param = file[0].split("/")[-1].split(".")[0].split("_")
            #row by row for each sample, concatenate the scores and the second line of the benchmark file
            shell("echo {param[0]}';'{param[1]}';'{param[2]}';'$(cat {file}) >> {output}")


rule train_prior:
    input: 
        "train_scoring_model_outputs/{lm}_{dimension}_{activation}.out"
    output: 
        "train_prior_outputs/{lm}_{dimension}_{activation}_{num_components}.out"
    threads: 4
    resources:
        mem_mb = 100000,
        gpu = 1,
        runtime = "2d",
        partition = "vision"
    run:
        shell("""
        cd .. && python3 PretrainMvnPrior.py --lm {wildcards.lm} --reduced_dim {wildcards.dimension} \
            --activation {wildcards.activation} --components {wildcards.num_components} --unscaled > train_scoring_models/{output}
        """)
        
        
rule test_prior:
    input: 
        "train_prior_outputs/{lm}_{dimension}_{activation}_{num_components}.out"
    output: 
        "results/test_prior_rows/{lm}_{dimension}_{activation}_{num_components}"
    threads: 4
    resources:
        mem_mb = 32000,
        gpu = 1,
        runtime = "1h",
        partition = "vision"
    run:
        shell("""
        cd .. && python3 TestPrior.py --lm {wildcards.lm} --reduced_dim {wildcards.dimension} \
            --activation {wildcards.activation} --components {wildcards.num_components} --unscaled > train_scoring_models/{output}
        """)


rule concat_prior_tests:
    input:
        expand("results/test_prior_rows/{lm}_{dimension}_{activation}_{num_components}",
                lm=config["language_models"],
                dimension=config["dimensions"],
                activation=config["activation"],
                num_components=config["num_components"])
    output:
        "results/test_prior.tbl"
    threads: 1
    resources:
        mem_mb = 1000,
        partition = "batch,snowball,pinky"
    run:
        shell("echo -n "" > {output}")
        for file in zip(input):
            param = file[0].split("/")[-1].split("_")
            #row by row for each sample, concatenate the scores and the second line of the benchmark file
            shell("echo {param[0]}';'{param[1]}';'{param[2]}';'{param[3]}';'$(cat {file}) >> {output}")