from glob import glob
import collections
import os
from singlecellmultiomics.utils import get_contig_list_from_fasta
"""
This workflow:
    - Runs bcftools on the supplied bam file
    - Runs variant extraction and phasing
    - Merges results
"""
################## configuration ##################
configfile: "config.json"
# config

# Obtain contigs:
job_size = 500_000 # Bases per job
max_N = None

contigs, sizes = get_contig_list_from_fasta(config['reference_file'], with_length=True)
contig_sizes = dict(zip(contigs,sizes))


def get_target_vcf_list(wildcards):
    """
    Obtain a list of target vcf paths
    """
    global contig_sizes
    global job_size
    global max_N

    target_list = []
    for contig, size in contig_sizes.items():
        if size is None:
            continue

        # The bcftools -r argument is 1 based and inclusive
        target_list += ['variant_calls/TEMP/%s_%s_%s.vcf.gz' % (contig, bin_start, min(bin_start+job_size-1, size))
        for bin_start in  range(1,size+1,job_size)]

    if max_N is not None:
        return target_list[:max_N]

    return target_list


def get_target_pickle_list(wildcards):
    """
    Obtain a list of target vcf paths
    """
    global contig_sizes
    global job_size
    global max_N

    target_list = []
    for contig, size in contig_sizes.items():
        if size is None:
            continue

        # The bcftools -r argument is 1 based and inclusive
        target_list += ['variant_calls/TEMP/%s_%s_%s.sc.pickle.gz' % (contig, bin_start, min(bin_start+job_size-1, size))
        for bin_start in  range(1,size+1,job_size)]

    if max_N is not None:
        return target_list[:max_N]

    return target_list


#[contig for contig in get_contig_list_from_fasta(config['reference_file']) if contig!='MISC_ALT_CONTIGS_SCMO']

rule all:
    input: 'variants/unfiltered_variants.vcf.gz',get_target_pickle_list


rule bcftools_call:
    input:
        scbam=config['bam'],

    output:
        vcf = "variant_calls/TEMP/{contig}_{start}_{end}.vcf.gz",
        vcfindex = "variant_calls/TEMP/{contig}_{start}_{end}.vcf.gz.csi"

    log:
        stdout="log/variant_calls/{contig}_{start}_{end}.stdout",
        stderr="log/variant_calls/{contig}_{start}_{end}.stderr"

    threads: 1
    params:
        runtime="60h",
        reference=config['reference_file'],

    resources:
        mem_mb = lambda wildcards, attempt, input: attempt * 10000,
        runtime = lambda wildcards, attempt, input: attempt * 24

    shell:
        """bcftools mpileup -r {wildcards.contig}:{wildcards.start}-{wildcards.end} -Ou {input.scbam} -f {params.reference} -d 1000000  | bcftools call -mv - | bcftools view -i '%QUAL>500' -Oz -o {output.vcf}   > {log.stdout} 2> {log.stderr};
        bcftools index {output.vcf}
        """

rule extract_sc:
    input:
        scbam=config['bam'],
        germline_variants_vcf=config['germline_variants_vcf'],
        vcf="variant_calls/TEMP/{contig}_{start}_{end}.vcf.gz"
    output:
        pickle='variant_calls/TEMP/{contig}_{start}_{end}.sc.pickle.gz'

    threads: 1
    params:
        runtime="60h",
        germline_variants_sample=config['germline_variants_vcf_sample'],

    resources:
        mem_mb = lambda wildcards, attempt, input: attempt * 10000,
        runtime = lambda wildcards, attempt, input: attempt * 24

    shell:
        "bamExtractVariants.py -extract {input.vcf} -germline_sample {params.germline_variants_sample} -germline {input.germline_variants_vcf} -o {output.pickle} {input.scbam} -t {threads}"


rule bcftools_gather:
    input:
        chr_vcfs = get_target_vcf_list

    output:
        vcf = 'variants/unfiltered_variants.vcf.gz'

    threads: 1
    resources:
        runtime=lambda wildcards, attempt, input: ( attempt * 24),
        mem_mb = 2000,
    message:
        'Merging contig VCF files'

    shell:
        "bcftools concat -Oz -o {output.vcf} {input.chr_vcfs}; bcftools index {output.vcf}"
