#!/usr/bin/env python


import os
import re
import sys
import copy
import optparse
import logging
import subprocess

from Bio.Seq import Seq
from Bio.Alphabet import generic_dna
from Bio.Seq import MutableSeq

logging.basicConfig(level=logging.DEBUG)
l = logging


parser = optparse.OptionParser()

parser.add_option('-d', dest='hcdiffs')
parser.add_option('-r', dest='refs')
parser.set_defaults(orfMinSize = 40)
parser.add_option('-o', dest='orfMinSize', type='int')
parser.add_option('-R', dest='useReverseORFs', action='store_true',
                  help = 'Check reverse orientation ORFs?')
parser.add_option('-s', dest='sff', action='append', 
                  help='input SFF file, to track read origin')
parser.add_option('-g', dest='gffOut', help='GFF output file')
parser.add_option('-t', dest='tsvOut', help='tab separated values output file')
parser.add_option('-l', dest='limit', help='limit the no of SNPS to ' + 
                  'test (for testing purposes', type='int')

## Parse the options
(options, args) = parser.parse_args()


class HCDiff:
    def __init__(self, refs, sff, predictions, data, count):
        self.references = refs
        self.sff = sff
        self.predictions = predictions
        self.count = count
        self._data = data
        _ls = self._data.pop(0).split("\t")
        self.ref = _ls[0][1:]
        if not self.ref in self.references:
            raise Exception("Unknown reference sequence: '%s'" % self.ref)
        self.refSeq = self.references[self.ref]
        self.start = int(_ls[1])
        self.end = int(_ls[2])
        self.refNuc = _ls[3]
        self.varNuc = _ls[4]
        
        #read until Start of difference reads
        while True:            
            d = self._data.pop(0)
            if d.strip() == 'Reads with Difference:':
                break

        self.refAlign = self._data.pop(0).strip()
        self._data.pop(0)

        self.diffAlign = []
        self.sameAlign = []
        self.diffReads = []
        self.sameReads = []

        self.origin = {}

        if self.sff:
            for k in self.sff:
                self.origin[k] = {'diff': [], 'same' : []}

            while True:
                d = self._data.pop(0).strip()
                if d and d[0] == '*': break
                if not d: continue
                readName = d.split()[0]
                self.diffReads.append(readName)
                self.diffAlign.append(d)
                for k in self.sff:
                    if readName in sff[k]:                    
                        self.origin[k]['diff'].append(readName)
                        break
                else:
                    raise(Exception("cannot find diff read %s in sff files" % readName))

            while True:
                d = self._data.pop(0).strip()
                if d ==  "Other Reads:": break
            self._data.pop(0) # ***

            while True:
                d = self._data.pop(0).strip()
                if d and (d[0] == '*' or d[0] == '-'): break
                if not d: continue
                self.sameAlign.append(d)
                readName = d.split()[0]
                self.sameReads.append(readName)
                for k in self.sff:                
                    if readName in sff[k]:
                        self.origin[k]['same'].append(readName)
                        break
                else:
                    raise(Exception("cannot find same read %s in sff files" % readName))
        

        #see if the SNP is potentially inside a ORF
        ps = self.predictions[self.ref]
        self.snpEffect = set()
        self.inOrf = False
        for orf in ps:
            if self.start >= orf['start'] and self.start <= orf['end']:
                self.inOrf=True
                #l.info('%s maybe in orf %s' % (self.ref, orf))
                seq = orf['seq']
                vn = self.varNuc
                if vn == '-': vn = ''
                aseq = Seq(str(seq)[:self.start - orf['start']] \
                               + vn + str(seq)[self.end-orf['start']+1:])

                #if seq == aseq: continue #nothing has changed
                seqt = str(seq.translate())
                aseqt = str(aseq.translate())
                if aseqt == seqt:
                    continue

                if (not '*' in aseqt) and (len(aseqt) == len(seqt)):
                    self.snpEffect.add('amino_acid_change')
                    continue

                ldif = len(aseq) - len(seq)
                if not '*' in aseqt and ldif %3 == 0:
                    self.snpEffect.add('insert_delete')
                    continue
                
                if '*' in aseqt:
                    self.snpEffect.add('early_stop')
                
                self.snpEffect.add('frameshift')
            
        l.debug("Found diff in %s at %d (diff %d, same %d)" % (
                self.ref, self.start, len(self.diffReads), len(self.sameReads)))
            
    def toTsv(self):        
        """
        Convert this to a data structure to be outputted as a tab delimited file
        """

        k = ['id']
        v = ['PolyM_%s_%s_%s' % ( self.ref, self.start, self.count)]

        k.append('sequence')
        v.append(self.ref)
        
        k.append('start')
        v.append(str(self.start))

        k.append('stop')
        v.append(str(self.end))

        k.append('reference')
        v.append(self.refNuc)

        k.append('variant')
        v.append(self.varNuc)

        k.append('SameCount')
        v.append(str(len(self.sameReads)))

        k.append('VariantCount')
        v.append(str(len(self.diffReads)))

        k.append('snpEffect')
        if self.snpEffect: v.append(",".join(self.snpEffect))
        else: v.append("None")

        k.append('inOrf')
        if self.inOrf: v.append("Yes")
        else: v.append("No")

        if self.sff:
            sffFiles = self.sff.keys()
            sffFiles.sort()
            sameLC = ""
            diffLC = ""
            for s in sffFiles:
                k.append("same:%s" % s)
                v.append(str(len(self.origin[s]['same'])))
                k.append("diff:%s" % s)
                v.append(str(len(self.origin[s]['diff'])))

        return k,v
    
    def __str__(self):
        attrs = [
            'ID=PolyM_%s_%s_%s' % (
                self.ref, self.start, self.count),
            'Reference=%s' % self.refNuc,
            'Variant=%s' % self.varNuc,
            'SameCount=%d' % len(self.sameReads),
            'VariantCount=%d' % len(self.diffReads),
            ]

        if self.snpEffect:
            attrs.append('SnpEffect=%s'  % (",".join(self.snpEffect)))
        else:
            attrs.append('SnpEffect=None')

        if self.inOrf: attrs.append('InOrf=Yes')
        else: attrs.append('InOrf=No')

        if self.sff:
            sffFiles = self.sff.keys()
            sffFiles.sort()
            sameLC = ""
            diffLC = ""            
            for s in sffFiles:
                sameLC += "%s:%d," % (s, len(self.origin[s]['same']))
                diffLC += "%s:%d," % (s, len(self.origin[s]['diff']))
            attrs.append('SameCountPerLibrary=%s' % sameLC.strip()[:-1])
            attrs.append('VariantCountPerLibrary=%s' % diffLC.strip()[:-1])      

        return "\t".join([
            self.ref, "gsMapper", "Polymorphism",
            str(self.start), 
            str(self.end),
            '.', '.', '.',
            ";".join(attrs)
            ])

def HCDiffReader(references, sff, predictions, filename, limit):
    with open(filename) as F:
        for x in range(4): 
            F.readline()
        cs = []
        i = 0
           
        while True:
            line = F.readline()
            if not line: break
            if line[0] == '>':
                if cs:
                    i += 1
                    yield HCDiff(references, sff, predictions, cs, i )
                    if limit and i > limit:
                        break
                cs = [line]
            else:                
                cs.append(line)
        else:
            if cs:
                yield HCDiff(references, sff, predictions, cs, i )

def fastaReader(filename):
    rv = {}
    for pair in open(filename).read().strip()[1:].split('>'):
        head, seq = pair.split('\n',1)
        seqid = head.strip().split()[0]
        seq = "".join(seq.split())
        rv[seqid] = seq
    return rv
    
            
        
if __name__ == '__main__':

    l.debug("read fasta from %s" % options.refs)
    refs = fastaReader(options.refs)
    
    #run getorf for a preliminary 'annotation'
    cl = ('getorf -sequence %s -outseq stdout -minsize %d' % (
            options.refs, options.orfMinSize)).split()
    data = subprocess.Popen(cl, stdout=subprocess.PIPE).communicate()[0].split("\n")    
    mobj = re.compile(r'>(.*)_(\d+) \[(\d+) - (\d+)\]( \(REVERSE SENSE\))?')
    predictions = {}
    pcount = 0
    for line in data:
        if not line: continue
        if not line[0] == '>': continue
        m = mobj.match(line)
        ref = m.groups()[0]
        if not predictions.has_key(ref): predictions[ref] = []
        pp = {
            'no' : int(m.groups()[1]),
            'start' : int(m.groups()[2]),
            'end' : int(m.groups()[3])
            }
        
        if m.groups()[4]:
            if not options.useReverseORFs: continue
            raise(Exception("Not implemented!"))
            pp['orientation'] = '-'
            pp['start'], pp['end'] = pp['end'], pp['start']
        else:
            pp['orientation'] = '+'   
            rseq = refs[ref][pp['start']-1:pp['end']]
            pp['seq'] = Seq(rseq, generic_dna)

        pcount += 1
        predictions[ref].append(pp)
        
    l.debug("discovered %d tentative ORFs" % pcount)

    sff = {}
    if options.sff:
        for sffFile in options.sff:
            sffBase = os.path.basename(sffFile).replace('.sff', '')
            l.debug("parsing sff file: %s" % sffBase)
            data = subprocess.Popen(
                ['sffinfo', '-a', sffFile], 
                stdout=subprocess.PIPE).communicate()[0].split()
            sff[sffBase] = data
            l.debug("found %d reads in %s" % (len(data), sffFile))
            l.debug("first 5 reads: %s" % (", ".join(data[:5])))

    with open(options.gffOut, 'w') as F:
        with open(options.tsvOut, 'w') as G:
            i = 0
            for hcd in HCDiffReader(refs, sff, predictions, options.hcdiffs, options.limit):
                F.write("%s\n" % hcd)
                k,v = hcd.toTsv()
                if i == 0:
                    G.write("%s\n" % "\t".join(k))
                G.write("%s\n" % "\t".join(v))
                i += 1

