
from snakemake.utils import Paramspace, validate, min_version
import pandas as pd
from os import makedirs

min_version("6.10.0")
validate(config, "config.schema.yml")

workdir: config.get("workdir", ".")

if config.get("singularity"):
    ruleorder: beast > beast_nosingularity 
else: 
    ruleorder: beast_nosingularity > beast

XMLgenerator = workflow.source_path("aux/methylationBetas2xml.py")
_ = workflow.source_path("aux/createXML.py"), workflow.source_path("aux/readInputMethylation.py")
readIDAT = workflow.source_path("aux/readIDAT.R")
selectCpGs = workflow.source_path("aux/select_cpgs.py")
manifest = workflow.source_path("data/MethylationEPIC_v-1-0_B4.csv")
crossreactivefile = workflow.source_path("data/13059_2016_1066_MOESM1_ESM.csv")

patient_col = config.get("patient_col", "Patient")
age_col = config.get("age_col", "Age")

df = pd.read_csv(config.get("patientInfo"))
AGES = df.groupby(patient_col)[age_col].first().to_dict()

stemCellsSweep=config.get("stemCells").split("-")
STEMCELLS=list(range(int(stemCellsSweep[0]), int(stemCellsSweep[1]), int(stemCellsSweep[2])))


if config.get("mode") == "trees":
    PATIENTS, = glob_wildcards(config["input_dir"] + "/{patient}.csv")
    ruleorder: getXML_short > getXML
else:
    df["sample_basename"] = df["Basename"].str.split("/").str[-1]
    PATIENTS = df[patient_col].unique()
    SAMPLES = df.groupby(patient_col)["sample_basename"].agg(list)
    ruleorder: getXML > getXML_short

# wildcard_constraints:
#     patient='|'.join([x for x in PATIENTS]),
#     stemcell='|'.join([str(x) for x in STEMCELLS]),

rule all:
    input: expand("results/{patient}/{patient}.{stemcell}cells.trees", patient=PATIENTS, stemcell=STEMCELLS)

rule readIDAT:
    input:
        dir = config.get("input_dir"),
    output:
        directory("idat_processed"),
    params:
        pattern = config.get("pattern", "csv$"),
        genome_plot = "--genome_plot" if config.get("genome_plot") else "",
        name = config.get("output"),
        readIDAT = readIDAT
    threads:
        config.get("threads", 1),
    shell:
        """
        Rscript {params.readIDAT} -i {input.dir} -o {output} -n {params.name} -p {params.pattern} {params.genome_plot} -c {threads}
        """


rule select_cpgs:
    input: "idat_processed"
    output: expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output"))
    params:
        manifest = config.get("manifest", manifest),
        crossreactivefile = crossreactivefile,
        percent = config.get("percent", 5),
        patientInfo = config.get("patientInfo"),
        name = config.get("output"),
        patient_col = patient_col,
        selectCpGs = selectCpGs
    shell:
        """
        python3 {params.selectCpGs} -e {params.manifest} -p {params.percent} -c {crossreactivefile} {output} {input}/{params.name}.betas.csv {input}/{params.name}.M.csv {input}/{params.name}.U.csv {params.patientInfo} -P {params.patient_col} --basename
        """

checkpoint split_flucpgs:
    input: rules.select_cpgs.output
    output: temp(directory("miniBetas"))
    run:
        df = pd.read_csv(input[0])
        makedirs(output[0], exist_ok = True)
        for patient in PATIENTS:
            samples = SAMPLES[patient]
            df.filter(items = samples).to_csv(f"{output[0]}/{patient}.csv")


def aggregate_input(wildcards):
    checkpoint_output = checkpoints.split_flucpgs.get(**wildcards).output[0]
    return "miniBetas/{patient}.csv"

rule getXML:
    input: aggregate_input
    output: "results/{patient}.{stemcell}cells.xml"
    params:
        age = lambda wildcards: AGES.get(wildcards.patient),
        delta = config.get("delta", .2),
        eta = config.get("eta", .7),
        kappa = config.get("kappa", 50),
        mu = config.get("mu", .1),
        gamma = config.get("gamma", .1),
        lam = config.get("lam", 1),
        iterations = config.get("iterations", 750_000),
        precision = config.get("precision", 6),
        sampling = config.get("sampling", 75),
        screenSampling = config.get("screenSampling", 500),
        stripRownames = "--stripRownames" if config.get("stripRownames") else "",
        mle_ps = "--mle-ps" if config.get("mle_ps") else "",
        mle_ss = "--mle-ss" if config.get("mle_ss") else "",
        hme = "--hme" if config.get("hme") else "",
        mle_steps = config.get("mle_steps", 100),
        mle_iterations = config.get("mle_iterations", int(config.get("iterations", 750_000)/config.get("mle_steps", 100))),
        mle_sampling = config.get("mle_sampling", int(config.get("iterations", 750_000)/1000)),
        outdir="results",
        XMLgenerator = XMLgenerator
    shell:
        """
        input=$(realpath {input})
        [[ ! -d {params.outdir} ]] && mkdir {params.outdir}
        cd {params.outdir}
        name=$(basename {output})
        name=${{name%.xml}}
        python3 {XMLgenerator} --age {params.age} --input $input --output $name --stemCells {wildcards.stemcell} --delta {params.delta} --eta {params.eta} --kappa {params.kappa} --mu {params.mu} --gamma {params.gamma} --lambda {params.lam} --iterations {params.iterations} --precision {params.precision} --sampling {params.sampling} --screenSampling {params.screenSampling} {params.stripRownames} {params.mle_ps} {params.mle_ss} {params.hme} --mle-steps {params.mle_steps} --mle-iterations {params.mle_iterations} --mle-sampling {params.mle_sampling} 
        """

use rule getXML as getXML_short with:
    input: config["input_dir"] + "/{patient}.csv"

rule beast:
    input:
        "results/{patient}.{stemcell}cells.xml"
    output:
        "results/{patient}/{patient}.{stemcell}cells.trees"
    log: 
        "beast_logs/{patient}/{stemcell}cells/{patient}.run.log"
    shell:
        """
        singularity run -B ~/dev/beast-icr/pisca-branch:/mnt docker://rachelicr/pisca-branch-master {input} 
        [[ -f {output} ]] && touch {output}
        """

rule beast_nosingularity:
    input:
        "results/{patient}.{stemcell}cells.xml"
    output:
        "results/{patient}/{patient}.{stemcell}cells.trees"
    log: 
        "beast_logs/{patient}/{stemcell}cells/{patient}.run.log"
    shell:
        """
        beast -beagle_off {input} 
        [[ -f {output} ]] && touch {output}
        """
