#!/usr/bin/env python
# 
# Copyright 2009 Mark Fiers, Plant & Food Research
# 
# This file is part of Moa - http://github.com/mfiers/Moa
# 
# Moa is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your
# option) any later version.
# 
# Moa is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
# License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with Moa.  If not, see <http://www.gnu.org/licenses/>.
# 

import re
import os
import sys
import math
import site
import bz2
import math
import tempfile
import optparse
import numpy as np

#moa specific libs - first set the path
if not os.environ.has_key('MOABASE'):
    raise Exception("MOABASE is undefined")

#process the .pth file in the $MOABASE/bin folder !
site.addsitedir(os.path.join(os.environ['MOABASE'], 'lib', 'python'))
import moa
import moa.logger as l
from moa.logger import setVerbose

usage = \
"""%s [OPTIONS] [STATS]

STATS can by any of:
 - length
 - fracn
 - id
 - scaf10
 - scaf100

 - gcfrac

"""
parser = optparse.OptionParser(usage=usage)

parser.add_option('-i', dest='infile',
    help = 'input fasta file. If undefined or "-", input is read from stdin')
parser.set_defaults(delim="\t")

parser.add_option('-d', dest='delim', help = 'output field delimiter')
parser.add_option('-x', dest='maxno', type='int', 
                  help = 'max no of sequences to process')
parser.add_option('-s', dest='stats', help = 'stats mode - generate averages',
                  action='store_true')
parser.add_option('-O', dest='onlystats', help = 'only print stats to screen - ' +
                  'do not write anything to disk', action='store_true')
parser.add_option('-S', dest='statsSimple', 
                  help = 'Simple stat generation',
                  action='store_true')
parser.set_defaults(graphengine="R")
parser.add_option('-g', dest='graphengine',
                  help='either "R" or "gnuplot" is used to generated the graphs')
parser.set_defaults(histogramsteps=50)
parser.add_option('-b', dest='histogramsteps', type="int",
                  help='no of bins to use in the histograms')
parser.add_option('-e', dest='echo', action="store_true",
    help = 'echo unknown field ids (as opposed to raise an error)')

(options, args) = parser.parse_args()
#if not options.infile:
#    raise Exception("Must define an inputfile!")

if not args:
    if options.stats:
        args = ['length']
    else:
        args = ['id', 'length']

def fastareader(f):
    if type(f) == type('hi'):
        #it is probably a filename
        if f[-4:] == '.bz2':
            F = bz2.BZ2File(f)
        else:
            F = open(f, 'r')
    else:
        #it is most likely stdin
        F = f

    name, seq = "", []
    while True:
        l = F.readline()
        if not l: break
        
        l = l.strip()
        if not l: continue

        if l[0] == '>':
            if name and seq:
                yield name, "".join(seq)
            seq = []
            name = l[1:]
        else:
            seq.append("".join(l.split()).lower())

    if name and seq:
        yield name, "".join(seq)

    F.close()

i = 0

dd = {}
if options.stats:
    for x in args:
        dd[x] = []

funcs = ['length', 'fracn', 'id', 'scaf10', 'scaf100', 'gcfrac']

def f_id(name, seq):
    return name.split()[0]

def f_length(name,seq):
    return len(seq)

def f_loglength(name,seq):
    return math.log10(len(seq))

def f_fracn(name, seq):
    return float(seq.count('n')) / len(seq)

def f_gcfrac(name, seq):
    g = seq.count('g')
    c = seq.count('c')
    a = seq.count('a')
    t = seq.count('t')
    return float(g+c) / (g + c + a + t)

scaf10rex = re.compile(r'n{10,}')
scaf100rex = re.compile(r'n{100,}')
def f_scaf10(name,seq):
    return len(scaf10rex.split(seq))

def f_scaf100(name,seq):
    return len(scaf100rex.split(seq))

#Read and interpret  fasta input 
if not options.infile or options.infile == '-':
    inp = sys.stdin
else:
    inp = options.infile

for name, seq in fastareader(inp):
    i += 1
    if options.maxno and i > options.maxno: 
        break

    if options.stats:
        for what in args:
            dd[what].append(eval('f_%s(name, seq)' % what))
    else:
        for what in args:
            if what == "name":
                sys.stdout.write("%s" % name)
            elif what in funcs:
                es = 'f_%s(name, seq)' % what
                v = eval(es)
                sys.stdout.write("%s" % v)
            else:
                if options.echo:
                    sys.stdout.write("%s" % what)
                else:
                    raise Exception("Unknown field %s" % what)
            sys.stdout.write(options.delim)
        sys.stdout.write("\n")

if not options.stats:
    sys.exit(0)

if len(dd) == 0:
    l.error("no results")
    sys.exit(0)

#start gathering of overall statistics
stats = ['no', 'min', 'max', 'sum', 'mean', 'median',
        'stdev', 'n10', 'n50', 'n90', 'n10i', 'n50i', 
         'n90i', 'over10k', 'over100k', 'over1m']
data = dict([(x, []) for x in stats])

def statAdd(d,w,v,what,store):
    store[what][w] = v
    try:
        if abs(v) < 1e-8:
            d[w].append("0       ")
        elif math.log10(abs(v)) > 6:
            d[w].append("%.2e" % v)
        elif v == int(v):
            d[w].append("%d       " % v)
        elif math.log10(abs(v)) < -3:
            d[w].append("%.2e" % v)
        else:
            d[w].append("%.4f  " % v)
    except ValueError:
        l.error("ERROR")
        l.error("probably with a math.log10 of something close to 0")
        l.error("%s %s %s" % (d,w,v))

def getNX(d, N):
    cutoff = np.sum(d) * N
    summ = 0
    for i in range(len(d)):
        x = d[i]
        result = x
        summ += x
        if summ > cutoff: break
    return result, i

binsteps = options.histogramsteps

exportData = {}
exportHist = {}
exportHistBin = {}
storeStats = {}

#print 'looking at args %s' % args
for what in args:
    storeStats[what] = {}
    rawd = dd[what]

    if type(rawd[0]) == type('hi'):
        continue

    d = np.array(rawd)
    dsort = np.sort(d)[::-1]
    exportData[what] = d

    _n50, _n50no = getNX(dsort, 0.5)
    _n10, _n10no = getNX(dsort, 0.1)
    _n90, _n90no = getNX(dsort, 0.9)
        
    statAdd(data, 'no', len(d), what, storeStats)
    statAdd(data, 'max', dsort[0], what, storeStats)
    statAdd(data, 'min', dsort[-1], what, storeStats)
    statAdd(data, 'sum', np.sum(d), what, storeStats)
    statAdd(data, 'mean', np.mean(d), what, storeStats)
    statAdd(data, 'median', np.median(d), what, storeStats)
    statAdd(data, 'stdev', np.std(d), what, storeStats)
    statAdd(data, 'n10', _n10, what, storeStats)
    statAdd(data, 'n10i', _n10no, what, storeStats)
    statAdd(data, 'n50', _n50, what, storeStats)
    statAdd(data, 'n50i', _n50no, what, storeStats)
    statAdd(data, 'n90', _n90, what, storeStats)
    statAdd(data, 'n90i', _n90no, what, storeStats)
    statAdd(data, 'over10k', len(np.where(d>=1e4)[0]), what, storeStats)
    statAdd(data, 'over100k', len(np.where(d>=1e5)[0]), what, storeStats)
    statAdd(data, 'over1m', len(np.where(d>=1e6)[0]), what, storeStats)

    hist, edges = np.histogram(d, binsteps)
    exportHist[what] = hist
    exportHistBin[what] = edges

#print stats to screen
if options.statsSimple:
    for i, x in enumerate(args):
        pv = [x]
        for stat in stats:
            pv.append(data[stat][i])
        print "\t".join(pv)
else:
    print "".join(["%20s" % x for x in ([""] + args)])
    for stat in stats:
        print "".join(["%20s" % x for x in ([stat] + data[stat])])

if options.onlystats:
    sys.exit(0)
    
#write the raw data to file
H,dataFile = tempfile.mkstemp()
F = os.fdopen(H, 'w')
for a in args:
    F.write("%s\t" % a)
F.write("\n")

for i in range(len(exportData[args[0]])):
    for a in args:
        if not exportData.has_key(a):
            continue        
        F.write("%g\t" % exportData[a][i])
    F.write("\n")
F.close()

#write the histograms to a file
H,histFile = tempfile.mkstemp()
F = os.fdopen(H, 'w')
for i in range(len(exportHist[args[0]])):
    for a in args:
        F.write("%f\t" % exportHistBin[a][i])
        F.write("%f\t" % exportHist[a][i])
    F.write("\n")
F.close()

l.debug("#export data %s hist %s" % (dataFile, histFile))

if options.infile and options.infile != '-':
    filebasename = os.path.basename(options.infile)
else:
    filebasename = 'fastainfo'

if '.' in filebasename:
    filebasename = filebasename[:filebasename.rfind('.')]

if options.graphengine.lower() == 'gnuplot':
    #create gnuplot script
    H,plotFile = tempfile.mkstemp()
    F = os.fdopen(H, 'w')
    noplots = len(args)
    F.write('''
    set terminal png nocrop enhanced size 1000,600 
    set output "%(filebasename)s.stat.png"
    set logscale y
    set yrange [0.1:]
    set style fill solid
    set boxwidth 0.9
    set multiplot layout %(noplots)d, 1 title "Stats of %(infilename)s"
    ''' % locals())

    i = 1
    for a in args:
        j = i + 1
        F.write('''
        plot "%(histFile)s" using %(i)d:%(j)d w filledcurve above x1 t "%(a)s"
        ''' % locals())
        i += 2


    F.close()
    os.system("gnuplot %s" % plotFile)

    print "#gnuplot script %s" % (plotFile)

elif options.graphengine.lower() == 'r':
    breaks = options.histogramsteps
    Rscript = '''
        D <- data.frame(read.table("%(dataFile)s", header=T))        
        '''% locals()

    if len(args) > 1:
        Rscript += '''

        png("%(filebasename)s.all.png", width=1000, 800, pointsize=15)
        plot(D,main="All vs All %(args)s plot for %(filebasename)s")
        dev.off()

        ''' % locals()

    for a in args:
        n10 = storeStats[a]['n10']
        n50 = storeStats[a]['n50']
        n90 = storeStats[a]['n90']
        nosq = storeStats[a]['no']
        sht = (storeStats[a]['max'] - storeStats[a]['min'] ) / 40
        Rscript += '''

        png("%(filebasename)s.%(a)s.hist.png", width=900, 700, pointsize=15)
        hgr <- hist(D$%(a)s, xlab="%(a)s", breaks=%(breaks)s, col="#ffcc66", main="")
        maxy <- max(hgr$counts) 
        title("%(a)s histogram for %(filebasename)s")
        abline(v="%(n10)f",col="darkred", lwd=5)
        abline(v="%(n50)f",col="darkblue", lwd=5)
        abline(v="%(n90)f",col="darkgreen", lwd=5)
        text(y=maxy,x=%(n10)f + %(sht)f, bg="white",labels="n10", cex=1)
        text(y=maxy,x=%(n50)f + %(sht)f, bg="white",labels="n50", cex=1)
        text(y=maxy,x=%(n90)f + %(sht)f, bg="white",labels="n90", cex=1)
        dev.off()

        png("%(filebasename)s.%(a)s.curve.png", width=900, 700, pointsize=15)
        thisd <- D$%(a)s
        plot(sort(thisd), type="b", ylab="%(a)s", xlab="index", col="black", main="")
        title("%(a)s curve for %(filebasename)s")
        abline(h="%(n10)f",col="darkred", lwd=3)
        abline(h="%(n50)f",col="darkblue", lwd=3)
        abline(h="%(n90)f",col="darkgreen", lwd=3)
        text(x=%(nosq)d,y=%(n10)f + %(sht)f, bg="white",labels="n10", cex=1)
        text(x=%(nosq)d,y=%(n50)f + %(sht)f, bg="white",labels="n50", cex=1)
        text(x=%(nosq)d,y=%(n90)f + %(sht)f, bg="white",labels="n90", cex=1)
        dev.off()


        ''' % locals()
    
    H,RFile = tempfile.mkstemp()
    F = os.fdopen(H, 'w')
    F.write(Rscript)
    F.close()
    os.system("R --vanilla < %(RFile)s >> R.out 2>> R.err" % locals())
    l.debug("#R script %s" % (RFile))
