#!/usr/bin/env python

import os
import sys
import string
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import matplotlib as mpl

inputFile = sys.argv[1]

basename, KMERSIZE, x = inputFile.rsplit('.', 2)
print 'loading', inputFile
print 'basename', basename
KMERSIZE = int(KMERSIZE)
print 'KMERSIZE', KMERSIZE

# A 0   C 1   G 2   T 3
tra = string.maketrans('acgt', '0123')
trb = string.maketrans('0123', 'acgt')

def nt2no(s):
    n = s.translate(tra)
    return int(n,4)

def baseN(num,b,numerals="0123456789abcdefghijklmnopqrstuvwxyz"):
    return ((num == 0) and  "0" ) or ( baseN(num // b, b).lstrip("0") + numerals[num % b])

def no2nt(n):    
    return (('%%%ds' % KMERSIZE) % baseN(n, 4)).translate(trb).replace(' ', 'a')

n = 0

x = np.load(inputFile)
data = x['data']
errs = x['errs']

noReads = np.sum(data[0]) + errs[0]

print 'no reads', noReads

for x in range(5):
    print len(data[x]), np.sum(data[x]), np.std(data[x]), data[x]

stds = np.apply_along_axis(np.std, 1, data)
mins = np.apply_along_axis(np.min, 1, data)
maxs = np.apply_along_axis(np.max, 1, data)

stds2 = np.apply_along_axis(np.std, 0, data)

basename = os.path.basename(inputFile).replace('.npz', '')

def create_fig():
    return plt.figure(figsize=(10,5))
#print error/positioin plot
fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('Std of kmer (%d) distribution per position' % KMERSIZE)
plt.xlabel('kmer position')
plt.ylabel('std')
ax.plot(stds)
plt.savefig('%s.std.kmers.per.position.png' % basename, dpi=200)

fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('No invalid kmers (%d) per position' % KMERSIZE)
plt.xlabel('kmer position')
plt.ylabel('no of invalid kmers')
ax.plot(errs)
plt.savefig('%s.invalid.kmers.per.position.png' % basename, dpi=200)


fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('Frequency of the most common kmer (%d) per position'  % KMERSIZE)
plt.xlabel('kmer position')
plt.ylabel('Max kmer frequency')
ax.plot(maxs)
plt.savefig('%s.maxs.pos.png' % basename, dpi=200)


fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('stdev of kmer (%d) distribution per kmer' % KMERSIZE)
plt.xlabel('different kmers')
plt.ylabel('stdev of nmer freq')
ax.plot(np.sort(stds2))
plt.savefig('%s.std.kmers.per.kmer.png' % basename, dpi=200)

legProps = {'size' : 6}
#most variable kmers
fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('frequency per position of the most variable kmers (%d)' % KMERSIZE)
plt.xlabel('position')
plt.ylabel('frequency')

i = 0
cutoff = np.sort(stds2)[-6];
for x in range(len(stds2)):
    if stds2[x] >= cutoff:
        i += 1
        ax.plot(data[:,x], label=no2nt(x))
plt.legend(prop = legProps, ncol=2)
plt.savefig('%s.freq.most.variable.kmers.png' % basename, dpi=200)

#ALL kmers
fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('frequency per position of all kmers (%d)' % KMERSIZE)
plt.xlabel('position')
plt.ylabel('frequency')

i = 0
cutoff = np.sort(stds2)[-6];
for x in range(len(stds2)):
    ax.plot(data[:,x], color='black', alpha=0.3, linewidth=1)
plt.savefig('%s.freq.all.kmers.png' % basename, dpi=200)

#ten random kmers
fig = create_fig()
ax = fig.add_subplot(1,1,1)
ax.set_title('frequency per position of ten random kmers (%d)' % KMERSIZE)
plt.xlabel('position')
plt.ylabel('frequency')

for x in range(6):
    p = random.randrange(0, 4 ** KMERSIZE)
    ax.plot(data[:,p], label=no2nt(p))
plt.legend(prop = legProps, ncol=2)
plt.savefig('%s.freq.ten.random.kmers.png' % basename, dpi=200)
