#!/usr/bin/env python

help = """
This is a fairly complex, but rather powerfull script. It's can be used to combine 
two gff files into one. Based on selection criteria data from the second gff file
is added to the first.

imagine the following command:
   ./gffcombiner.py 
     -r ../10.getGenes/genes.gff
     -a ../20.getGeneBlast/blastgenes.gff
     -f CDS 
     -c 'note="blastx Swissprot Bacteria"'
     -R '^(?P<ref>\w*)_(?P<start>\d*)_(?P<stop>\d*)\s'
     -X '^.*Hsp_expect=([0-9\.e-]*);.*;Note="(.*)"'
     -Y ';note="Blastx Trembl Bacteria (Eval \1) \2"'
     -o $$base;
     
"""
import re
import os
import sys
import copy
import logging
import textwrap
import optparse

parser = optparse.OptionParser()
parser.add_option('-r', '--reference', dest='refgff',
                  help = 'reference gff, to map the other gff file on')
parser.add_option('-a', '--add', dest='addgff',
                  help = 'the gff to add to the reference gff file')
parser.set_defaults(outfile= '-')
parser.add_option('-o', '--output', dest='outfile',
                  help = 'output file, defaults to stdout')

parser.set_defaults(prepRefSearch=[], prepRefReplace=[])
parser.add_option('--pprs', dest='prepRefSearch', metavar='REGEX', 
                  action='append',
                  help='Allows search & replace on the reference set ' +
                  '(search term)')

parser.add_option('--pprr', dest='prepRefReplace', metavar='REGEX', 
                  action='append',
                  help='Allows search & replace on the reference set ' +
                  '(replace term)')

parser.set_defaults(prepAppendSearch=[], prepAppendReplace=[])
parser.add_option('--ppas', dest='prepAppendSearch',  metavar='REGEX', 
                  action='append',
                  help='Allows search & replace regular expressions to help ' +
                  'deal with invalidly formatted append sets')

parser.add_option('--ppar', dest='prepAppendReplace', metavar='REGEX', 
                  action='append',
                  help='Allows search & replace regular expressions to help ' +
                  'deal with invalidly formatted append sets')

parser.set_defaults(referenceKey='ID')
parser.add_option('-Q', '--referenceKey', dest='referenceKey',
                  help='attribute key to look for in the reference set - based ' + 
                  'this key the reference records to append to are selected (' +
                  'default: ID)')

parser.add_option('-f', '--featureRef', dest='featureRef',
                  help="apply to features of the reference set of this type")
parser.add_option('-s', '--sourceRef', dest='sourceRef',
                  help="apply to features of this source")

parser.add_option('-g', '--featureAdd', dest='featureAdd',
                  help="add only features of type to the reference set")

parser.set_defaults(clearAttribs = [])

parser.add_option('-c', '--clearAttribs', dest='clearAttribs', action="append",
                  metavar="ATTRIB",
                  help="remove all these attributes from the reference set before adding")

parser.set_defaults(clearAttribRegex = [])
parser.add_option('--cr',  dest='clearAttribRegex', metavar='REGEX', 
                  action='append',
                  help="remove only thos attributes that match this regex, " + 
                  "(to be used with -c")

#parser.add_option('-R', '--regex', dest='regex',
#                  help='regex to apply to each line in the gff to be added' +
#                  'to determine how to match the reference gff')                  

parser.set_defaults(appendKey='Note')
parser.add_option('-X', dest='appendKey', metavar='KEY',
                  help='Keyname of the newly added attribute')

parser.set_defaults(appendVal='"blast hit: %(Note)s"')
parser.add_option('-Y', dest='appendVal', metavar='VALUE',
                  help='Value of the newly added attribute. Note that you can ' +
                  'use python %()s string substitutions attribute')

parser.add_option('-v', dest='verbose', action='store_true',
                  help="verbose output")

#parser.add_option(
(options, args) = parser.parse_args()

if options.verbose:
    logging.basicConfig(level=logging.DEBUG, 
                        format = "%(levelname)s - %(message)s")
else:
    logging.basicConfig(level=logging.INFO, 
                        format = "%(levelname)s - %(message)s")
l = logging 


def preprocess_attrib(s):
    if '"' in s:
        s = s.replace('"', "'")
    #if ';' in s:
    #    return '"%s"' % s
    return s

class DUMMY:
    pass

class GFFRECORD:
    IDFINDER = re.compile(r'ID=([^;]*)')
    VALSPLIT = re.compile(r'(?P<fq>[\'\"]?)(?P<val>.*?)(?P=fq)[,;]')
    FIRST8COLS = "ref source feature start stop score strand frame".split()
    
    def __init__(self, line, 
                 clearAttribs = [], 
                 clearAttribRegex=[],
                 prepSearch = [],
                 prepReplace = []):
        #see if we can split the line..
        line = line.strip()
        #self.line = line
        
        for i in range(len(prepSearch)):
            sterm = prepSearch[i]
            rterm = prepReplace[i]
            #l.debug("attempt search replace (%s - %s)" % (
            #    sterm, rterm))
            #l.debug(" - original %s" % line)
            line = re.sub(sterm, rterm, line)
            #l.debug(" - after  %s" % line)        
        #see if we can do a search & replace 
        
        self.line = line
        self.clearAttribs = clearAttribs
        self.clearAttribCompiledRe = []
        for car in clearAttribRegex:
            self.clearAttribCompiledRe.append(re.compile(car))
        
        ls = line.split("\t")
        if len(ls) != 9:
            self._type = "string"
        else:
            self._type = "record"
            self.ref = ls[0]
            self.source = ls[1]
            self.feature = ls[2]            
            self.start = ls[3]
            self.stop = ls[4]
            self.score = ls[5]
            self.strand = ls[6]
            self.frame = ls[7]
            self.attributes = {}
                        
            #for token in lexer: 
            for token in re.finditer(
                r'(?P<key>[a-zA-Z_][a-zA-Z0-9_]*)=(?P<fq>"?)(?P<val>.*?)(?P=fq);' , 
                    ls[8]+';'):  
                _key = token.groupdict()['key']
                _val = token.groupdict()['val']
                _val = _val.replace('=', ':')
                
                self.addAttribute(_key, _val)
                      
    def getKeys(self):
        #get all keys (attributes + first 8 columns) of this record
        return self.attributes.keys() + self.FIRST8COLS
        
    def _getFirst8ColsDict(self):
        """
        Get a dictionary representing the first 8 columns of the GFF
        record
        """
        return dict([(k, self.__dict__[k]) for k in self.FIRST8COLS])
        
    def getVal(self, name, default = []):
        """
        Get the value of a key
        """
        if name in self.FIRST8COLS:
            return self.__dict__[name]
        if name in self.attributes.keys():
            return self.attributes[name]
        return default
        
    def getFormattedData(self):
        """
        Return a dictionary with all data in this gff record
        """        
        rv = copy.copy(self._getFirst8ColsDict())
        for k in self.attributes.keys():
            v = self.attributes[k]
            if len(v) == 1:
                rv[k] = v[0]
            else:
                rv[k] = ','.join(map(v, preprocess_attrib))
        return rv
        
    
    def __getattr__(self, name):
        if name in self.attributes.keys():
            val = self.attributes[name]
            if len(val) == 1:
                return val[0]
            else:
                return val                 
        raise AttributeError, name


    def addAttribute(self, _key, _val, ignoreFilter = False):
        """ Add an attribute to the record """        
        
        _valList = []        
        for _tv in self.VALSPLIT.findall(_val+';'):
            _valList.append(_tv[1].strip())
        
        #l.critical("### VALSPLIT BEFORE ")
        #l.critical("    %s " % _val)
        #l.critical("### VALPSPLIT AFTER ")
        #for v in _valList:
        #    l.critical("    %s " % v)
        
        _key = _key.strip()

        if not self.attributes.has_key(_key):
            self.attributes[_key] = []
            
        #replace double quotes with single quotes
        for i in range(len(_valList)):
            _valList[i] = _valList[i].replace('"',"'")

        #should we filter the incoming keys? if not, just add all
        #values and return
        if ignoreFilter or (not _key in self.clearAttribs):
            self.attributes[_key].extend(_valList)
            return
        
        #ok, we are going to filter the incoming attribute values 
        for care in self.clearAttribCompiledRe:
            for _val in _valList:
                if care.match(_val):
                    #skip this value
                    #l.debug("Skipping %s = %s" % (_key, _val))
                    continue
                self.attributes[_key].append(_val)

    def __str__(self):
        if self._type == "record":
            a = []
            if 'ID' in self.attributes.keys():
                a.append('ID=%s' % self.attributes['ID'][0])
            for k in self.attributes.keys():
                if k == 'ID': continue
                v = self.attributes[k]
                if type(v) == type([]):
                    v = ','.join(map(preprocess_attrib,v))
                if v:
                    a.append('%s=%s' % (k, preprocess_attrib(v)))
                
            return "\t".join([
                self.ref, self.source, self.feature,
                str(self.start), str(self.stop),
                self.score, self.strand, self.frame,
                ";".join(a)])
        else:
            return self.line
            
    def abbstr(self, s):
        if len(str(s)) > 120:
            return "%s ... %s" % (s[:58], s[-58:])
        else:
            return s
        
    def pretty(self):
        if self._type != 'record':
            return self.line
        rv = []
        rv.append(" ".join([self.ref, self.source, self.feature,
                    str(self.start), str(self.stop),
                    self.score, self.strand, self.frame]))
        for at in self.attributes.keys():            
            va = self.attributes[at]
            if type(va) in [type([1]), type((2,3))]:
                for va1 in va:
                    rv.append("      %-20s : %s" % (at, self.abbstr(va1)))
            else:
                 rv.append("      %-20s : %s" % (at, self.abbstr(va)))
        return "\n".join(rv)                                     
                

class GFFINDEX:
    def __init__(self, key):
        l.info("Creating an index for key %s" % key)
        self.keyName = key
        self.d = {}
        
    def add(self, rec):
        klist = rec.getVal(self.keyName, [])
        if not type(klist) == type([]):
            klist = [klist]
        for k in klist:                          
            if not self.d.has_key(k): self.d[k] = []
            #l.debug("appending key %s = %s for record %s" % (self.keyName, k, rec))
            self.d[k].append(rec)

    def get(self, k):
        return set(self.d.get(k, []))

    def __str__(self):
        return "index '%s' len: %d" % (self.keyName, len(self.d))
        
class GFFSET:
    def __init__(self, fileName, indici=['ref'], 
                 clearAttribs=[],
                 clearAttribRegex=[],
                 prepSearch=[], prepReplace=[]):
        """
        clearAttribs : remove these attributes directly after loading
        """
        
        self.records = []
        self.fileName = fileName
        self.indici = {}
        self.prepSearch = prepSearch
        self.prepReplace = prepReplace
        self.clearAttribs = clearAttribs
        self.clearAttribRegex = clearAttribRegex
        
        for i in indici:
            self.indici[i] = GFFINDEX(i)
            
        for rec in self._gffReader():
            self.addRecord(rec)

    def getRecords(self):
        for r in self.records:
            if r._type == 'record':
                yield r
            
    def addRecord(self, rec):
        self.records.append(rec)
        if rec._type == 'record':
            for i in self.indici.values():
                i.add(rec)
                    
    def report(self):
        l.info("Read %s" % self.fileName)
        l.info("Discovered %d records" % len(self.records))

    def search(self, **kwargs):
        res = None
        counter = 0
        for k in kwargs.keys():
            counter += 1
            i = self.indici[k]            
            if counter == 1: res = i.get(kwargs[k])
            else: res &= i.get(kwargs[k])
            if TRACE: print "filtering %s (%d results)" % (k, len(res))

            l.debug("res %s -  %s" % (k, len(res)))
        return res
        
    def _gffReader(self):
        F = open(self.fileName, 'r')
        while True:
            line = F.readline()
            if not line: break #EOF
            yield GFFRECORD(line, 
                            clearAttribs=self.clearAttribs,
                            clearAttribRegex =self.clearAttribRegex,
                            prepSearch = self.prepSearch,
                            prepReplace = self.prepReplace)
            if line == "##FASTA":
                break

        if line:
            #return the rest of the file.. there appears to be more
            while True:
                line = F.readline()
                if not line: break
                yield GFFRECORD(line, 
                                clearAttribs=self.clearAttribs,
                                clearAttribRegex =self.clearAttribRegex)
        F.close()


l.debug("start loading reference set")
refset = GFFSET(options.refgff, indici=['ID', 'feature', options.referenceKey],
                clearAttribs=options.clearAttribs, 
                clearAttribRegex=options.clearAttribRegex,
                prepSearch=options.prepRefSearch,
                prepReplace=options.prepRefReplace)
refset.report()

#read the add-to-set
l.debug("start loading the add-set")
addset = GFFSET(options.addgff, indici=['ref'],
                prepSearch=options.prepAppendSearch,
                prepReplace=options.prepAppendReplace)
addset.report()

querybase = {}
if options.featureRef:
    querybase['feature'] =  options.featureRef
if options.sourceRef:
    querybase['source'] =  options.sourceRef

l.debug("querybase %s" % querybase)

i = 0
appK = options.appendKey


TRACE=False
for rec in addset.getRecords():
    if options.featureAdd and rec.feature != options.featureAdd:
        l.debug("ignoring %s" % (str(rec)[:80]))
        continue

    appV = options.appendVal % rec.getFormattedData()
    l.debug("appending %s -  %s" % (appK, appV))
    if 'oxidored' in appV:
        TRACE=True
        print "appending ", appK, appV
    else:
        TRACE = False

    if i < 100:
        l.debug("%s" % rec)
        for a in rec.attributes.keys():
            l.debug("  a - %s = %s" % (a, rec.attributes[a]))

    query = copy.copy(querybase)
    i += 1;
    if i % 180 == 0:
        l.info("processed %d records" % i)
   
    query[options.referenceKey] = rec.ref
    if TRACE:
        print 'query: ', query

    l.debug("query %s" % (query))
        
    results = refset.search(**query)
    if TRACE:
        print "no results", len(results)
    for r in results:
        r.addAttribute(appK, appV, ignoreFilter=True)
        l.debug("post append")
        l.debug("     %s" % (r.pretty()))

if options.outfile == '-':
    F = sys.stdout
else:
    F = open(options.outfile, 'w')

for r in refset.records:
    F.write("%s\n"%r)

if options.outfile != '-':
    F.close()
    

