#!/usr/bin/env python
import re
import os
import sys
import copy
import shlex
import logging
import optparse
import textwrap


class DUMMY:
    pass

class GFFRECORD:
    IDFINDER = re.compile(r'ID=([^;]*)')
    def __init__(self, line, clearAttribs = []):
        #see if we can split the line..
        line = line.strip()
        self.line = line
        
        ls = line.split("\t")
        if len(ls) != 9:
            self._type = "string"
        else:
            self._type = "record"
            self.ref = ls[0]
            self.source = ls[1]
            self.feature = ls[2]            
            self.start = ls[3]
            self.stop = ls[4]
            self.score = ls[5]
            self.strand = ls[6]
            self.frame = ls[7]
            self.attriblist = []
            
            lexer = shlex.shlex(ls[8])
            lexer.whitespace=';'
            lexer.whitespace_split=True
            for token in lexer:
                _key, _val = token.split('=', 1)
                if _key not in clearAttribs:
                    self.addAttribute(_key, _val)
                    
    def addAttribute(self, _key, _val):
        """
        Add an attribute to the record
        """
        _key = _key.strip()
        _val = _val.strip()
        if _val[0] in ['"',"'"]: _val = _val[1:]
        if _val[-1] in ['"',"'"]: _val = _val[:-1]
        _val = _val.replace('"',"'")

        if not self.__dict__.has_key(_key):
            self.__dict__[_key] = _val
            self.attriblist.append(_key)
        else:
            if type(self.__dict__[_key]) != type([]):
                self.__dict__[_key] = [self.__dict__[_key], _val]
            else:
                self.__dict__[_key].append(_val)
        
    def __str__(self):
        if self._type == "record":
            a = []
            if 'ID' in self.attriblist:
                a.append('ID=%s' % self.__dict__['ID'])
            for k in self.attriblist:
                if k == 'ID': continue
                v = self.__dict__[k]
                if type(v) == type([]):
                    v = ",".join(v)
                a.append("%s=%s" % (k, v))
            return "\t".join([
                self.ref, self.source, self.feature,
                str(self.start), str(self.stop),
                self.score, self.strand, self.frame,
                ";".join(a)])
        else:
            return self.line

class GFFINDEX:
    def __init__(self, key):
        self.keyName = key
        self.d = {}
        
    def add(self, rec):
        k = rec.__dict__[self.keyName]
        if not self.d.has_key(k): self.d[k] = []
        self.d[k].append(rec)

    def get(self, k):
        return set(self.d.get(k, []))

    def __str__(self):
        return "index '%s' len: %d" % (self.keyName, len(self.d))
        
class GFFSET:
    def __init__(self, fileName, preload=True, indici=['ref'], clearAttribs=[]):
        """
        clearAttribs : remove these attributes directly after loading
        """
        
        self.records = []
        self.fileName = fileName
        self.indici = {}
        self.clearAttribs = clearAttribs
        
        for i in indici:
            self.indici[i] = GFFINDEX(i)
            
        if preload:
            for rec in self._gffReader():
                self.addRecord(rec)
        

    def getRecords(self):        
        for r in self.records:
            if r._type == 'record':
                yield r
    
    def readRecords(self):
        for rec in self._gffReader():
            self.addRecord(rec)
            yield rec

    def addRecord(self, rec):
        self.records.append(rec)
        if rec._type == 'record':
            for i in self.indici.values():
                i.add(rec)
                    
    def report(self):
        l.info("Read %s" % self.fileName)
        l.info("Discovered %d records" % len(self.records))

    def search(self, **kwargs):
        res = None
        for k in kwargs.keys():
            i = self.indici[k]
            if not res: res = i.get(kwargs[k])
            else: res &= i.get(kwargs[k])
            l.debug("res %s -  %s" % (k, len(res)))
        return res
        
    def _gffReader(self):
        F = open(self.fileName, 'r')
        while True:
            line = F.readline()
            if not line: break #EOF
            yield GFFRECORD(line, clearAttribs=self.clearAttribs)
            if line == "##FASTA":
                break

        if line:
            #return the rest of the file.. there appears to be more
            while True:
                line = F.readline()
                if not line: break
                yield GFFRECORD(line, clearAttribs=self.clearAttribs)
        F.close()

gffSet = GFFSET(sys.argv[1], preload=False)
fields = "ref source feature start stop score strand".split()
for rec in gffSet.readRecords():
    if rec._type == 'string': continue
    for f in fields:
        wraf = "\n".join(textwrap.wrap(rec.__dict__[f], 80, 
                             initial_indent='', subsequent_indent = ' ' * 18))
        print "%10s : %s" % (f, wraf)
    for a in rec.attriblist:
        v = rec.__dict__[a].split(',')
        for vv in v:
            wraf = ("\n" + ' ' * 13).join( textwrap.wrap(vv, 70))
            print "%10s : %s" % (a[-10:], wraf)
        
    print '-' * 90

