#!python

"""Summarizes alignments with plots.

Written by Jesse Bloom."""


import sys
import os
import time
import re
import dms_tools.utils
import dms_tools.parsearguments
import dms_tools.file_io
import dms_tools.plot



def WriteMutFreqs(codon_counts, names, mutfreqstext):
    """Writes in text format the data plotted by *dms_tools.plot.PairedMutFracs*.

    The first two arguments are the same as to *dms_tools.plot.PairedMutFracs*.
    *mutfreqstext* is the name of the created text file containing the data.
    """
    assert len(codon_counts) == len(names)
    keys = ['synonymous', 'nonsynonymous', 'stop codon'] + ['1 nucleotide mutation', '2 nucleotide mutations', '3 nucleotide mutations']
    data = {}
    for (name, counts) in zip(names, codon_counts):
        d = {}
        counts = dms_tools.utils.ClassifyCodonCounts(counts)
        denom = float(counts['TOTAL_COUNTS'])
        if not denom:
            raise ValueError("no counts for a sample")
        for key in keys:
            if key == 'nonsynonymous':
                d[key] = counts['TOTAL_NS'] / denom
            elif key == 'synonymous':
                d[key]= counts['TOTAL_SYN'] / denom
            elif key == 'stop codon':
                d[key] = counts['TOTAL_STOP'] / denom
            elif key == '1 nucleotide mutation':
                d[key] = counts['TOTAL_N_1MUT'] / denom
            elif key == '2 nucleotide mutations':
                d[key] = counts['TOTAL_N_2MUT'] / denom
            elif key == '3 nucleotide mutations':
                d[key] = counts['TOTAL_N_3MUT'] / denom
            else:
                raise ValueError("Invalid key of %s" % key)
        data[name] = d
    with open(mutfreqstext, 'w') as f:
        f.write('#sample %s\n' % ' '.join([key.replace(' ', '_') for key in keys]))
        for name in names:
            f.write('%s %s\n' % (name, ' '.join(['%g' % data[name][key] for key in keys])))
    f.close()



def ParseBarcodeData(names, stats, maxperbarcode):
    """Parses data for ``barcodes.pdf`` plot from summary stats.

    *names* is list of sample names.

    *stats* is list of corresponding summary stats.

    *maxperbarcode* : group barcodes with >= this many reads.

    The return value is *(categories, data)*, which are appropriate
    for use as the same arguments to *dms_tools.plot.PlotSampleBargraphs*.
    """
    assert len(names) == len(stats)
    barcodematch = re.compile('^(?P<fate>discarded|aligned|un-alignable) barcodes with (?P<nreads>\d+) reads$')
    data = {}
    categories = set([])
    for (name, s) in zip(names, stats):
        data[name] = {}
        for (key, n) in s.iteritems():
            if key not in ['total read pairs', 'read pairs that fail Illumina filter', 'low quality read pairs', 'barcodes randomly purged']:
                m = barcodematch.search(key)
                if not m:
                    raise ValueError("Failed to match key:\n%s" % key)
                nreads = int(m.group('nreads'))
                if nreads >= maxperbarcode:
                    category = '$\ge$%d reads, %s' % (maxperbarcode, m.group('fate'))
                    nreads = maxperbarcode
                elif nreads > 1:
                    category = '%d reads, %s' % (nreads, m.group('fate'))
                else:
                    category = '%d read, %s' % (nreads, m.group('fate'))
                categories.add((nreads + {'aligned':0.3, 'un-alignable':0.2, 'discarded':0.1}[m.group('fate')], category))
                if category in data[name]:
                    data[name][category] += n
                else:
                    data[name][category] = n
    categories = list(categories)
    categories.sort()
    categories.reverse()
    return ([tup[1] for tup in categories], data)


def ParseReadData(names, stats):
    """Parses data for ``read.pdf`` plot from summary stats.

    *names* is list of sample names.

    *stats* is list of corresponding summary stats.

    The return value is *(categories, data)*, which are appropriate
    for use as the same arguments to *dms_tools.plot.PlotSampleBargraphs*.
    """
    assert len(names) == len(stats)
    categories = ['failed filter', 'low quality', 'barcoded']
    data = {}
    barcodematch = re.compile('^(discarded|aligned|un-alignable) barcodes with (?P<nreads>\d+) reads$')
    for (name, s) in zip(names, stats):
        data[name] = {}
        data[name]['failed filter'] = s['read pairs that fail Illumina filter']
        data[name]['low quality'] = s['low quality read pairs']
        data[name]['barcoded'] = 0
        for (key, n) in s.iteritems():
            if key not in ['total read pairs', 'read pairs that fail Illumina filter', 'low quality read pairs', 'barcodes randomly purged']:
                m = barcodematch.search(key)
                if not m:
                    raise ValueError("Failed to match key:\n%s" % key)
                data[name]['barcoded'] += n * int(m.group('nreads'))
    return (categories, data)


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools.parsearguments.SummarizeAlignmentsParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # print initial information
    versionstring = dms_tools.file_io.Versions() 
    print 'Beginning execution of %s in directory %s\n' % (prog, os.getcwd())
    print versionstring
    print 'Parsed the following arguments:\n%s\n' % '\n'.join(['\t%s = %s' % tup for tup in args.iteritems()])

    # define output files and remove if they already exist
    mutfreqs = '%smutfreqs.pdf' % args['outprefix']
    mutfreqstext = '%smutfreqs.txt' % args['outprefix']
    mutcountsall = '%smutcounts_all.pdf' % args['outprefix']
    mutcountsmulti = '%smutcounts_multi_nt.pdf' % args['outprefix']
    depth = '%sdepth.pdf' % args['outprefix']
    mutdepth = '%smutdepth.pdf' % args['outprefix']
    reads = '%sreads.pdf' % args['outprefix']
    barcodes = '%sbarcodes.pdf' % args['outprefix']
    for f in [mutfreqs, mutfreqstext, mutcountsall, mutcountsmulti, depth, mutdepth, reads, barcodes]:
        if os.path.isfile(f):
            os.remove(f)
            print "Removing existing file %s" % f


    # make summaries
    if args['alignment_type'] == 'barcodedsubamplicons':
        names = []
        data = {}
        print "\nNow reading data from %s alignments..." % (args['alignment_type'])
        for (prefix, name) in args['alignments']:
            print "Processing data for %s from files with prefix %s" % (name, prefix)
            assert name not in names, "Duplicate name of %s" % name
            names.append(name)
            data[name] = {}
            countsfile = '%scounts.txt' % prefix
            if not os.path.isfile(countsfile):
                raise IOError("Failed to find counts file %s expected from the specification of the prefix %s" % (countsfile, prefix))
            data[name]['counts'] = dms_tools.file_io.ReadDMSCounts(countsfile, args['chartype'])
            statsfile = '%ssummarystats.txt' % prefix
            if not os.path.isfile(statsfile):
                raise IOError("Failed to find summary stats file %s expected from the specification of the prefix %s" % (statsfile, prefix))
            data[name]['stats'] = dms_tools.file_io.ReadSummaryStats(statsfile)
            (data[name]['all_cumulfracs'], data[name]['all_counts'], data[name]['syn_cumulfracs'], data[name]['syn_counts'], data[name]['multi_nt_all_cumulfracs'], data[name]['multi_nt_all_counts'], data[name]['multi_nt_syn_cumulfracs'], data[name]['multi_nt_syn_counts']) = dms_tools.utils.CodonMutsCumulFracs(data[name]['counts'])

        print "\nNow making plots..."
        print "Making %s" % mutfreqs
        dms_tools.plot.PlotPairedMutFracs([data[name]['counts'] for name in names], names, mutfreqs)
        if args['writemutfreqs']:
            print "Making %s" % mutfreqstext
            WriteMutFreqs([data[name]['counts'] for name in names], names, mutfreqstext)
        print "Making %s and %s" % (mutcountsall, mutcountsmulti)
        dms_tools.plot.PlotMutCountFracs(mutcountsall, 'Single and multi-nucleotide codon mutations', names,\
            [data[name]['all_cumulfracs'][ : args['maxmutcounts'] + 1] for name in names],\
            [data[name]['syn_cumulfracs'][ : args['maxmutcounts'] + 1] for name in names],\
            data[names[0]]['all_counts'], data[names[0]]['syn_counts'], legendloc='right', writecounts=False)
        dms_tools.plot.PlotMutCountFracs(mutcountsmulti, 'Multi-nucleotide codon mutations', names,\
            [data[name]['multi_nt_all_cumulfracs'][ : args['maxmutcounts'] + 1] for name in names],\
            [data[name]['multi_nt_syn_cumulfracs'][ : args['maxmutcounts'] + 1] for name in names],\
            data[names[0]]['multi_nt_all_counts'], data[names[0]]['multi_nt_syn_counts'], legendloc='right', writecounts=False)
        print "Making %s" % depth
        dms_tools.plot.PlotDepth([data[name]['counts'] for name in names], names, depth)
        print "Making %s" % mutdepth
        dms_tools.plot.PlotDepth([data[name]['counts'] for name in names], names, mutdepth, mutdepth=True)
        print "Making %s" % reads
        (categories, plotdata) = ParseReadData(names, [data[name]['stats'][1] for name in names])
        dms_tools.plot.PlotSampleBargraphs(names, categories, plotdata, reads, 'number of read pairs')
        print "Making %s" % barcodes
        (categories, plotdata) = ParseBarcodeData(names, [data[name]['stats'][1] for name in names], args['maxperbarcode'])
        dms_tools.plot.PlotSampleBargraphs(names, categories, plotdata, barcodes, 'number of barcodes')

    else:
        raise ValueError("Invalid alignment_type of %s" % args['alignment_type'])

    print '\nSuccessfully completed %s at %s' % (prog, time.asctime())



if __name__ == '__main__':
    main() # run the script
