#!/usr/bin/env python

import os
import sys
import site
import math
import optparse
from xml.dom.minidom import getDOMImplementation
from subprocess import Popen, PIPE

MOABASE = os.environ["MOABASE"]
site.addsitedir(os.path.join(MOABASE, 'lib', 'python'))

## Init the logger
import moa.logger
l = moa.logger.l

parser = optparse.OptionParser()
parser.add_option('-i', dest='inputfile',
                  help='multifasta file of sequences to scaffold')
parser.add_option('-r', dest='reference',
                  help='multifasta file of the reference sequences')
parser.set_defaults(prefix='scaffold')
parser.add_option('-p', dest='prefix',
                  help='prefix for the output files')
parser.add_option('-v', dest='verbose', action='store_true',
                  help='verbose output')
parser.set_defaults(maxLinks=10)
parser.add_option('--mh', dest='maxHits', help=(
        'When finding links, based on BLAST vs the reference genome, limit ',
        'the number of links')
                    
(options, args) = parser.parse_args()


################################################################################
## Prepare & initialize
##

if options.verbose:
    moa.logger.setVerbose()

#so it can be used in string substitutions
PREFIX = options.prefix
INPUTFILE = options.inputfile

#print extra debugging information
DEBUG = False

#a global counter that tracks the number of links found
#used to create unique link names
LINKCOUNT = 0

#ecats is a categorization of the quality of the links found
#this allows distinguishing between e-100 hits and e-10 without
#discarding the latter ones completely
#
#ecats  range between 0 ( low quality link ) and 6 (high quality)
#corresponding to 0 and e-60 in blast evals
ECATS = set()

PRIORITY = {
    0 : 3,
    1 : 3,
    2 : 3,
    3 : 2,
    4 : 2,
    5 : 1,
    6 : 1,
    }

REDUNDANCE = {
    0 : 4,
    1 : 3,
    2 : 2,
    3 : 1,
    4 : 1,
    5 : 1,
    6 : 1
    }

#blast
BLASTFIELDS = """
query_id subject_id perc_identity alignment_length mismatches
gap_openings q_start q_end s_start s_end e_value bit_score""".split()



#prepare an evidence document
domImpl = getDOMImplementation()
EVIDOC = domImpl.createDocument(None, "EVIDENCE", None)
EVIDENCE = EVIDOC.documentElement


################################################################################
## Helper functions
##

def error(message, recA, recB):
    l.error(message)
    for f in BLASTFIELDS:
        l.error("  %20s : %40s %40s" % (f, recA[f], recB[f]))
    l.error(recA['raw'])
    l.error(recB['raw'])
    l.error('-' * 80)
    sys.exit()

def run(cl):
    """
    Run the commandline
    """
    l.debug("Executing")
    l.debug("  %s" % cl)
    Popen(cl.split()).communicate()

def run2(cl):
    l.debug("Executing")
    l.debug("  %s" % cl)
    return Popen(cl.split(), stdout=PIPE).stdout

def blastReader():
    """
    Execute blast & return pairs
    """

    l.info("generating blast db for %s" % INPUTFILE)
    cl = """ formatdb -i %s -n %s.query -p F """ % (
        INPUTFILE, PREFIX)
    run(cl)
    l.info("Finished generating blast db, starting blast")
    cl = """ blastall -i %s -d %s.query -p blastn -e 1e-30 -m 9 -W 10 """ % (
        options.reference, PREFIX)
    l.info("Finished blasting")
    
    with run2(cl) as blast:
        cc = None
        set = {}
        while True:
            line = blast.readline()
            if not line: break
            line = line.strip()
            if line[0] == '#':
                continue
            data = dict(zip(BLASTFIELDS, line.split()))
            data['raw'] = line
            if not cc:
                cc = data['query_id']
            if data['query_id'] != cc:
                #new query contig: yield current results
                yield cc, set
                cc = data['query_id']
                set = {}
            subject_id = data['subject_id']
            if not set.has_key(subject_id):
                set[subject_id] = []
            else:
                pass #continue
            set[subject_id].append(data)
            #else:
            #    #see what happens if we allow only one
            #    #subject/query pair
            #    pass 

        #yield the last set
        if set: yield cc, set



def prepareContigs():
    """
    
    """
    #find all contigs in the input set
    l.debug("creating contigs for %s" % INPUTFILE)
    contiginfo = run2("fastaInfo -i %s" % INPUTFILE)
    
    while True:
        line = contiginfo.readline()
        if not line: break
        line = line.strip()
        if not line: continue
        id,length = line.split()
        CONTIG = EVIDOC.createElement('CONTIG')
        CONTIG.setAttribute('ID', id)
        CONTIG.setAttribute('NAME', id)
        CONTIG.setAttribute('LEN', length)
        EVIDENCE.appendChild(CONTIG)

def analyzeBlastSet(recA, recB):
    """
    Analayze a combination of two blast hits
    """
    global LINKCOUNT
    LINKCOUNT += 1

    if DEBUG:
        for f in BLASTFIELDS:
            l.debug("  %20s : %40s %40s" % (f, recA[f], recB[f]))

    #important coordinates 
    q1s = int(recA['q_start'])
    q1e = int(recA['q_end'])
    q2s = int(recB['q_start'])
    q2e = int(recB['q_end'])

    s1s = int(recA['s_start'])
    s1e = int(recA['s_end'])
    s2s = int(recB['s_start'])
    s2e = int(recB['s_end'])

    oriA = 'BE'
    oriB = 'BE'

    if q2s > q1e:
        distance = (q2s - q1e)
    else:
        return
    # elif q1s > q2e:
    # continue

    if q1e < q1s: error("Orientation 1??", recA, recB)
    if q2e < q2s: error("Orientation 2??", recA, recB)
    if s1e < s1s:
        oriA = 'EB'
        #distance = distance * -1
    if s2e < s2s:
        oriB = 'EB'
        #distance = distance * -1



    # e1: 1e-100 -> -100
    # e2: 1e-5 -> -5
    e1 = int(math.log10(1e-100+float(recA['e_value'])))
    e2 = int(math.log10(1e-100+float(recB['e_value'])))

    #get the worst e-value of the two hits, divide by 10.0 and
    #take the absolute value
    #so, -100, -5 -> 1
    ecat = abs(int( max(e1, e2) / 10.0))
    if ecat > 6: ecat = 6
    if ecat < 0: ecat = 0
    ECATS.add(ecat)
    
    if DEBUG:
        l.debug("Distance %d" % distance)
        l.debug("Link %s %s (%d) %s %s (eval %s %s) ecat: %s (based on: %s)" % (
                recA['subject_id'], oriA,
                distance,
                recB['subject_id'], oriB,
                recA['e_value'], recB['e_value'],
                ecat,
                recA['query_id']
                ))

    LINK = EVIDOC.createElement('LINK')
    LINK.setAttribute('ID', "link_%s" % LINKCOUNT)
    LINK.setAttribute('SIZE', "%s" % distance)
    LINK.setAttribute('TYPE', 'BLAST_%s' % ecat)

    CONTIGA = EVIDOC.createElement('CONTIG')
    CONTIGA.setAttribute('ID', recA['subject_id'])
    CONTIGA.setAttribute('ORI', oriA)

    CONTIGB = EVIDOC.createElement('CONTIG')
    CONTIGB.setAttribute('ID', recB['subject_id'])
    CONTIGB.setAttribute('ORI', oriB)

    LINK.appendChild(CONTIGA)
    LINK.appendChild(CONTIGB)

    EVIDENCE.appendChild(LINK)

    if DEBUG: l.debug('-' * 80)

def analyzeBlast():
    """
    Analyze the blast data
    """
    for query_id, hitset in blastReader():
        print 'considering %s' % query_id
        if len(hitset.keys()) < 2:
            #this does not really add anything - continue
            continue    
        for sidA in  hitset.keys():
            setA = hitset[sidA]
            if len(setA) >
            for sidB in hitset.keys():
                if sidA == sidB: continue
                setB = hitset[sidB]
                l.info("Considering blast %s (%d) vs %s (%d)" % (
                        sidA, len(setA), sidB, len(setB)))
                for recA in setA:
                    for recB in setB:
                        analyzeBlastSet(recA, recB)

def createBambusConfig():
    ecatsort = list(ECATS)
    ecatsort.sort()
    ecatsort.reverse()
    l.info("Ecats encountered: %s" % ecatsort)
    #write the evidence xml file
    with open('%s.xml' % PREFIX, 'w') as F:
        F.write(EVIDENCE.toprettyxml())

    #write a config file
    with open("%s.conf" % PREFIX, 'w') as F:        
        for ecat in ecatsort:
            pri = PRIORITY[ecat]
            red = REDUNDANCE[ecat]
            F.write("priority BLAST_%s %s\n" % (ecat, pri))
            F.write("redundancy BLAST_%s %s\n" % (ecat, red))
            i += 1
        F.write("mingroupsize 0\n")

def runBambus():
    #touch an empty mates file
    run('touch %s.mates' % PREFIX)

    #execute bambus
    cl = ('goBambus -x %(PREFIX)s.xml -m %(PREFIX)s.mates '+
          '-c %(INPUTFILE)s -o %(PREFIX)s -C %(PREFIX)s.conf') % globals()
    run(cl)
    
    cl = ('dot -Tpng -o %(PREFIX)s.png %(PREFIX)s.dot' % globals())
    run(cl)

    cl = ('mogrify -rotate 90 %(PREFIX)s.png' % globals())
    run(cl)

    #try an untangle
    cl = 'untangle -e %s.evidence.xml -s %s.out.xml -o %s.untangled.out.xml' % (
        PREFIX,PREFIX,PREFIX)
    run(cl)

    cl = " ".join(("""
       printScaff -e %(PREFIX)s.evidence.xml
                  -s %(PREFIX)s.untangled.out.xml
                  -l %(PREFIX)s.lib
                  -o %(PREFIX)s_U
                  -dot
                  -f %(INPUTFILE)s """ % globals()).split())
    run(cl)

    cl = ('dot -Tpng -o %(PREFIX)s.untangled.png %(PREFIX)s.untangled.dot' % globals())
    run(cl)


    cl = ('mogrify -rotate 90 %(PREFIX)s.untangled.png' % globals())
    run(cl)


################################################################################
##

if __name__ == "__main__":
    prepareContigs()
    analyzeBlast()
    createBambusConfig()
    runBambus()
    
