#!/usr/bin/env python
# 
# Copyright 2009 Mark Fiers, Plant & Food Research
# 
# This file is part of Moa - http://github.com/mfiers/Moa
# 
# Moa is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your
# option) any later version.
# 
# Moa is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
# License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with Moa.  If not, see <http://www.gnu.org/licenses/>.
# 
"""
Extract information from blast output

"""

import os
import sys
import logging
import optparse

from Bio.Blast import NCBIStandalone
from Bio.Blast import NCBIXML

LOGLEVEL = logging.INFO

################################################################################
parser = optparse.OptionParser()
parser.add_option('-v', dest='verbose', action='store_true')

parser.set_defaults(separator="\t")
parser.add_option('-s', dest='separator')


parser.add_option('-k', dest='printKeys', action='store_true',                  
                  help = 'Print all usuable keys - note, these '+
                  'are generated dynamically - so you must define at least one blast '+
                  'report')

parser.add_option('-e', dest='e', action='store', type='float',              
                  help = 'print only records, alignments or hsps  with an expect value '+
                  'better than this')

parser.add_option('-R', dest='recs', action='store_true',                  
                  help = 'record mode - produce one line per record, output the '+
                  'defined fields, separated by colons')

parser.add_option('-A', dest='alignments', action='store_true',                  
                  help = 'alignment mode - produce one line per alignment, output the '+
                  'defined fields, separated by colons')

parser.add_option('-H', dest='hsps', action='store_true',
                  help = 'hsp mode - produce one line per hsp, output the '+
                  'defined fields, separated by colons')


TERMHEIGHT, TERMWIDTH = map(int, os.popen('stty size', 'r').read().split())

options, args = parser.parse_args()

if options.verbose:
    LOGLEVEL = logging.DEBUG

logging.basicConfig(level=LOGLEVEL, format = "%(levelname)s - %(message)s")
l = logging #shortcut

if not args:
    l.error("need to specify something on the commandline")
    sys.exit()


inFiles = args[args.index('-')+1:]
fields = args[:args.index('-')]

l.debug("processing %d input files" % len(inFiles))
if len(inFiles) == 0:
    l.critical("no input files defined")
elif len(inFiles) <4:
    l.debug("Processing: %s" % ",".join(inFiles))
else:
    l.debug("Processing %s, %s .. %s" % (inFiles[0], inFiles[1], inFiles[-1]))
l.debug("processing fields %s" % " ".join(fields))
l.debug("Console width: %s" % TERMWIDTH)

OUTPUT = []

################################################################################
#blast_parser = NCBIStandalone.BlastParser()

def cleanup_name(name):
    """
    Clean a name from forbidden characters
    """
    return name.replace(';', ','
            ).replace('=', '_'
            ).replace('%', '_'
            ).replace(',', '_'
            ).replace('__', '_')

class DUMMY:
    pass

def preprocessRecord(rec):
    """
    Add some aditional data to a blast record
    """
    
    rec.expect = 1e33
    rec.best_alignment = None
    rec.query_id2 = rec.query.split()[0]

    for alignment in rec.alignments:
        alignment.query_id = rec.query_id
        alignment.query = rec.query
        
        alignment.expect = 1e33
        alignment.best_hsp = None

        alignment.query_start = 1e33
        alignment.sbjct_start = 1e33
        alignment.query_end = 0
        alignment.sbjct_end = 0
        
        for hsp in alignment.hsps:
            eval = hsp.expect
            if alignment.expect > eval:
                alignment.expect = eval
                alignment.best_hsp = hsp

            if alignment.query_start > hsp.query_start:
                alignment.query_start = hsp.query_start
            if alignment.sbjct_start > hsp.sbjct_start:
                alignment.sbjct_start = hsp.sbjct_start
            if alignment.query_end < hsp.query_end:
                alignment.query_end = hsp.query_end
            if alignment.sbjct_end < hsp.sbjct_end:
                alignment.sbjct_end = hsp.sbjct_end

        if rec.expect > alignment.expect:
            rec.expect = alignment.expect
            rec.best_alignment = alignment
            
def preprocessAlignment(alignment, record):
    alignment.hit_id2 = alignment.hit_def.split()[0]


def preprocessHsp(hsp, alignment, record):
    queryLen = record.query_letters
    npc = float(queryLen) / (TERMWIDTH / 2)
    hsp.cfq = '.' * int( hsp.query_start / npc ) \
        + '#' * int( (hsp.query_end - hsp.query_start ) / npc ) \
        + '.' * int( (queryLen - hsp.query_end ) / npc ) \

def output_keys(r):
    ks = [x for x in dir(r) if x[0] != '_']
    ks.sort()
    ml = max(map(len, ks))+2

    nc = min(8, int(TERMWIDTH / ml) -1)
    while ks:
        print "  ", "".join(["%%-%ds" % ml % x for x in ks[:nc]])
        ks = ks[nc:]
        
    print "\n".join(ks)

def printKeys(record):
    print "Record:"
    output_keys(record)
    alignment = record.alignments[0]
    print "Alignment:"
    preprocessAlignment(alignment, record)
    output_keys(alignment)
    hsp = alignment.hsps[0]
    preprocessHsp(hsp, alignment, record)
    print "Hsp"
    output_keys(hsp)

    sys.exit()

        
def output(*args):
    global OUTPUT
    fields = args[0]
    obs = args[1:]
    rv=[]
    for k in fields:

        format = None
        if ':' in k: k, format = k.split(':')
        
        for o in obs:
            v = o.__dict__.get(k, None)
            if v != None: break
        else: v = k

        if format:
            rv.append(format % v)
        elif type(v) == type(0.18):
            rv.append("%.5g" % v)
        elif type(v) in [type([]), type((1,2))]:
            rv.append("%s" % len(v))
        else:
            rv.append("%s" % v)

    OUTPUT.append(rv)

def process_hsp(fields, hsp, alignment, record):
    #ignore hsp based on expect, but
    #only if the corresponding alignment
    #has not been printed
    if options.e and hsp.expect > options.e:
        return

    output(fields, hsp, alignment, record)
    
def process_alignment(fields, alignment, record):
    #ignore alignment based on expect, but
    #only if the corresponding record
    #has not been printed
    if options.e and alignment.expect > options.e:
        return
    if options.alignments:
        output(fields, alignment, record)
    else:
        for hsp in alignment.hsps:
            preprocessHsp(hsp, alignment, record)
            process_hsp(fields, hsp, alignment, record)


def process_record(fields, record):
    """
    Process & output information on a record 
    """
    #print or proceed only if the evalue is better than
    #the requested value
    if options.e and \
           record.expect > options.e:
        return
    #output - if requested
    if options.recs: 
        output(fields, record)
    else:
        for alignment in blast_record.alignments:
            preprocessAlignment(alignment, record)
            process_alignment(fields, alignment, record)

    
            
if __name__ == '__main__':
    for infile in inFiles:
        l.debug("processing %s" % infile)
        I = open(infile, 'r')
        for blast_record in NCBIXML.parse(I):

            preprocessRecord(blast_record)

            if options.printKeys:
                printKeys(blast_record)
                sys.exit()

            process_record(fields, blast_record)
    
    #print header
    hf = []
    import re
    for h in fields:        
        if ':' in h:
            h = re.sub(r'^%(-?)(\d+).*$', '%\\1\\2s', h.split(':')[1]) % h.split(':')[0]
        hf.append(h)
    print options.separator.join(hf)

    for o in OUTPUT:
        while r'\n' in o:
            o[o.index(r'\n')] = "\n"
        #while o.find("\n") > -1:
        #    print 'xxx '+ o.find("\n"),
        print  options.separator.join(o)

        
