from os import listdir
from os.path import isfile, splitext

rule all:
    input:
        "mapping/stats/mapping_stats.xlsx",
        ".report/modules/star.html",
        expand("mapping/{A}.bam", A=sorted(config['entries'].keys())),
        expand("mapping/{A}.bam.csi", A=sorted(config['entries'].keys()))


rule mapping_stats_xlsx:
    input:
        "mapping/stats/mapping_stats.tsv"
    output:
        "mapping/stats/mapping_stats.xlsx"
    conda:
        "../lib/conda_env.yaml"
    shell:
        "python3 lib/se_mapping_stats_tsv_to_xlsx.py {input} {output}"


rule mapping_stats_tsv:
    input:
        expand("mapping/raw_star_output/{sample}_Log.final.out", sample=config['entries'].keys())
    output:
        "mapping/stats/mapping_stats.tsv"
    run:
        with open(output[0], 'w') as f_out:
            f_out.write("\t".join(["sample", "reads", "reads[%]", "aligned_0_times", "aligned_0_times[%]", "aligned_1_time", "aligned_1_time[%]", "aligned_more_than_1_times", "aligned_more_than_1_times[%]",
                                   "aligned_too_many_times", "aligned_too_many_times[%]", "chimeric_reads", "chimeric_reads[%]"]))
            f_out.write("\n")
            for file_path in input:
                with open(file_path) as f_in:
                    sample_name = file_path.split("/")[-1][:-len("_Log.final.out")]
                    entries = {"sample": sample_name, "aligned_0_times": 0}
                    for line in f_in:
                        splitted_line = [x.strip() for x in line.split("|")]
                        if splitted_line[0] == "Number of input reads":
                            entries["reads"] = int(splitted_line[1])
                            entries["reads_%"] = 100
                        elif splitted_line[0] == "Uniquely mapped reads number":
                            entries["aligned_1_time"] = int(splitted_line[1])
                        elif splitted_line[0] == "Number of reads mapped to multiple loci":
                            entries["aligned_more_than_1_times"] = int(splitted_line[1])
                        elif splitted_line[0] == "Number of reads mapped to too many loci":
                            entries["aligned_too_many_times"] = int(splitted_line[1])
                        elif splitted_line[0] == "Number of reads unmapped: too many mismatches":
                            entries["aligned_0_times"] += int(splitted_line[1])
                        elif splitted_line[0] == "Number of reads unmapped: too short":
                            entries["aligned_0_times"] += int(splitted_line[1])
                        elif splitted_line[0] == "Number of reads unmapped: other":
                            entries["aligned_0_times"] += int(splitted_line[1])
                        elif splitted_line[0] == "Number of chimeric reads":
                            entries["chimeric_reads"] = int(splitted_line[1])
                
                entries["chimeric_reads_%"] = entries.get("chimeric_reads", 0) / entries["reads"] * 100
                entries["aligned_0_times_%"] = entries.get("aligned_0_times", 0) / entries["reads"] * 100
                entries["aligned_1_time_%"] = entries.get("aligned_1_time", 0) / entries["reads"] * 100
                entries["aligned_more_than_1_times_%"] = entries.get("aligned_more_than_1_times", 0) / entries["reads"] * 100
                entries["aligned_too_many_times_%"] = entries.get("aligned_too_many_times", 0) / entries["reads"] * 100
            
                f_out.write("{}\t{}\t{:.2f}\t{}\t{:.2f}\t{}\t{:.2f}\t{}\t{:.2f}\t{}\t{:.2f}\t{}\t{:.2f}\n".format(
                    entries["sample"], entries["reads"], entries["reads_%"],
                    entries["aligned_0_times"], entries["aligned_0_times_%"], 
                    entries.get("aligned_1_time", 0), entries["aligned_1_time_%"], 
                    entries.get("aligned_more_than_1_times", 0), entries["aligned_more_than_1_times_%"],
                    entries.get("aligned_too_many_times", 0), entries["aligned_too_many_times_%"],
                    entries.get("chimeric_reads", 0), entries["chimeric_reads_%"]
                ))


rule star_index:
    input:
        genome="%%GENOME_FASTA%%"
    output:
        index_settings="mapping/index_parameters.txt",
        index_dir=directory("mapping/reference_index/")
    params:
        gff_file=lambda wildcards: "" if len("%%GENOME_ANNOTATION%%") == 0 else " --sjdbGTFfile %%GENOME_ANNOTATION%%",
        gff_parent_keyword=lambda wildcards: "" if len("%%GFF_PARENT_KEYWORD%%") == 0 else " --sjdbGTFtagExonParentGene %%GFF_PARENT_KEYWORD%%",
        gff_id_keyword=lambda wildcards: "" if len("%%GFF_ID_KEYWORD%%") == 0 else " --sjdbGTFtagExonParentTranscript %%GFF_ID_KEYWORD%%",
        gff_feature_type=lambda wildcards: "" if len("%%GFF_FEATURE_TYPE%%") == 0 else " --sjdbGTFfeatureExon %%GFF_FEATURE_TYPE%%"
    log:
        "mapping/logs/star_index.log"
    conda:
        "../lib/conda_env.yaml"
    threads:
        16
    shell:
        "genomeLength=`grep -v '^>' {input.genome}| wc -m`;\n"
        "genomeSAindexNbases=`python3 -c \"import math; print(min(14, int(math.log2($genomeLength)/2-1)))\"`;\n"
        "mkdir -p {output.index_dir};\n"
        "STAR --runMode genomeGenerate --runThreadN {threads} --genomeSAindexNbases $genomeSAindexNbases --genomeDir {output.index_dir} --genomeFastaFiles {input.genome} "
        "{params.gff_file}{params.gff_parent_keyword}{params.gff_id_keyword}{params.gff_feature_type} %%ADDITIONAL_STAR_INDEX_OPTIONS%% 2>&1 |"
        "tee {log}; echo \"genomeSAindexNbases: $genomeSAindexNbases\" > \"mapping/index_parameters.txt\";"


rule star_mapping:
    input:
        genome="%%GENOME_FASTA%%",
        genome_index_dir="mapping/reference_index/",
	    reads="preprocessing/{sample}.fastq.gz"
    output:
        mapping="mapping/raw_star_output/{sample}_Aligned.sortedByCoord.out.bam",
        stats="mapping/raw_star_output/{sample}_Log.final.out"
    params:
        output_prefix="mapping/raw_star_output/{sample}_"
    log:
        "mapping/logs/star_mapping.{sample}.log"
    conda:
        "../lib/conda_env.yaml"
    threads:
        4
    shell:
        """
        STAR --runThreadN {threads} --readFilesCommand "gunzip -c" --genomeDir {input.genome_index_dir} --readFilesIn {input.reads}\
        --outFileNamePrefix {params.output_prefix} --outSAMtype BAM SortedByCoordinate --outBAMsortingThreadN {threads} --outReadsUnmapped Fastx \
        %%ADDITIONAL_STAR_MAPPING_OPTIONS%% 2>&1 | tee {log}
        """


rule sam_to_bam:
    input:
        "mapping/raw_star_output/{sample}_Aligned.sortedByCoord.out.bam"
    output:
        bam="mapping/{sample}.bam",
        bam_unmapped="mapping/unmapped/{sample}_unmapped.bam"
    threads:
        4
    shell:
        "samtools view -F 4 -Shb {input} | samtools sort -@ {threads} -o {output.bam} -;"
        "samtools view -f 4 -Shb {input} | samtools sort -@ {threads} -o {output.bam_unmapped} -"

rule bam_index:
    input:
        bam="mapping/{sample}.bam"
    output:
        csi="mapping/{sample}.bam.csi",
    threads:
        4
    shell:
        "samtools index -c {input.bam}"


rule write_settings:
    output:
        settings="mapping/settings.yaml"
    params:
        ref_genome=lambda wildcards: "%%GENOME_FASTA%%".split("/")[-1],
        genome_annotation=lambda wildcards: "-" if len("%%GENOME_ANNOTATION%%") == 0 else "%%GENOME_ANNOTATION%%".split("/")[-1],
        gff_parent_keyword=lambda wildcards: "-" if len("%%GFF_PARENT_KEYWORD%%") == 0 else "%%GFF_PARENT_KEYWORD%%",
        gff_id_keyword=lambda wildcards: "-" if len("%%GFF_ID_KEYWORD%%") == 0 else "%%GFF_ID_KEYWORD%%",
        gff_feature_type=lambda wildcards: "-" if len("%%GFF_FEATURE_TYPE%%") == 0 else "%%GFF_FEATURE_TYPE%%",
        additional_index_parameter=lambda wildcards: "-" if len("%%ADDITIONAL_STAR_INDEX_OPTIONS%%") == 0 else "%%ADDITIONAL_STAR_INDEX_OPTIONS%%",
        additional_mapping_parameter=lambda wildcards: "-" if len("%%ADDITIONAL_STAR_MAPPING_OPTIONS%%") == 0 else "%%ADDITIONAL_STAR_MAPPING_OPTIONS%%"
    conda:
        "../lib/conda_env.yaml"
    shell:
        """
        set +e
        star_version=$(STAR --version 2>&1);
        echo "star_version: \\"$star_version\\"" > {output.settings};
        echo 'reference_genome: "{params.ref_genome}"' >> {output.settings};
        echo 'genome_annotation: "{params.genome_annotation}"' >> {output.settings};
        echo 'annotation_parent_keyword: "{params.gff_parent_keyword}"' >> {output.settings}; 
        echo 'annotation_id_keyword: "{params.gff_id_keyword}"' >> {output.settings}; 
        echo 'annotation_feature_type: "{params.gff_feature_type}"' >> {output.settings}; 
        echo 'additional_star_index_parameters: "{params.additional_index_parameter}"' >> {output.settings}; 
        echo 'additional_star_mapping_parameters: "{params.additional_mapping_parameter}"' >> {output.settings}; 
        """

rule generate_report_data:
    input:
        stats_tsv="mapping/stats/mapping_stats.tsv",
        settings="mapping/settings.yaml"
    output:
        star_data=".report/data/star_data.js",
        star_html=".report/modules/star.html",
        star_js=".report/js/modules/star.js",
        star_css=".report/css/modules/star.css",
    conda:
        "../lib/conda_env.yaml"
    shell:
        "python3 lib/generate_report_data.py --stats {input.stats_tsv} --output {output.star_data} --settings {input.settings} && "
        "cp lib/report/star.html {output.star_html} && "
        "cp lib/report/star.js {output.star_js} && "
        "cp lib/report/star.css {output.star_css}"