#!python

"""Aligns and counts mutations in barcoded subamplicons.

Written by Jesse Bloom."""


import sys
import os
import time
import logging
import gzip
import random
import Bio.SeqIO
import dms_tools
import dms_tools.parsearguments
import dms_tools.file_io
import dms_tools.utils



def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools.parsearguments.BarcodedSubampliconsParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    random.seed(args['seed'])

    # set up logging
    logging.shutdown()
    logfile = "%s.log" % args['outprefix']
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)
    logger.info("Beginning execution of %s in directory %s\n" % (prog, os.getcwd()))
    logger.info("Progress will be logged to %s" % logfile)

    # define file names and delete existing files
    countsfilename = "%scounts.txt" % args['outprefix']
    summarystatsfilename = '%ssummarystats.txt' % args['outprefix']
    barcodeinfofilename = '%sbarcodeinfo.txt.gz' % args['outprefix']
    outfilenames = [countsfilename, summarystatsfilename, barcodeinfofilename]
    for f in outfilenames:
        if os.path.isfile(f):
            os.remove(f)
            logger.info("Removing existing file %s" % f)
    countsfile = open(countsfilename, 'w')
    summarystatsfile = open(summarystatsfilename, 'w')
    outfiles = [countsfile, summarystatsfile]
    if args['barcodeinfo']:
        barcodeinfofile = gzip.open(barcodeinfofilename, 'w')
        outfiles.append(barcodeinfofile)

    # log in try / except loop
    try:
        versionstring = dms_tools.file_io.Versions() 
        logger.info("%s\n" % versionstring)
        logger.info('Parsed the following arguments:\n%s\n' % '\n'.join(['\t%s = %s' % tup for tup in args.iteritems()]))

        # read refseq
        refseq = [seq for seq in Bio.SeqIO.parse(open(args['refseq']), 'fasta')]
        assert len(refseq) == 1, "refseq of %s does not specify exactly one sequence" % args['refseq']
        refseq = str(refseq[0].seq).upper()
        if args['chartype'] == 'codon':
            assert len(refseq) % 3 == 0, "refseq does not specify a sequence of codons since its length of %d nucleotides is not a multiple of 3" % (len(refseq))
            logger.info('Read a reference sequence of %d codons from %s\n' % (len(refseq) // 3, args['refseq']))
        else:
            raise ValueError("Invalid chartype")

        # do some checking on validity of alignspecs
        for (refseqstart, refseqend, r1start, r2start) in args['alignspecs']:
            assert refseqend > refseqstart, "REFSEQEND <= REFSEQSTART"
            if args['barcodelength'] > r1start:
                raise ValueError("One of the alignspecs specifies R1START of %d, which doesn't fully trim a barcode of barcodelength %d" % (r1start, args['barcodelength']))
            if args['barcodelength'] > r2start:
                raise ValueError("One of the alignspecs specifies R2START of %d, which doesn't fully trim a barcode of barcodelength %d" % (r2start, args['barcodelength']))
            if r1start < 1:
                raise ValueError("R1START must be >= 1; you specified %d" % r1start)
            if r2start < 1:
                raise ValueError("R2START must be >= 1; you specified %d" % r2start)
            if (refseqstart < 0 or refseqend > len(refseq) or refseqstart >= refseqend):
                raise ValueError("One of the alignspecs specifies REFSEQSTART of %d and REFSEQEND of %d, which are not valid values for a refseq of length %d nucleotides." % (refseqstart, refseqend, len(refseq)))

        # check on read files
        assert len(args['r1files']) == len(args['r2files']), "r1files and r2files do not specify the same number of files."
        if all([os.path.splitext(f)[1].lower() == '.gz' for f in args['r1files'] + args['r2files']]):
            gzipped = True
        else:
            if any([os.path.splitext(f)[1].lower() == '.gz' for f in args['r1files'] + args['r2files']]):
                raise ValueError("Some but not all of the r1files and r2files are gzipped. Either all or no files can be gzipped.")
            gzipped = False

        # collect reads by barcode while iterating over reads
        readcategories = ['total read pairs', 'read pairs that fail Illumina filter', 'low quality read pairs']
        n = dict([(category, 0) for category in readcategories])
        barcodes = {}
        logger.info('Now parsing reads from the following:\n\tr1files: %s\n\tr2files: %s' % (' '.join(args['r1files']), ' '.join(args['r2files'])))
        for read_tup in dms_tools.file_io.IteratePairedFASTQ(args['r1files'], args['r2files'], gzipped, applyfilter=True):
            n['total read pairs'] += 1
            if read_tup:
                (name, r1, r2, q1, q2) = read_tup
                qcheckedreads = dms_tools.utils.CheckReadQuality(r1, r2, q1, q2, args['minq'], args['maxlowqfrac'], args['barcodelength'])
                if qcheckedreads:
                    (r1, r2) = qcheckedreads
                    barcode = r1[ : args['barcodelength']] + r2[ : args['barcodelength']]
                    if barcode in barcodes:
                        barcodes[barcode].append((r1, r2))
                    else:
                        barcodes[barcode] = [(r1, r2)]
                else:
                    n['low quality read pairs'] += 1
            else:
                n['read pairs that fail Illumina filter'] += 1
            if n['total read pairs'] % 1e5 == 0:
                logger.info('Reads parsed so far: %d' % n['total read pairs'])

        logger.info('Finished parsing all %d reads; ended up with %d unique barcodes.\n' % (n['total read pairs'], len(barcodes)))

        readcategories.append('barcodes randomly purged')
        if args['purgefrac']:
            logger.info('Randomly purging barcodes with probability %.3f to subsample the data.' % args['purgefrac'])
            npurged = 0
            for barcode in list(barcodes.iterkeys()):
                if random.random() < args['purgefrac']:
                    npurged += 1
                    del barcodes[barcode]
            n['barcodes randomly purged'] = npurged
            logger.info('Randomly purged %d of the %d barcodes (%.1f%%), leaving %d barcodes.\n' % (npurged, len(barcodes) + npurged, 100.0 * npurged / (npurged + len(barcodes)), len(barcodes)))
        else:
            n['barcodes randomly purged'] = 0

        # Examine barcodes to see if they pass criteria, if so align them and record mutation counts
        # Set up *counts* as dict that will hold mutation counts in format for dms_tools.file_io.WriteDMSCounts
        counts = {}
        if args['chartype'] == 'codon':
            for r in range(1, len(refseq) // 3 + 1):
                counts[str(r)] = dict([('WT', refseq[3 * r - 3 : 3 * r])] + [(codon, 0) for codon in dms_tools.codons])
        else:
            raise ValueError("Invalid chartype %s" % args['chartype'])
        # Now start looping over barcodes
        logger.info('Now examining the %d barcodes to see if reads meet criteria for retention' % len(barcodes))
        newreadcategories = []
        unalignedkeystring = 'un-alignable barcodes with %d reads'
        alignedkeystring = 'aligned barcodes with %d reads'
        discardedkeystring = 'discarded barcodes with %d reads'
        ibarcode = 0
        for barcode in list(barcodes.iterkeys()):
            ibarcode += 1
            if ibarcode % 1e5 == 0:
                logger.info('Barcodes examined so far: %s' % ibarcode)
            nreads = len(barcodes[barcode])
            if nreads < args['minreadsperbarcode']:
                if args['barcodeinfo']:
                    barcodeinfofile.write('BARCODE: %s\nRETAINED: no\nDESCRIPTION: too few reads\nCONSENSUS: none\nREADS:\n\t%s\n\n' % (barcode, '\n\t'.join(['R1 = %s; R2 = %s' % tup for tup in barcodes[barcode]])))
                keystring = discardedkeystring % nreads
            else:
                consensus = dms_tools.utils.BuildReadConsensus(barcodes[barcode], args['minreadidentity'], args['minreadconcurrence'], args['maxreadtrim'], use_cutils=True)
                if consensus:
                    # try to align subamplicon
                    (r1, r2) = consensus
                    for (refseqstart, refseqend, r1start, r2start) in args['alignspecs']:
                        maxN = int(args['maxlowqfrac'] * min(refseqend - refseqstart + 1, len(r1) + len(r2) - r1start - r2start + 2))
                        if r1start >= len(r1) or r2start >= len(r2):
                            aligned = False # reads too short toa align
                        elif not args['R1_is_antisense']:
                            aligned = dms_tools.utils.AlignSubamplicon(refseq, r1[r1start - 1 : ], r2[r2start - 1 : ], refseqstart, refseqend, args['maxmuts'], maxN, args['chartype'], counts, use_cutils=True)
                        else:
                            aligned = dms_tools.utils.AlignSubamplicon(refseq, r2[r2start - 1 : ], r1[r1start - 1 : ], refseqstart, refseqend, args['maxmuts'], maxN, args['chartype'], counts, use_cutils=True)
                        if aligned:
                            keystring = alignedkeystring % nreads
                            if args['barcodeinfo']:
                                barcodeinfofile.write('BARCODE: %s\nRETAINED: yes\nDESCRIPTION: aligned from %d to %d with %d %s mutations\nCONSENSUS: R1 = %s; R2 = %s\nREADS:\n\t%s\n\n' % (barcode, refseqstart, refseqend, aligned[1], args['chartype'], r1, r2, '\n\t'.join(['R1 = %s; R2 = %s' % tup for tup in barcodes[barcode]])))
                            break
                    else:
                        keystring = unalignedkeystring % nreads
                        if args['barcodeinfo']:
                            barcodeinfofile.write('BARCODE: %s\nRETAINED: no\nDESCRIPTION: subamplicon could not be aligned to refseq\nCONSENSUS: R1 = %s; R2 = %s\nREADS:\n\t%s\n\n' % (barcode, r1, r2, '\n\t'.join(['R1 = %s; R2 = %s' % tup for tup in barcodes[barcode]])))
                else:
                    if args['barcodeinfo']:
                        barcodeinfofile.write('BARCODE: %s\nRETAINED: no\nDESCRIPTION: reads not sufficiently identical\nCONSENSUS: none\nREADS:\n\t%s\n\n' % (barcode, '\n\t'.join(['R1 = %s; R2 = %s' % tup for tup in barcodes[barcode]])))
                    keystring = discardedkeystring % nreads
            try:
                n[keystring] += 1
            except KeyError:
                n[keystring] = 1
                newreadcategories.append((nreads, keystring))
            del barcodes[barcode] # free up memory as we finish with barcodes
        newreadcategories.sort()
        readcategories = readcategories + [tup[1] for tup in newreadcategories]
        logger.info('Finished examining all %d barcodes\n' % ibarcode)

        # Write the counts
        logger.info('Writing the %s counts to %s' % (args['chartype'], countsfilename))
        dms_tools.file_io.WriteDMSCounts(countsfile, counts)

        # Summarize statistics
        logger.info('Here are the summary statistics (these are also being written to %s):\n%s' % (summarystatsfilename, '\n'.join(['\t%s = %s' % (category, n[category]) for category in readcategories])))
        summarystatsfile.write('\n'.join(['%s = %s' % (category, n[category]) for category in readcategories]))

    except:
        logger.exception('Terminating %s at %s with ERROR' % (prog, time.asctime()))
        for f in outfiles:
            try:
                f.close()
            except:
                pass
        for f in outfilenames:
            if os.path.isfile(f):
                os.remove(f)
                logger.exception('Deleting file %s' % f)
    else:
        logger.info('Successful completion of %s at %s' % (prog, time.asctime()))
    finally:
        for f in outfiles:
            try:
                f.close()
            except:
                pass
        logging.shutdown()


if __name__ == '__main__':
    main() # run the script
