from os import listdir
from os.path import isfile, splitext

genome_base_name = "%%GENOME_FASTA%%".rsplit("/", 1)[-1]

def get_reads_file_path(sample, mate):
    if mate == 1:
        gzipped_extension = config['entries'][sample]['main']['forward_reads_gzipped']
    else:
        gzipped_extension = config['entries'][sample]['main']['reverse_reads_gzipped']

    return "preprocessing/{}_R{}.fastq{}".format(sample, mate, ".gz" if gzipped_extension else "")

rule all:
    input:
        "mapping/statistics/flagstats/mapping_stats.xlsx",
        expand("mapping/{A}.bam", A=sorted(config['entries'].keys())),
        ".report/modules/minimap2.html"


rule summarize_flagstat:
    input:
        expand("mapping/statistics/{NAME}_flagstat.txt", NAME=config['entries'].keys())
    output:
        "mapping/statistics/flagstat_summary.tsv"
    group:
        "bwa_mapping"
    run:
        with open(output[0], "w") as summary_file:
            total_pattern = re.compile(r"(\d+) \+ (\d+) in total \(QC-passed reads \+ QC-failed reads\)")
            secondary_pattern = re.compile(r"(\d+) \+ (\d+) secondary")
            supplementary_pattern = re.compile(r"(\d+) \+ (\d+) supplementary")
            duplicates_pattern = re.compile(r"(\d+) \+ (\d+) duplicates")
            mapped_pattern = re.compile( r"(\d+) \+ (\d+) mapped \((.+):(.+)\)")
            paired_in_sequencing_pattern = re.compile(r"(\d+) \+ (\d+) paired in sequencing")
            read1_pattern = re.compile(r"(\d+) \+ (\d+) read1")
            read2_pattern = re.compile(r"(\d+) \+ (\d+) read2")
            properly_paired_pattern = re.compile(r"(\d+) \+ (\d+) properly paired \((.+):(.+)\)")
            with_itself_and_mate_pattern = re.compile(r"(\d+) \+ (\d+) with itself and mate mapped")
            singletons_pattern = re.compile(r"(\d+) \+ (\d+) singletons \((.+):(.+)\)")
            summary_file.write("sample	total alignments	secondary	supplementary	duplicates	mapped	mapped[%]	paired in sequencing	read1	read2	properly paired	properly paired[%]	with itself and mate mapped	singletons	singletons[%]\n")
            for i in input:
                summary_file.write(i.replace("_flagstat.txt", "").replace("mapping/statistics/", "") + "	")
                with open(i, "r") as input_file:
                    for line in input_file:
                        total_match = total_pattern.match(line)
                        if total_match:
                            total = int(total_match.group(1))
                            continue
                        secondary_match = secondary_pattern.match(line)
                        if secondary_match:
                            secondary = int(secondary_match.group(1))
                            continue
                        supplementary_match = supplementary_pattern.match(line)
                        if supplementary_match:
                            supplementary = int(supplementary_match.group(1))
                            continue
                        duplicates_match = duplicates_pattern.match(line)
                        if duplicates_match:
                            duplicates = int(duplicates_match.group(1))
                            continue
                        mapped_match = mapped_pattern.match(line)
                        if mapped_match:
                            mapped = int(mapped_match.group(1))
                            continue
                        paired_in_sequencing_match = paired_in_sequencing_pattern.match(line)
                        if paired_in_sequencing_match:
                            paired_in_sequencing = int(paired_in_sequencing_match.group(1))
                            continue
                        read1_match = read1_pattern.match(line)
                        if read1_match:
                            read1 = int(read1_match.group(1))
                            continue
                        read2_match = read2_pattern.match(line)
                        if read2_match:
                            read2 = int(read2_match.group(1))
                            continue
                        properly_paired_match = properly_paired_pattern.match(line)
                        if properly_paired_match:
                            properly_paired = int(properly_paired_match.group(1))
                            continue
                        with_itself_and_mate_match = with_itself_and_mate_pattern.match(line)
                        if with_itself_and_mate_match:
                            with_itself_and_mate = int(with_itself_and_mate_match.group(1))
                            continue
                        singletons_match = singletons_pattern.match(line)
                        if singletons_match:
                            singletons = int(singletons_match.group(1))
                            continue
                    summary_file.write(("{}\t"*5 + "{:2f}\t" + "{}\t"*4 + "{:2f}" + "\t{}\t{}\t{:2f}\n").format(
                        total,
                        secondary,
                        supplementary,
                        duplicates,
                        mapped, mapped/total*100,
                        paired_in_sequencing,
                        read1, read2,
                        properly_paired, properly_paired/paired_in_sequencing*100,
                        with_itself_and_mate,
                        singletons, singletons/paired_in_sequencing*100))


rule samtools_flagstat:
    input:
        "mapping/sam/{name}.sam"
    output:
        "mapping/statistics/{name}_flagstat.txt"
    conda:
        "../lib/conda_env.yaml"
    group:
        "bwa_mapping"
    shell:
        "samtools flagstat {input} > {output}"


rule mapping_stats_xlsx:
    input:
        "mapping/statistics/flagstat_summary.tsv"
    output:
        xlsx="mapping/statistics/flagstats/mapping_stats.xlsx",
        plot_alignment_absolute="mapping/statistics/flagstats/alignment_stats.svg",
        plot_alignment_relative="mapping/statistics/flagstats/alignment_stats_relative.svg"
    params:
        plot_dir="mapping/statistics/flagstats/"
    conda:
        "../lib/conda_env.yaml"
    group:
        "minimap2_statistics"
    shell:
        "python3 lib/pe_mapping_stats_tsv_to_xlsx.py {input} {output.xlsx} {params.plot_dir}"


rule minimap2_index:
    input:
        genome="%%GENOME_FASTA%%"
    output:
        "mapping/reference/" + genome_base_name + ".mmi"
    conda:
        "../lib/conda_env.yaml"
    group:
        "minimap2_index"
    log:
        "mapping/logs/minimap2_index.log"
    threads:
        8
    shell:
        "minimap2 %%MINIMAP2_PRESET%% %%ADDITIONAL_INDEX_OPTIONS%% -t {threads} -d {output} {input.genome} 2>&1 |"
        "tee {log}"


rule minimap2_mapping:
    input:
	    genome_index="mapping/reference/" + genome_base_name + ".mmi",
	    reads="preprocessing/{sample}_R1.fastq.gz",
	    reads_reverse="preprocessing/{sample}_R2.fastq.gz"
    output:
        "mapping/sam/{sample}.sam"
    conda:
        "../lib/conda_env.yaml"
    group:
        "minimap2_mapping"
    log:
        "mapping/logs/minimap2_mapping.{sample}.log"
    threads:
        2
    shell:
        "minimap2 %%MINIMAP2_PRESET%% %%ADDITIONAL_ALIGNMENT_OPTIONS%% -t {threads} -a %%MINIMAP2_PRESET%% {input.genome_index} {input.reads} {input.reads_reverse} 2>&1 > {output} |"
        "tee {log};"
        "samtools flagstat {output} >> {log}"


rule index_bam:
    input:
        "mapping/sam/{sample}.sam"
    output:
        bam="mapping/{sample}.bam",
        csi="mapping/{sample}.bam.csi",
        bam_unmapped="mapping/unmapped/{sample}_unmapped.bam",
        bam_singleton="mapping/singleton/{sample}_singletons.bam"
    params:
        tmp_singleton=temp("mapping/singleton/{sample}_tmp.sam")
    conda:
        "../lib/conda_env.yaml"
    group:
        "minimap2_mapping"
    threads:
        2
    shell:
        "samtools view -F 12 -Shb {input} | samtools sort -@ {threads} -o {output.bam} - && samtools index -c {output.bam};"

        "samtools view -f 4 -F 8 -Sh {input} > {params.tmp_singleton} && samtools view -f 8 -F 4 -S {input} >> {params.tmp_singleton} "
        "&& samtools view -Shb {params.tmp_singleton} | samtools sort -@ {threads} -o {output.bam_singleton} - && rm {params.tmp_singleton};"

        "samtools view -f 12 -Shb {input} | samtools sort -@ {threads} -o {output.bam_unmapped} -;"


rule write_settings:
    output:
        settings="mapping/settings.yaml"
    conda:
        "../lib/conda_env.yaml"
    group:
        "minimap2_report"
    shell:
        """
        set +e
        m2_version=$(minimap2 --version)
        echo "minimap2_version: \\"$m2_version\\"" >> {output.settings};
        echo "preset: '%%MINIMAP2_PRESET%%'" >> {output.settings};
        echo 'additional_index_parameters: "%%ADDITIONAL_INDEX_OPTIONS%%"' >> {output.settings}; 
        echo 'additional_alignment_parameters: "%%ADDITIONAL_ALIGNMENT_OPTIONS%%"' >> {output.settings}; 
        """


rule generate_report_data:
    input:
        stats_tsv="mapping/statistics/flagstat_summary.tsv",
        images=["mapping/statistics/flagstats/alignment_stats.svg", "mapping/statistics/flagstats/alignment_stats_relative.svg"],
        settings="mapping/settings.yaml"
    output:
        mm2_data=".report/data/minimap2_data.js",
        mm2_html=".report/modules/minimap2.html",
        mm2_js=".report/js/modules/minimap2.js",
        mm2_css=".report/css/modules/minimap2.css",
        mm2_images=directory(".report/img/modules/minimap2/")
    conda:
        "../lib/conda_env.yaml"
    group:
        "minimap2_report"
    shell:
        "python3 lib/generate_report_data.py --stats {input.stats_tsv} --output {output.mm2_data} --settings {input.settings} --paired-end && "
        "cp lib/report/minimap2.html {output.mm2_html} && "
        "cp lib/report/minimap2.js {output.mm2_js} && "
        "cp lib/report/minimap2.css {output.mm2_css} &&"
        "mkdir -p {output.mm2_images} && cp {input.images} {output.mm2_images}"
