import hashlib
import os
import shutil
import subprocess
import sys

from snakemake.logging import logger


def get_conda_env():
    """
    Retrieves the file path of the conda definitions included in the hundo
    Python package.
    """
    if config.get("environment"):
        yml = config.get("environment")
    else:
        yml = os.path.join(
            os.path.dirname(os.path.abspath(workflow.snakefile)), "environment.yml"
        )
    if not os.path.exists(yml):
        sys.exit("Unable to locate the environmental definitions file; tried %s" % yml)
    return yml


def _make_gen(reader):
    """https://stackoverflow.com/a/27518377"""
    b = reader(1024 * 1024)
    while b:
        yield b

        b = reader(1024 * 1024)


def count_sequences(filename):
    """
    Used to count sequences within files of very low size. If protocol starts
    with a file that does not have reads that will pass the sample_group
    stage, the protocol will fail as a subsequent rule will not have an input
    file.

    Source: https://stackoverflow.com/a/27518377
    """
    if filename.endswith(".gz"):
        import gzip

        f = gzip.open(filename, "rb")
        f_gen = _make_gen(f.read)
    else:
        f = open(filename, "rb")
        f_gen = _make_gen(f.raw.read)
    return sum(buf.count(b'\n') for buf in f_gen) / 4


def get_sample_files(fastq_dirs, prefilter_file_size):
    """
    Looks into folders and file wildcard patterns to grab all relevant files
    and their paired sequence. Sample names are determined from the file name.
    """
    if not fastq_dirs:
        logger.error(
            (
                "'fastq_dir' has not been set -- this directory "
                "should contain your input FASTQs"
            )
        )
        sys.exit(1)
    if os.path.isfile(fastq_dirs):
        logger.error(
            (
                "'fastq_dir' must be a directory, a file pattern, or a comma "
                "separated list of both of those things "
                "e.g. '*.fastq', with single quotes surrounding the pattern."
            )
        )
        sys.exit(1)
    # grab all of the possible fastq file paths
    # user is permitted to send a single path, a single pattern, or a
    # comma separated list of paths and patterns
    raw_fastq_files = list()
    for fastq_dir in fastq_dirs.split(","):
        if "*" in fastq_dir:
            from glob import glob
            logger.info("Finding samples matching %s" % fastq_dir)
            raw_fastq_files.extend(glob(fastq_dir))
        else:
            logger.info("Finding samples in %s" % fastq_dir)
            raw_fastq_files.extend([os.path.join(fastq_dir, i) for i in os.listdir(fastq_dir)])
    samples = dict()
    seen = set()
    omitted = dict()
    for fq_path in raw_fastq_files:
        fname = os.path.basename(fq_path)
        fastq_dir = os.path.dirname(fq_path)
        if not ".fastq" in fname and not ".fq" in fname:
            continue

        sample_id = fname.partition(".fastq")[0]
        if ".fq" in sample_id:
            sample_id = fname.partition(".fq")[0]
        sample_id = sample_id.replace("_R1", "").replace("_r1", "").replace(
            "_R2", ""
        ).replace(
            "_r2", ""
        )
        sample_id = sample_id.replace(".", "_").replace(" ", "_").replace("-", "_")
        # sample ID after rename
        if sample_id in seen:
            # but FASTQ has yet to be added
            # if one sample has a dash and another an underscore, this
            # is a case where we should warn the user that this file
            # is being skipped
            if not fq_path in seen:
                logger.warning(
                    "Duplicate sample %s was found after renaming; skipping..." %
                    sample_id
                )
            continue

        # which read index do we have of the pair?
        if "_R1" in fname or "_r1" in fname:
            idx = "r1"
            # simple replace of right-most read index designator
            if fname.find("_R1") > fname.find("_r1"):
                other_idx = os.path.join(fastq_dir, "_R2".join(fname.rsplit("_R1", 1)))
            else:
                other_idx = os.path.join(fastq_dir, "_r2".join(fname.rsplit("_r1", 1)))
        elif "_R2" in fname or "_r2" in fname:
            idx = "r2"
            if fname.find("_R2") > fname.find("_r2"):
                other_idx = os.path.join(fastq_dir, "_R1".join(fname.rsplit("_R2", 1)))
            else:
                other_idx = os.path.join(fastq_dir, "_r1".join(fname.rsplit("_r2", 1)))
        else:
            logger.error(
                "Unable to determine read index for [%s] as it is missing '_R1' or '_R2' designation. Exiting." %
                fname
            )
            sys.exit(1)
        # naming convention not followed and nothing was renamed
        if other_idx == fq_path:
            logger.error("Unable to locate matching index for %s. Exiting." % fq_path)
            sys.exit(1)
        # not paired-end?
        if not os.path.exists(other_idx):
            logger.error(
                "File [%s] for %s was not found. Exiting." % (other_idx, sample_id)
            )
            sys.exit(1)
        seen.add(fq_path)
        seen.add(other_idx)
        seen.add(sample_id)
        # omissions
        if os.path.getsize(fq_path) < prefilter_file_size or os.path.getsize(
            other_idx
        ) < prefilter_file_size:
            c = count_sequences(fq_path)
            logger.info(
                (
                    "%s is being omitted from analysis due to its "
                    "file size (%d paired sequences)."
                ) %
                (sample_id, c)
            )
            omitted["%s **OMITTED**" % sample_id] = c
            fastq_paths = None
        else:
            if idx == "r1":
                fastq_paths = {"r1": fq_path, "r2": other_idx}
            else:
                fastq_paths = {"r1": other_idx, "r2": fq_path}
        if fastq_paths:
            samples[sample_id] = fastq_paths
    # logger.info("\nFound %d samples for processing:\n" % len(samples))
    return samples, omitted


def get_cluster_sequences_input(wildcards):
    return "chimera-filtered_denovo.fasta" if config.get(
        "denova_chimera_filter"
    ) else "dereplicated-sequences.fasta"


def get_run_reference_chimera_filter_inputs(wildcards):
    if config.get("reference_chimera_filter") is True:
        db = REFFASTA
    else:
        db = config.get("reference_chimera_filter")
    if config.get("denovo_chimera_filter", True):
        fasta = "chimera-filtered_denovo.fasta"
    elif float(config.get("read_identity_requirement", 0.97)) * 100 < 99:
        fasta = "preclustered.fasta"
    else:
        fasta = "dereplicated-sequences.fasta"
    return {"fasta": fasta, "db": db}


def filter_fastas(fasta1, uclust, fasta2, output):
    """
    fasta1 is the larger set of reads to be filtered
    by read IDs in uclust and fasta2
    """
    accepted_reads = set()
    with open(fasta2) as fh:
        for line in fh:
            if line.startswith(">"):
                name = line.strip().strip(">").partition(";")[0]
                accepted_reads.add(name)
    with open(uclust) as fh:
        for line in fh:
            toks = line.strip().split("\t")
            if toks[9] == "*":
                continue

            if toks[8] == "*":
                continue

            a = toks[8].partition(";")[0]
            b = toks[9].partition(";")[0]
            if b in accepted_reads:
                accepted_reads.add(a)
    with open(fasta1) as fh, open(output, "w") as fo:
        keep = False
        for line in fh:
            line = line.strip()
            if line.startswith(">"):
                name = line.strip(">").partition(";")[0]
                if name in accepted_reads:
                    keep = True
                else:
                    keep = False
            if keep:
                print(line, file=fo)


def md5(fname):
    # https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file
    hash_md5 = hashlib.md5()
    if not os.path.exists(fname):
        return None

    with open(fname, "rb") as f:
        for chunk in iter( lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


PROTOCOL_VERSION = subprocess.check_output("hundo --version", shell=True).decode("utf-8").strip()
if config.get("workflow") == "download":
    SAMPLES, OMITTED = {"placeholder": {"r1": "placeholder", "r2": "placeholder"}}
else:
    SAMPLES, OMITTED = get_sample_files(
        config.get("fastq_dir"), config.get("prefilter_file_size", 100000)
    )
CONDAENV = get_conda_env()
REFERENCES = {
    "greengenes.fasta.nhr": "200f1b4356e59be524525e4ca83cefc1",
    "greengenes.fasta.nin": "f41470ebc95c21f13b54f7b127ebe2ad",
    "greengenes.fasta.nsq": "150b00866a87f44c8b33871be7cc6b98",
    "greengenes.map": "cb49d30e2a00476f8bdaaa3eaec693ef",
    "greengenes.tre": "b3a369edde911bf0aa848a842f736fce",
    "silvamod128.fasta.nhr": "99d9b6817a9c6a0249fbb32a60e725ae",
    "silvamod128.fasta.nin": "25d8d6587123aa7dbaabaa1299e481cc",
    "silvamod128.fasta.nsq": "7e12efc4ab9019c9d79809861922a001",
    "silvamod128.map": "00500c7f215a89a3240c907af4f75f33",
    "silvamod128.tre": "340864d9d32e8561f75ec0b0a9a6c3d1",
    "unite.fasta.nhr": "3ce60ad0d3eee81246c6915e64bf3122",
    "unite.fasta.nin": "962c9364b6cb3da2b9457546838bb6fb",
    "unite.fasta.nsq": "796e7338cdbb3d659951e380fd2b8104",
    "unite.map": "11984b9f45fbac120824de516a133e45",
    "unite.tre": "d50a5f427c22964906f2d6fc1b7cf220",
}
REF = [k for k in REFERENCES.keys() if k.startswith(config.get("reference_database", "silva"))]
BLASTREF = [os.path.join(config.get("database_dir", "."), i) for i in REF]
REFFASTA = [os.path.join(config.get("database_dir", "."), i) for i in REF if ".fasta" in i][0].rpartition(".")[0]
REFMAP = [os.path.join(config.get("database_dir", "."), i) for i in REF if i.endswith(".map")][0]
REFTRE = [os.path.join(config.get("database_dir", "."), i) for i in REF if i.endswith(".tre")][0]


localrules: all
rule all:
    input:
        "summary.html",
        "PROTOCOL-VERSION.txt"


localrules: print_protocol_version
rule print_protocol_version:
    output:
        "PROTOCOL-VERSION.txt"
    shell:
        "hundo --version > {output}"


localrules: print_samples
rule print_samples:
    output:
        "SAMPLES.txt"
    run:
        with open(output[0], "w") as fh:
            for k, v in SAMPLES.items():
                print("%s: %s; %s" % (k, v["r1"], v["r2"]), file=fh)
            print("\n", file=fh)
            for k, v in OMITTED.items():
                print("%s: %d total sequences" % (k, v), file=fh)


rule download_reference_data:
    output:
        "%s/{filename}" % config.get("database_dir", ".")
    run:
        shell("curl 'https://zenodo.org/record/1043977/files/{wildcards.filename}' -s > {output}")
        if not REFERENCES[os.path.basename(output[0])] == md5(output[0]):
            raise OSError(2, "Invalid checksum", output[0])


rule create_fasta_from_reference:
    input:
        BLASTREF
    output:
        REFFASTA
    params:
        db = lambda wildcards, input: os.path.join(os.path.dirname(input[0]), "%s.fasta" % os.path.basename(input[0]).partition(".")[0])
    conda:
        CONDAENV
    shell:
        """blastdbcmd -db {params.db} -outfmt %f -entry all -out {output}"""


rule get_raw_r1_fastq_qualities:
    input:
        unpack(lambda wc: SAMPLES[wc.sample])
    output:
        r1 = "logs/{sample}_R1_eestats.txt",
    conda:
        CONDAENV
    threads:
        1
    group:
        "sample_group"
    shell:
        """
        vsearch --threads {threads} --fastq_eestats {input.r1} --output {output.r1}
        """


rule get_raw_r2_fastq_qualities:
    input:
        unpack(lambda wc: SAMPLES[wc.sample])
    output:
        r2 = "logs/{sample}_R2_eestats.txt",
    conda:
        CONDAENV
    threads:
        1
    group:
        "sample_group"
    shell:
        """
        vsearch --threads {threads} --fastq_eestats {input.r2} --output {output.r2}
        """


rule trim_and_filter_reads:
    input:
        unpack(lambda wc: SAMPLES[wc.sample])
    output:
        r1 = temp("tmp/{sample}_R1.fastq"),
        r2 = temp("tmp/{sample}_R2.fastq"),
        stats = "logs/{sample}_quality_filtering_stats.txt"
    benchmark:
        "logs/benchmarks/quality_filter/{sample}.txt"
    params:
        lref = "lref=%s" % config.get("filter_adapters") if config.get("filter_adapters") else "",
        rref = "rref=%s" % config.get("filter_adapters") if config.get("filter_adapters") else "",
        fref = "fref=%s" % config.get("filter_contaminants") if config.get("filter_contaminants") else "",
        mink = "" if not config.get("filter_adapters") and not config.get("filter_contaminants") else "mink=%d" % config.get("reduced_kmer_min", 8),
        trimq = config.get("minimum_base_quality", 10),
        hdist = "" if not config.get("filter_adapters") and not config.get("filter_contaminants") else "hdist=%d" % config.get("allowable_kmer_mismatches", 1),
        k = "" if not config.get("filter_adapters") and not config.get("filter_contaminants") else "k=%d" % config.get("reference_kmer_match_length", 27),
        qtrim = config.get("qtrim", "rl"),
        minlength = config.get("minimum_passing_read_length", 100)
    conda:
        CONDAENV
    log:
        "logs/{sample}_quality_filter.log"
    threads:
        max(int(config.get("threads", 4) * 0.25), 1)
    group:
        "sample_group"
    shell:
        """
        bbduk2.sh in={input.r1} in2={input.r2} out={output.r1} \
            out2={output.r2} {params.rref} {params.lref} {params.fref} \
            {params.mink} qout=33 stats={output.stats} {params.hdist} \
            {params.k} trimq={params.trimq} qtrim={params.qtrim} \
            threads={threads} minlength={params.minlength} \
            overwrite=true 2> {log}
        """


localrules: count_raw_reads
rule count_raw_reads:
    input:
        "logs/{sample}_quality_filtering_stats.txt"
    output:
        "logs/raw_counts/{sample}_R1.count"
    # unsure if this will work
    group:
        "sample_group"
    run:
        with open(input[0]) as filein, open(output[0], "w") as fileout:
            for line in filein:
                if line.startswith("#Total"):
                    total = int(line.strip().split("\t")[1]) / 2.
                    print(total, file=fileout)
                    break


rule count_filtered_reads:
    input:
        "tmp/{sample}_R1.fastq"
    output:
        "logs/filtered_counts/{sample}_R1.count"
    group:
        "sample_group"
    shell:
        "awk '{{n++}}END{{print int(n/4)}}' {input} > {output}"


rule merge_reads:
    input:
        r1 = "tmp/{sample}_R1.fastq",
        r2 = "tmp/{sample}_R2.fastq"
    output:
        temp("tmp/{sample}_merged.fastq")
    params:
        minimum_merge_length = config.get("minimum_merge_length", 150),
        fastq_allowmergestagger = "--fastq_allowmergestagger" if config.get("fastq_allowmergestagger", False) else "",
        fastq_maxdiffs = config.get("fastq_maxdiffs", 5),
        fastq_minovlen = config.get("fastq_minovlen", 16)
    conda:
        CONDAENV
    log:
        "logs/{sample}_merge_reads.log"
    group:
        "sample_group"
    shell:
        """
        vsearch --fastq_mergepairs {input.r1} --reverse {input.r2} \
            --label_suffix \;sample={wildcards.sample}\; \
            --fastq_minmergelen {params.minimum_merge_length} \
            {params.fastq_allowmergestagger} \
            --fastq_maxdiffs {params.fastq_maxdiffs} \
            --fastq_minovlen {params.fastq_minovlen} \
            --fastqout {output} --log {log}
        """


rule count_merged_reads:
    input:
        "tmp/{sample}_merged.fastq"
    output:
        "logs/merged_counts/{sample}_R1.count"
    group:
        "sample_group"
    shell:
        "awk '{{n++}}END{{print int(n/4)}}' {input} > {output}"


rule get_prefilter_fastq_stats:
    input:
        "tmp/{sample}_merged.fastq"
    output:
        "logs/{sample}_merged_eestats.txt"
    conda:
        CONDAENV
    threads:
        1
    group:
        "sample_group"
    shell:
        """
        vsearch --threads {threads} --fastq_eestats {input} \
            --output {output}
        """


rule filter_merged_reads:
    input:
        "tmp/{sample}_merged.fastq"
    output:
        fa = temp("tmp/{sample}_merged_filtered.fasta"),
        fq = temp("tmp/{sample}_merged_filtered.fastq")
    params:
        maxee = config.get("maximum_expected_error", 1),
        maxns = config.get("vsearch_maxns", 0)
    conda:
        CONDAENV
    log:
        "logs/{sample}_fastq_filter.log"
    group:
        "sample_group"
    shell:
        """
        vsearch --fastq_filter {input} --fastq_maxee {params.maxee} \
            --fastq_maxns {params.maxns} --fastaout {output.fa} \
            --fastqout {output.fq}
        """


rule get_postfilter_fastq_stats:
    input:
        "tmp/{sample}_merged_filtered.fastq"
    output:
        "logs/{sample}_merged_filtered_eestats.txt"
    conda:
        CONDAENV
    threads:
        1
    group:
        "sample_group"
    shell:
        """
        vsearch --threads {threads} --fastq_eestats {input} \
            --output {output}
        """


rule dereplicate_sequences_by_sample:
    input:
        "tmp/{sample}_merged_filtered.fasta"
    output:
        temp("tmp/{sample}_dereplicated.fasta")
    conda:
        CONDAENV
    log:
        "logs/{sample}_dereplication.log"
    group:
        "sample_group"
    shell:
        """
        vsearch --derep_fulllength {input} --output {output} \
            --strand plus --sizeout --relabel {wildcards.sample}.
        """


localrules: combine_merged_reads
rule combine_merged_reads:
    input:
        fastas = expand("tmp/{sample}_dereplicated.fasta", sample=SAMPLES)
    output:
        fasta = "all-sequences.fasta"
    run:
        import shutil

        with open(output.fasta, 'wb') as ofh:
            for f in input.fastas:
                with open(f, 'rb') as ifh:
                    shutil.copyfileobj(ifh, ofh, 1024*1024*10)


rule dereplicate_sequences:
    input:
        "all-sequences.fasta"
    output:
        fa = temp("dereplicated-sequences.fasta"),
        uc = temp("dereplicated.uclust")
    params:
        minuniquesize = config.get("vsearch_minuniquesize", 2)
    conda:
        CONDAENV
    log:
        "logs/dereplicated-sequences.log"
    threads:
        config.get("threads", 1)
    group:
        "combined_group"
    shell:
        """
        vsearch --derep_fulllength {input} --output {output.fa} \
            --sizein --sizeout --minuniquesize {params.minuniquesize} \
            --threads {threads} -log {log} --uc {output.uc}
        """


rule precluster_sequences:
    input:
        "dereplicated-sequences.fasta"
    output:
        fa = temp("preclustered.fasta"),
        uc = temp("preclustered.uclust")
    params:
        id = min((100 - float(config.get("percent_of_allowable_difference", 3))) / 100, 1.00)
    conda:
        CONDAENV
    log:
        "logs/precluster.log"
    threads:
        config.get("threads", 1)
    group:
        "combined_group"
    shell:
        """
        vsearch --threads {threads} --cluster_size {input} \
            --id {params.id} --strand plus --sizein --sizeout \
            --centroids {output.fa} --uc {output.uc}
        """


rule run_denovo_chimera_filter:
    input:
        "preclustered.fasta"
    output:
        temp("chimera-filtered_denovo.fasta")
    conda:
        CONDAENV
    log:
        "logs/chimera-filtered_denovo.log"
    threads:
        config.get("threads", 1)
    group:
        "combined_group"
    shell:
        """
        vsearch --uchime_denovo {input} --nonchimeras {output} \
            --strand plus --sizein --sizeout --threads {threads} \
            --log {log}
        """


rule run_reference_chimera_filter:
    input:
        unpack(get_run_reference_chimera_filter_inputs)
    output:
        temp("chimera-filtered_reference.fasta")
    conda:
        CONDAENV
    log:
        "logs/chimera-filtered_reference.log"
    threads:
        config.get("threads", 1)
    group:
        "combined_group"
    shell:
        """
        vsearch --uchime_ref {input.fasta} --nonchimeras {output} \
            --strand plus --sizein --sizeout --db {input.db} \
            --threads {threads} --log {log}
        """


rule extract_filtered_sequences:
    # Extract all non-chimeric, non-singleton sequences, dereplicated
    input:
        filter_fa = "dereplicated-sequences.fasta",
        uc = "preclustered.uclust",
        keep_fa = "chimera-filtered_reference.fasta"
    output:
        temp("nonchimeric-nonsingleton.fasta")
    group:
        "combined_group"
    run:
        filter_fastas(input.filter_fa, input.uc, input.keep_fa, output[0])


rule pull_seqs_from_samples_with_filtered_sequences:
    # Extract all non-chimeric, non-singleton sequences in each sample
    input:
        filter_fa = "all-sequences.fasta",
        uc = "dereplicated.uclust",
        keep_fa = "nonchimeric-nonsingleton.fasta"
    output:
        temp("all-sequences_filtered.fasta")
    group:
        "combined_group"
    run:
        filter_fastas(input.filter_fa, input.uc, input.keep_fa, output[0])


rule cluster_sequences:
    input:
        # get_cluster_sequences_input
        "all-sequences_filtered.fasta"
    output:
        fa = "OTU.fasta"
    params:
        minsize = config.get("minimum_sequence_abundance", 2),
        otu_id_pct = (100 - float(config.get("percent_of_allowable_difference", 3))) / 100.
    conda:
        CONDAENV
    log:
        "logs/cluster_sequences.log"
    threads:
        config.get("threads", 1)
    group:
        "combined_group"
    shell:
        """
        vsearch --cluster_size {input} --centroids {output.fa} \
            --relabel OTU_ --id {params.otu_id_pct} --log {log} \
            --sizein --strand plus --threads {threads}
        """


if config.get("aligner") == "blast":
    rule run_aligner:
        input:
            fasta = "OTU.fasta",
            # force download
            db = BLASTREF,
            db_fa = REFFASTA
        output:
            "local-hits.txt"
        params:
            # works due to file naming convention
            db = REFFASTA,
            num_alignments = config.get("blast_num_alignments", 25),
            id_req = config.get("min_id_req", 0.6),
            max_rej = config.get("max_reject", 100),
        conda:
            CONDAENV
        threads:
            config.get("threads", 1)
        shell:
            """
            blastn -query {input.fasta} -db {params.db} \
                -num_alignments {params.num_alignments} \
                -outfmt 6 -out {output} -num_threads {threads}
            """
# VSEARCH was chosen as the aligner
else:
    rule run_aligner:
        input:
            fasta = "OTU.fasta",
            db = BLASTREF,
            db_fa = REFFASTA
        output:
            "local-hits.txt"
        params:
            db = REFFASTA,
            num_alignments = config.get("blast_num_alignments", 25),
            id_req = config.get("min_id_req", 0.6),
            max_rej = config.get("max_reject", 100),
        conda:
            CONDAENV
        threads:
            config.get("threads", 1)
        shell:
            """
            vsearch --usearch_global {input.fasta} --db {params.db} \
                --maxaccepts {params.num_alignments} --threads {threads} \
                --maxrejects {params.max_rej} --strand plus \
                --id {params.id_req} --blast6out {output}
            """


rule compute_lca:
    input:
        fasta = "OTU.fasta",
        hits = "local-hits.txt",
        map = REFMAP,
        tre = REFTRE
    output:
        fasta = "OTU_tax.fasta",
        tsv = temp("OTU_assignments.tsv")
    params:
        min_score = config.get("blast_minimum_bitscore", 125),
        min_pid = config.get("min_pid", .85),
        aligner = config.get("aligner"),
        top_fraction = 0.99 if not config.get("aligner") == "blast" else config.get("blast_top_fraction", 0.95),
    group:
        "combined_group"
    shell:
        """
        hundo lca --min-score {params.min_score} --min-pid {params.min_pid} \
            --top-fraction {params.top_fraction} --aligner {params.aligner} \
            {input.fasta} {input.hits} {input.map} {input.tre} {output.fasta} \
            {output.tsv}
        """


rule compile_counts:
    input:
        seqs = "all-sequences.fasta",
        db = "OTU_tax.fasta"
    output:
        tsv = "OTU.txt"
    params:
        id_req = config.get("read_identity_requirement", 0.80),
        minqt = config.get("vsearch_minqt", 0.80)
    conda:
        CONDAENV
    threads:
        config.get("threads", 1)
    group:
        "combined_group"
    shell:
        """
        vsearch --usearch_global {input.seqs} \
            --db {input.db} --id {params.id_req} \
            --otutabout {output.tsv} --sizein --strand plus \
            --threads {threads}
        """


rule create_biom_from_txt:
    input:
        "OTU.txt"
    output:
        "OTU.biom"
    shadow:
        "shallow"
    conda:
        CONDAENV
    group:
        "combined_group"
    shell:
        '''
        sed 's|\"||g' {input} | sed 's|\,|\;|g' > OTU_converted.txt
        biom convert -i OTU_converted.txt -o {output} --to-json \
            --process-obs-metadata sc_separated --table-type "OTU table"
        '''


rule multiple_align:
    input:
        "OTU.fasta"
    output:
        "OTU_aligned.fasta"
    threads:
        # mafft (via conda) is single-threaded on macOS
        config.get("threads", 1)
    conda:
        CONDAENV
    group:
        "combined_group"
    shell:
        """
        mafft --auto --thread {threads} --quiet --maxiterate 1000 \
            --globalpair {input} > {output}
        """


rule newick_tree:
    input:
        "OTU_aligned.fasta"
    output:
        "OTU.tree"
    threads:
        config.get("threads", 1)
    conda:
        CONDAENV
    log:
        "logs/fasttree.log"
    group:
        "combined_group"
    shell:
        """
        export OMP_NUM_THREADS={threads} && \
            FastTreeMP -nt -gamma -spr 4 -log {log} -quiet {input} > {output}
        """


rule make_deliverables_archive:
    input:
        "OTU.biom", "OTU.fasta", "OTU.tree", "OTU.txt"
    output:
        "hundo_results.zip"
    conda:
        CONDAENV
    group:
        "combined_group"
    shell:
        "zip {output} {input}"


rule aggregate_run_parameters:
    output:
        txt = "PARAMS.txt"
    params:
        fastq_dir = config.get("fastq_dir"),
        filter_adapters = "None" if not config.get("filter_adapters") else config.get("filter_adapters"),
        filter_contaminants = "None" if not config.get("filter_contaminants") else config.get("filter_contaminants"),
        allowable_kmer_mismatches = config.get("allowable_kmer_mismatches"),
        reference_kmer_match_length = config.get("reference_kmer_match_length"),
        reduced_kmer_min = config.get("reduced_kmer_min"),
        minimum_passing_read_length = config.get("minimum_passing_read_length"),
        minimum_base_quality = config.get("minimum_base_quality"),
        minimum_merge_length = config.get("minimum_merge_length"),
        fastq_allowmergestagger = "True" if config.get("fastq_allowmergestagger", False) else "False",
        fastq_maxdiffs = config.get("fastq_maxdiffs", 5),
        fastq_minovlen = config.get("fastq_minovlen", 16),
        maximum_expected_error = config.get("maximum_expected_error"),
        reference_chimera_filter = config.get("reference_chimera_filter"),
        minimum_sequence_abundance = config.get("minimum_sequence_abundance"),
        percent_of_allowable_difference = config.get("percent_of_allowable_difference"),
        reference_database = config.get("reference_database"),
        blast_minimum_bitscore = config.get("blast_minimum_bitscore"),
        blast_top_fraction = config.get("blast_top_fraction"),
        read_identity_requirement = config.get("read_identity_requirement"),
        aligner = config.get("aligner")
    run:
        with open(output.txt, "w") as fh:
            print("fastq_dir:", params.fastq_dir, sep=" ", file=fh)
            print("filter_adapters:", params.filter_adapters, sep=" ", file=fh)
            print("filter_contaminants:", params.filter_contaminants, sep=" ", file=fh)
            print("allowable_kmer_mismatches:", params.allowable_kmer_mismatches, sep=" ", file=fh)
            print("reference_kmer_match_length:", params.reference_kmer_match_length, sep=" ", file=fh)
            print("reduced_kmer_min:", params.reduced_kmer_min, sep=" ", file=fh)
            print("minimum_passing_read_length:", params.minimum_passing_read_length, sep=" ", file=fh)
            print("minimum_base_quality:", params.minimum_base_quality, sep=" ", file=fh)
            print("minimum_merge_length:", params.minimum_merge_length, sep=" ", file=fh)
            print("fastq_allowmergestagger:", params.fastq_allowmergestagger, sep=" ", file=fh)
            print("fastq_maxdiffs:", params.fastq_maxdiffs, sep=" ", file=fh)
            print("fastq_minovlen:", params.fastq_minovlen, sep=" ", file=fh)
            print("maximum_expected_error:", params.maximum_expected_error, sep=" ", file=fh)
            print("reference_chimera_filter:", params.reference_chimera_filter, sep=" ", file=fh)
            print("minimum_sequence_abundance:", params.minimum_sequence_abundance, sep=" ", file=fh)
            print("percent_of_allowable_difference:", params.percent_of_allowable_difference, sep=" ", file=fh)
            print("reference_database:", params.reference_database, sep=" ", file=fh)
            print("aligner:", params.aligner, sep=" ", file=fh)
            print("blast_minimum_bitscore:", params.blast_minimum_bitscore, sep=" ", file=fh)
            print("blast_top_fraction:", params.blast_top_fraction, sep=" ", file=fh)
            print("read_identity_requirement:", params.read_identity_requirement, sep=" ", file=fh)


rule build_report:
    input:
        report_script = os.path.join(
            os.path.dirname(os.path.abspath(workflow.snakefile)),
            "scripts",
            "build_report.py"
        ),
        biom = "OTU.biom",
        txt = "OTU.txt",
        archive = "hundo_results.zip",
        params = "PARAMS.txt",
        env = CONDAENV,
        fastq_stats_post = expand("logs/{sample}_merged_filtered_eestats.txt",
            sample=SAMPLES.keys()
        ),
        raw_r1_quals = expand("logs/{sample}_R1_eestats.txt",
            sample=SAMPLES.keys()
        ),
        raw_r2_quals = expand("logs/{sample}_R2_eestats.txt",
            sample=SAMPLES.keys()
        ),
        raw_counts = expand("logs/raw_counts/{sample}_R1.count",
            sample=SAMPLES.keys()
        ),
        filtered_counts = expand("logs/filtered_counts/{sample}_R1.count",
            sample=SAMPLES.keys()
        ),
        merged_counts = expand("logs/merged_counts/{sample}_R1.count",
            sample=SAMPLES.keys()
        ),
        samples = "SAMPLES.txt"
    shadow:
        "shallow"
    params:
        omitted = "'%s'" % ",".join(["{sample}:{count}".format(sample=k, count=v) for k, v in OMITTED.items()]),
        author = "'%s'" % config.get("author", "hundo"),
        version = "'%s'" % PROTOCOL_VERSION
    output:
        html = "summary.html"
    conda:
        CONDAENV
    group:
        "combined_group"
    shell:
        """
        python {input.report_script} --biom {input.biom} --txt {input.txt} \
            --zip {input.archive} --env {input.env} --params {input.params} \
            --eeplot 'logs/*_merged_filtered_eestats.txt' \
            --r1quals 'logs/*_R1_eestats.txt' \
            --raw 'logs/raw_counts/*_R1.count' --omitted {params.omitted} \
            --html {output.html} --author {params.author} \
            --version {params.version} --samples {input.samples}
        """


onsuccess:
    if not config.get("no_temp_declared"):
        shutil.rmtree("tmp", ignore_errors=True)
    print("Protocol [%s] completed successfully." % PROTOCOL_VERSION)
