#!/usr/bin/env python3

import argparse
import os
import sys
import multiprocessing
import pyfastaq
import iva


parser = argparse.ArgumentParser(
    usage = '%(prog)s [options] {-f reads_fwd -r reads_rev | --fr reads} <output directory>')

parser.add_argument('outdir', help='Name of output directory (must not already exist)', metavar='Output directory')


io_group = parser.add_argument_group('Input and output')
io_group.add_argument('-f', '--reads_fwd', action=iva.common.abspathAction, help='Name of forward reads fasta/q file. Must be used in conjuction with --reads_rev', metavar='filename[.gz]')
io_group.add_argument('-r', '--reads_rev', action=iva.common.abspathAction, help='Name of reverse reads fasta/q file. Must be used in conjuction with --reads_fwd', metavar='filename[.gz]')
io_group.add_argument('--fr', action=iva.common.abspathAction, dest='reads', help='Name of interleaved fasta/q file', metavar='filename[.gz]')
io_group.add_argument('--keep_files', action='store_true', help='Keep intermediate files (could be many!). Default is to delete all unecessary files')
io_group.add_argument('--contigs', action=iva.common.abspathAction, help='Fasta file of contigs to be extended. Incompatible with --reference', metavar='filename[.gz]')
io_group.add_argument('--reference', action=iva.common.abspathAction, help='Fasta file of reference genome, or parts thereof. IVA will try to assemble one contig per sequence in this file. Incompatible with --contigs', metavar='filename[.gz]')
io_group.add_argument('-v', '--verbose', action='count', help='Be verbose by printing messages to stdout. Use up to three times for increasing verbosity.', default=0)


mapping_group = parser.add_argument_group('SMALT mapping options')
mapping_group.add_argument('-k', '--smalt_k', type=int, help='kmer hash length in SMALT (the -k option in smalt index) [%(default)s]', default=19, metavar='INT')
mapping_group.add_argument('-s', '--smalt_s', type=int, help='kmer hash step size in SMALT (the -s option in smalt index) [%(default)s]', default=11, metavar='INT')
mapping_group.add_argument('-y', '--smalt_id', type=float, help='Minimum identity threshold for mapping to be reported (the -y option in smalt map) [%(default)s]', default=0.5, metavar='FLOAT')


contig_group = parser.add_argument_group('Contig options')
contig_group.add_argument('--ctg_first_trim', type=int, help='Number of bases to trim off the end of every contig before extending for the first time [%(default)s]', default=25, metavar='INT')
contig_group.add_argument('--ctg_iter_trim', type=int, help='During iterative extension, number of bases to trim off the end of a contig when extension fails (then try extending again) [%(default)s]', default=10, metavar='INT')
contig_group.add_argument('--ext_min_cov', type=int, help='Minimum kmer depth needed to use that kmer to extend a contig [%(default)s]', default=10, metavar='INT')
contig_group.add_argument('--ext_min_ratio', type=float, help='Sets N, where kmer for extension must be at least N times more abundant than next most common kmer [%(default)s]', default=4, metavar='FLOAT')
contig_group.add_argument('--ext_max_bases', type=int, help='Maximum number of bases to try to extend on each iteration [%(default)s]', default=100, metavar='INT')
contig_group.add_argument('--ext_min_clip', type=int, help='Set minimum number of bases soft clipped off a read for those bases to be used for extension [%(default)s]', default=3, metavar='INT')
contig_group.add_argument('--max_contigs', type=int, help='Maximum number of contigs allowed in the assembly. No more seeds generated if the cutoff is reached [%(default)s]', metavar='INT', default=50)


seed_group = parser.add_argument_group('Seed generation options')
seed_group.add_argument('--make_new_seeds', action='store_true', help='When no more contigs can be extended, generate a new seed. This is forced to be true when --contigs is not used')
seed_group.add_argument('--seed_start_length', type=int, help='When making a seed sequence, use the most common kmer of this length. Default is to use the minimum of (median read length, 95). Warning: it is not recommended to set this higher than 95', metavar='INT', default=None)
seed_group.add_argument('--seed_stop_length', type=int, help='Stop extending seed using perfect matches from reads when this length is reached. Future extensions are then made by treating the seed as a contig [0.9*max_insert]', default=0, metavar='INT')
seed_group.add_argument('--seed_min_kmer_cov', type=int, help='Minimum kmer coverage of initial seed [%(default)s]', default=25, metavar='INT')
seed_group.add_argument('--seed_max_kmer_cov', type=int, help='Maximum kmer coverage of initial seed [%(default)s]', default=1000000, metavar='INT')
seed_group.add_argument('--seed_ext_max_bases', type=int, help='Maximum number of bases to try to extend on each iteration [%(default)s]', default=50, metavar='INT')
seed_group.add_argument('--seed_overlap_length', type=int, help='Number of overlapping bases needed between read and seed to use that read to extend [seed_start_length]', metavar='INT')
seed_group.add_argument('--seed_ext_min_cov', type=int, help='Minimum kmer depth needed to use that kmer to extend a contig [%(default)s]', default=10, metavar='INT')
seed_group.add_argument('--seed_ext_min_ratio', type=float, help='Sets N, where kmer for extension must be at least N times more abundant than next most common kmer [%(default)s]', default=4, metavar='FLOAT')


trimming_group = parser.add_argument_group('Read trimming options')
trimming_group.add_argument('--trimmomatic', action=iva.common.abspathAction, help='Provide location of trimmomatic.jar file to enable read trimming. Required if --adapters used', metavar='FILENAME')
trimming_group.add_argument('--trimmo_qual', help='Trimmomatic options used to quality trim reads [%(default)s]', default='LEADING:10 TRAILING:10 SLIDINGWINDOW:4:20', metavar='STRING')
trimming_group.add_argument('--adapters', action=iva.common.abspathAction, help='Fasta file of adapter sequences to be trimmed off reads. If used, must also use --trimmomatic. Default is file of adapters supplied with IVA', metavar='FILENAME')
trimming_group.add_argument('--min_trimmed_length', type=int, help='Minimum length of read after trimming [%(default)s]', default=50, metavar='INT')
trimming_group.add_argument('--pcr_primers', action=iva.common.abspathAction, help='FASTA file of primers. The first perfect match found to a sequence in the primers file will be trimmed off the start of each read. This is run after trimmomatic (if --trimmomatic used)', metavar='FILENAME') 


other_group = parser.add_argument_group('Other options')
other_group.add_argument('-i', '--max_insert', type=int, help='Maximum insert size (includes read length). Reads with inferred insert size more than the maximum will not be used to extend contigs [%(default)s]', default=500, metavar='INT')
other_group.add_argument('-t', '--threads', type=int, help='Number of threads to use [%(default)s]', default=1, metavar='INT')
other_group.add_argument('--strand_bias', type=float, help='Set strand bias cutoff of mapped reads when trimming contig ends, in the interval [0,0.5]. A value of x means that a base needs min(fwd_depth, rev_depth) / total_depth <= x. [%(default)s]', default=0.1, metavar='FLOAT in [0,0.5]')
other_group.add_argument('--version', action='version', version=iva.common.version)

options = parser.parse_args()
if options.seed_stop_length == 0:
    options.seed_stop_length = int(0.9 * options.max_insert)

if not (0 <= options.strand_bias <= 0.5):
    print('Error! strand bias must in the interval [0, 0.5]. Cannot continue because it\'s', options.strand_bias, file=sys.stderr)
    sys.exit(1)


if options.adapters and not options.trimmomatic:
    print('Error! --adapters used, but not --trimmomatic. I need the trimmomatic jar file. Cannot contiue', file=sys.stderr)
    sys.exit(1)

if not (bool(options.reads) ^ bool(options.reads_fwd and options.reads_rev)) or (bool(options.reads_fwd) != bool(options.reads_rev)):
    print('Error! Must use options: -f/--reads_fwd and -r/--reads_rev together, or just use --fr on its own. Cannot continue', file=sys.stderr)
    sys.exit(1)

if options.contigs and options.reference:
    print('Error! Cannot use both of --contgs and --reference. Cannot continue', file=sys.stderr)
    sys.exit(1)

if os.path.exists(options.outdir):
    print('Error! Output directory', options.outdir, 'already exists. Cannot continue', file=sys.stderr)
    sys.exit(1)


iva.external_progs.get_all_versions(iva.external_progs.assembly_progs)

try:
    os.mkdir(options.outdir)
except:
    print('Error making output directory', options.outdir)
    sys.exit(1)

os.chdir(options.outdir)

log_file = 'info.txt'
iva.external_progs.write_prog_info('iva', log_file)

reads_prefix = 'reads'
reads_1 = reads_prefix + '_1.fa'
reads_2 = reads_prefix + '_2.fa'
original_line_length = pyfastaq.sequences.Fasta.line_length
pyfastaq.sequences.Fasta.line_length = 0

if options.reads and not options.trimmomatic:
    pyfastaq.tasks.deinterleave(options.reads, reads_1, reads_2, fasta_out=True)
else:
    to_delete = []

    if options.reads:
        reads_for_trimming_1 = 'reads.untrimmed_1.fq'
        reads_for_trimming_2 = 'reads.untrimmed_2.fq'
        pyfastaq.tasks.deinterleave(options.reads, reads_for_trimming_1, reads_for_trimming_2)
        to_delete.append(reads_for_trimming_1)
        to_delete.append(reads_for_trimming_2)
    else:
        reads_for_trimming_1 = options.reads_fwd
        reads_for_trimming_2 = options.reads_rev

    if options.trimmomatic:
        trimmed_reads_prefix = 'reads.trimmed'
        if options.adapters is None:
            extractor = iva.egg_extract.Extractor(os.path.abspath(os.path.join(os.path.dirname(iva.__file__), os.pardir)))
            egg_adapters = os.path.join('iva', 'read_trim', 'adapters.fasta')
            options.adapters = 'adapters.fasta'
            extractor.copy_file(egg_adapters, options.adapters)

        assert os.path.exists(options.adapters)

        iva.read_trim.run_trimmomatic(reads_for_trimming_1, reads_for_trimming_2, trimmed_reads_prefix, options.trimmomatic, options.adapters, minlen=options.min_trimmed_length, verbose=options.verbose, threads=options.threads, qual_trim=options.trimmo_qual)
        fq_to_convert_to_fa_1 = trimmed_reads_prefix + '_1.fq'
        fq_to_convert_to_fa_2 = trimmed_reads_prefix + '_2.fq'
        to_delete.append(fq_to_convert_to_fa_1)
        to_delete.append(fq_to_convert_to_fa_2)
    else:
        fq_to_convert_to_fa_1 = reads_for_trimming_1
        fq_to_convert_to_fa_2 = reads_for_trimming_2

    p1 = multiprocessing.Process(target=pyfastaq.tasks.to_fasta, args=(fq_to_convert_to_fa_1, reads_1), kwargs={'line_length':0})
    p2 = multiprocessing.Process(target=pyfastaq.tasks.to_fasta, args=(fq_to_convert_to_fa_2, reads_2), kwargs={'line_length':0})
    p1.start()

    if options.threads == 1:
        p1.join()

    p2.start()
    p2.join()

    if options.threads > 1:
        p1.join()

    for fname in to_delete:
        os.unlink(fname)

if options.pcr_primers:
    tmp_reads_1 = 'reads_1.pcr_trim.fa'
    tmp_reads_2 = 'reads_2.pcr_trim.fa'
    pyfastaq.tasks.sequence_trim(reads_1, reads_2, tmp_reads_1, tmp_reads_2, options.pcr_primers, min_length=options.min_trimmed_length, check_revcomp=True)
    os.rename(tmp_reads_1, reads_1)
    os.rename(tmp_reads_2, reads_2)

pyfastaq.sequences.Fasta.line_length = original_line_length

if options.contigs:
    contigs = 'contigs_to_extend.fasta'
    pyfastaq.tasks.to_fasta(options.contigs, contigs, line_length=60, strip_after_first_whitespace=True)
elif options.reference:
    reference = 'reference_in.fasta'
    pyfastaq.tasks.to_fasta(options.reference, reference, line_length=60, strip_after_first_whitespace=True)
    p = iva.seed_processor.SeedProcessor(
        reference,
        reads_1,
        reads_2,
        'seeds.fasta',
        index_k = options.smalt_k,
        index_s = options.smalt_s,
        threads = options.threads,
        max_insert = options.max_insert,
        minid = 0.9,
        seed_stop_length = options.seed_stop_length,
        extend_length = options.seed_ext_max_bases,
        overlap_length = options.seed_overlap_length,
        ext_min_cov = options.seed_ext_min_cov,
        ext_min_ratio = options.seed_ext_min_ratio,
        verbose = options.verbose,
        seed_length = options.seed_start_length,
        seed_min_count = options.seed_min_kmer_cov,
        seed_max_count = options.seed_max_kmer_cov
    )
    p.process()
    contigs = 'seeds.fasta'
else:
    contigs = None
    options.make_new_seed = True

assembly = iva.assembly.Assembly(
    contigs,
    verbose = options.verbose,
    clean = not options.keep_files,
    map_index_k = options.smalt_k,
    map_index_s = options.smalt_s,
    threads = options.threads,
    map_minid = options.smalt_id,
    contig_iter_trim = options.ctg_iter_trim,
    ext_min_cov = options.ext_min_cov,
    ext_min_ratio = options.ext_min_ratio,
    ext_bases = options.ext_max_bases,
    min_clip = options.ext_min_clip,
    max_contigs = options.max_contigs,
    make_new_seeds = options.make_new_seeds,
    seed_start_length = options.seed_start_length,
    seed_stop_length = options.seed_stop_length,
    seed_min_kmer_count = options.seed_min_kmer_cov,
    seed_max_kmer_count = options.seed_max_kmer_cov,
    seed_ext_max_bases = options.seed_ext_max_bases,
    seed_overlap_length = options.seed_overlap_length,
    seed_min_cov = options.seed_ext_min_cov,
    seed_min_ratio = options.seed_ext_min_ratio,
    max_insert = options.max_insert,
    strand_bias = options.strand_bias
)

if options.contigs:
    assembly.trim_contigs(options.ctg_first_trim)
elif not options.reference:
    seed_name = assembly.add_new_seed_contig(reads_1, reads_2)
    if seed_name is None:
        error_message = 'Failed to make first seed. Cannot continue'
        print(error_message, file=sys.stderr)
        with open(log_file, 'a') as f:
            print(error_message, file=f)
        f.close()
        sys.exit(1)

assembly.read_pair_extend(reads_prefix, 'iteration')

final_contigs = 'contigs.fasta'

if options.trimmomatic or options.pcr_primers:
    pre_trim_contigs = 'contigs.pre_trim.fasta'
    assembly.write_contigs_to_file(pre_trim_contigs, min_length=100, order_by_orfs=True, prefix='contig')
    iva.contig_trim.trim_primers_and_adapters(pre_trim_contigs, final_contigs, options.adapters, options.pcr_primers, min_length=100)
    if not options.keep_files:
        os.unlink(pre_trim_contigs)
        os.unlink(pre_trim_contigs + '.fai')
else:
    assembly.write_contigs_to_file(final_contigs, min_length=100, order_by_orfs=True, prefix='contig')

if not options.keep_files:
    os.unlink(reads_1)
    os.unlink(reads_2)
    if contigs is not None:
        os.unlink(contigs)

