from os import listdir
from os.path import isfile

def getConditions():
    return set([seqRun['dge_analysis']['condition'] for seqRun in config['entries'].values()])

def list_of_all_features(gff_path):
  features = set()
  if gff_path.endswith(".gz"):
    f2r = gzip.open(gff_path, "rt")
  else:
    f2r = open(gff_path, "r")

  for line in f2r.readlines():
    if line.startswith("#"): #skip lines starting with #
      continue
    if not line or line.isspace():
      continue
    features.add(line.strip().split("\t")[2]) #feature
  f2r.close()

  # Features to check for either user selected or Curare pre-selected
  if "%%ADDITIONAL_FEATCOUNTS_TABLES%%":
    features = features.intersection("%%ADDITIONAL_FEATCOUNTS_TABLES%%".split(","))
  else: 
    allowed_features = ["CDS", "exon", "intron", "gene", "mRNA", "tRNA", "rRNA", "ncRNA", "operon", "snRNA", "snoRNA",
        "miRNA", "pseudogene", "small regulatory ncRNA", "rasiRNA", "guide RNA", "siRNA", "stRNA", "sRNA"]
    features = features.intersection(allowed_features)

  return(list(features))

def get_main_feature_name():
    use_parent_as_id = %%USE_PARENT_INSTEAD_OF_ID%%
    if use_parent_as_id:
        return "%%GFF_FEATURE_PARENT%%"
    else:
        return "%%GFF_FEATURE_NAME%%"


localrules: dge_analysis__create_conditions


rule all:
    input:
        "analysis/dge_analysis/counts.txt",
        "analysis/dge_analysis/counts_normalized.txt",
        "analysis/dge_analysis/deseq2.RData",
        'analysis/dge_analysis/visualization/feature_assignments',
        expand('analysis/dge_analysis/summary/{COND}.xlsx', COND=getConditions()),
        "analysis/dge_analysis/genexvis_conditions.txt",
        ".report/modules/dge_analysis.html"

rule summary_tsv_to_xslx:
    input:
        'analysis/dge_analysis/summary/{COND}.tsv'
    output:
        'analysis/dge_analysis/summary/{COND}.xlsx'
    conda:
        "../lib/conda_env.yaml"
    params:
        number_conditions = len(getConditions())-1
    group:
        "dge_analysis"
    shell:
        "python3 lib/deseq2_summary_tsv_to_xlsx.py --tsv \"{input}\" --conditions \"{params.number_conditions}\" --gff \"%%GFF_PATH%%\" --identifier \"%%GFF_FEATURE_NAME%%\" "
        "--attributes \"%%ATTRIBUTE_COLUMNS%%\" --output \"{output}\""

rule make_count_tables:
    input:
        mappings=expand("mapping/{A}.bam", A=config['entry_order'])
    output:
        featcounts="analysis/dge_analysis/count_tables/{feature}.txt",
        featcounts_summary="analysis/dge_analysis/count_tables/{feature}.txt.summary"
    params:
        output="analysis/dge_analysis/count_tables/"
    conda:
        "../lib/conda_env.yaml"
    group:
        "count_reads"
    threads:
        1
    log:
        log="analysis/dge_analysis/logs/count_tables/{feature}.log"
    shell:
        "featureCounts -p --countReadPairs -T {threads} %%STRAND_SPECIFICITY%% %%ADDITIONAL_FEATCOUNTS_OPTIONS%% -t {wildcards.feature} -g '%%GFF_FEATURE_NAME%%' -a %%GFF_PATH%% -o '{output.featcounts}' {input.mappings} 2>&1 | "
        "tee {log.log};"

rule collect_count_tables:
    input:
        count_tables=expand("analysis/dge_analysis/count_tables/{feature}.txt", feature = list_of_all_features("%%GFF_PATH%%"))
    output:
        "analysis/dge_analysis/count_tables/.count_tables"
    log:
        log="analysis/dge_analysis/logs/count_tables/collect_count_tables.log"
    group:
        "count_reads"
    shell:
        "for i in {input.count_tables}; do echo $i >> {output}; done"

rule visualize_assignments:
    input:
        "analysis/dge_analysis/count_tables/.count_tables"
    output:
        directory('analysis/dge_analysis/visualization/feature_assignments')
    params:
        input="analysis/dge_analysis/count_tables/"
    conda:
        "../lib/conda_env.yaml"
    shell:
        "mkdir -p {output}; python3 lib/visualize_assignments.py -i {params.input} -o {output}"

rule count_reads:
    input:
        mappings=expand("mapping/{A}.bam", A=config['entry_order'])
    output:
        counts = "analysis/dge_analysis/counts.txt",
        summary = "analysis/dge_analysis/counts.txt.summary"
    params:
        feature_name = get_main_feature_name()
    conda:
        "../lib/conda_env.yaml"
    group:
        "count_reads"
    threads:
        1
    log:
        log="analysis/dge_analysis/logs/count_tables/counts.log"
    shell:
        "featureCounts -p --countReadPairs -T {threads} %%STRAND_SPECIFICITY%% -t %%GFF_FEATURE_TYPE%% -g {params.feature_name} -a %%GFF_PATH%% %%ADDITIONAL_FEATCOUNTS_OPTIONS%% -o '{output.counts}' {input.mappings} 2>&1 | "
        "tee {log.log};"

rule create_conditions:
    output:
        "analysis/dge_analysis/conditions.txt"
    run:
        with(open(output[0], 'w')) as output_file:
            for name in config['entry_order']:
                output_file.write('{}\t{}\n'.format(name, config['entries'][name]['dge_analysis']['condition']))


rule genexvis_condition_file:
    output:
        "analysis/dge_analysis/genexvis_conditions.txt"
    run:
        conditions = set()
        for name in config['entry_order']:
            conditions.add(config['entries'][name]['dge_analysis']['condition'])
        with(open(output[0], 'w')) as output_file:
            for cond in conditions:
                output_file.write('{}\n'.format(cond))

rule dge_analysis_normalize_counts:
    input:
        count_table = "analysis/dge_analysis/counts.txt",
        conditions = "analysis/dge_analysis/conditions.txt",
        feature_counts_log = "analysis/dge_analysis/counts.txt.summary"
    output:
        dump="analysis/dge_analysis/deseq2.RData",
        correlation_heatmap="analysis/dge_analysis/visualization/correlation_heatmap.svg",
        pca="analysis/dge_analysis/visualization/pca.svg",
        assignment_statistics_rel="analysis/dge_analysis/visualization/counts_assignment_relative.svg",
        assignment_statistics_abs="analysis/dge_analysis/visualization/counts_assignment_absolute.svg",
        counts_normalized="analysis/dge_analysis/counts_normalized.txt",
    params:
        vis_dir='analysis/dge_analysis/visualization'
    conda:
        "../lib/conda_env.yaml"
    group:
        "dge_analysis"
    log:
        log="analysis/dge_analysis/logs/deseq2_normalize_counts.log"
    threads:
        1
    shell:
        "R --vanilla --file=lib/deseq2_analysis_normalize_counts.R --args --threads {threads} --count-table {input.count_table} --conditions {input.conditions} --output-vis {params.vis_dir} --output-count {output.counts_normalized} --r-data {output.dump} "
        "--featcounts-log {input.feature_counts_log} 2>&1 | tee -a {log.log}"

rule deseq2_create_comparison_files:
    input:
        dump = "analysis/dge_analysis/deseq2.RData",
        feature_counts_log = "analysis/dge_analysis/counts.txt.summary"
    output:
        summary=expand('analysis/dge_analysis/summary/{COND}.tsv', COND=getConditions()),
        comparisons=directory("analysis/dge_analysis/deseq2_comparisons")
    conda:
        "../lib/conda_env.yaml"
    group:
        "dge_analysis"
    params:
        output_folder="analysis/dge_analysis/"
    log:
        log="analysis/dge_analysis/logs/deseq2.log"
    threads:
        1
    shell:
        "R --vanilla --file=lib/deseq2_analysis_create_comparison_files_and_summary.R --args --threads {threads} --output {params.output_folder} --r-data {input.dump} "
        "--featcounts-log {input.feature_counts_log} 2>&1 | tee -a {log.log}"

rule generate_report_data:
    input:
        stats="analysis/dge_analysis/counts.txt.summary",
        comparisons="analysis/dge_analysis/deseq2_comparisons",
        img_assignment_rel="analysis/dge_analysis/visualization/counts_assignment_relative.svg",
        img_assignment_abs="analysis/dge_analysis/visualization/counts_assignment_absolute.svg",
        img_pca="analysis/dge_analysis/visualization/pca.svg",
        correlation_heatmap="analysis/dge_analysis/visualization/correlation_heatmap.svg",
        feature_assignment="analysis/dge_analysis/visualization/feature_assignments",
        count_table="analysis/dge_analysis/counts.txt"
    output:
        dge_analysis_data=".report/data/dge_analysis_data.js",
        dge_analysis_html=".report/modules/dge_analysis.html",
        dge_analysis_js=".report/js/modules/dge_analysis.js",
        dge_analysis_img_assignment_rel=".report/img/modules/dge_analysis/counts_assignment_relative.svg",
        dge_analysis_img_assignment_abs=".report/img/modules/dge_analysis/counts_assignment_absolute.svg",
        dge_analysis_img_pca='.report/img/modules/dge_analysis/pca.svg',
        dge_analysis_img_correlation='.report/img/modules/dge_analysis/correlation_heatmap.svg',
        dge_analysis_img_feature_assignment=directory('.report/img/modules/dge_analysis/feature_assignment')
    params:
        visualization="analysis/dge_analysis/visualization"
    conda:
        "../lib/conda_env.yaml"
    group:
        "dge_analysis_report"
    shell:
        "python3 lib/generate_report_data.py --fc_stats {input.stats} --fc_main_feature '%%GFF_FEATURE_TYPE%%' --comparison_dir {input.comparisons} --visualization {params.visualization} --output {output.dge_analysis_data} --counttable '{input.count_table}' --paired-end && "
        "cp lib/report/dge_analysis.html {output.dge_analysis_html} && "
        "cp lib/report/dge_analysis.js {output.dge_analysis_js} &&"
        "cp {input.img_assignment_rel} {output.dge_analysis_img_assignment_rel} &&"
        "cp {input.img_assignment_abs} {output.dge_analysis_img_assignment_abs} &&"
        "cp {input.img_pca} {output.dge_analysis_img_pca} &&"
        "cp {input.correlation_heatmap} {output.dge_analysis_img_correlation} &&"
        "mkdir -p {output.dge_analysis_img_feature_assignment} && cp {input.feature_assignment}/*.svg {output.dge_analysis_img_feature_assignment}"
