#! /usr/bin/env python

from Bio import SeqIO
import sys, subprocess, argparse
import glob
import csv, time, pandas,os
import numpy as np
from argparse import RawTextHelpFormatter

def format_for_feature(inputfile,mapfile,outputpath):
    for file_ in os.listdir(mapfile):
        if file_.endswith(".bam"):
            outname=str(file_).lstrip('.bam')
            handle = open(os.path.join(outputpath,str(outname)+".saf"),'w')
            handle.write('GeneID   Chr     Start   End     Strand\n')
            for record in SeqIO.parse(inputfile,'fasta'):
                    seqLen = len(record.seq)
                    start, stop = 1, seqLen
                    geneId, Chr = str(record.id), str(record.id)
                    outLine = [geneId, Chr, str(start), str(stop), '+']
                    outLine = '\t'.join(outLine)
                    handle.write(outLine + "\n")
def feature_counts(mapfile,outputdir,threads):
    for file_ in os.listdir(mapfile):
	if file_.endswith('.bam'):
        	outname=str(file_).lstrip('.bam')
        	mapfile_file = os.path.join(mapfile,file_)
        	subprocess.call(["featureCounts","-M","-O", "-F", "SAF","-T",str(threads),"-a", os.path.join(outputdir,str(outname)+".saf"),"-o", os.path.join(outputdir,str(outname)+".readcounts"),str(mapfile_file)],stdout=subprocess.PIPE)

def make_coverage(ids,outfile,outputdir):
    cov_data = {}    
    for line in open(str(ids), "r"):
    	line = line.rstrip()
	cov_data[line] = []
    count_filenames = [f for f in os.listdir(outputdir) if f.endswith('.readcounts')]
    for i in count_filenames:
        for line in open(os.path.join(outputdir,i), "r"):
                if line[0] != "#" or line[:6] != "Geneid":
                        line = line.rstrip()
                        data = line.split()
                        try:
                            length = len(cov_data[data[0]])
                            if length == 0:
                                cov_data[data[0]] = [data[5], data[6]]
                            if length > 0:
                                cov_data[data[0]].append(data[6])
                        except KeyError:
                                continue
    out1 = open(str(outfile)+".cov", "w")

    for k in cov_data:
        contig_length = cov_data[k].pop(0)
        coverage_list = [float(x) / float(contig_length) for x in cov_data[k]]
        coverage_list[:] = [str("%.10f" % (i)) for i in coverage_list]
        out1.write(k +"\t"+ str("\t".join(coverage_list))+"\n")

def get_log(file_,outfile):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1,len(list_)):
            list_[num_] = np.log10(((float(list_[num_])))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)
    
def get_multiple(file_,outfile,num):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1,len(list_)):
            list_[num_] = float(list_[num_])*int(num)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)  

def get_squareroot(file_,outfile):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1,len(list_)):
            list_[num_] = np.sqrt(((float(list_[num_])))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)
def get_100x_log(file_,outfile):
    cov_file = list(csv.reader(open(str(file_), 'rb'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1,len(list_)):
            list_[num_] = np.log10((float(list_[num_])*int(100))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile,sep="\t",index=False,header=False)     
########################################################################################
class Logger(object):
    def __init__(self,logfile,location):
        self.terminal = sys.stdout
        self.log = open(os.path.join(location,logfile), "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='Binsanity-profile', usage='%(prog)s -i fasta_file -s {sam,bam}_file --id contig_ids.txt -c output_file',description="""
    ***********************************************************************
    ******************************BinSanity********************************
    **                                                                   **
    **  Binsanity-profile is used to generate coverage files for         **
    **  input to BinSanity. This uses Featurecounts to generate a        **  
    **  a coverage profile and transforms data for input into Binsanity, **
    **  Binsanity-refine, and Binsanity-wf                               **
    **                                                                   **  
    ***********************************************************************
    ***********************************************************************""",formatter_class=RawTextHelpFormatter)    
    parser.add_argument("-i", dest="inputFasta", help="Specify fasta file being profiled")
    parser.add_argument("-s", dest="inputMapLoc", help="""
    identify location of BAM files
    BAM files should be indexed and sorted""") 
    parser.add_argument("--ids",dest="inputIds",help="""
    Identify file containing contig ids""")
    parser.add_argument("-c",dest="outCov",help="""
    Identify name of output file for coverage information""")
    parser.add_argument("--transform",dest="transform", default = "scale", help ="""
    Indicate what type of data transformation you want in the final file [Default:log]:
    scale --> Scaled by multiplying by 100 and log transforming
    log --> Log transform
    None --> Raw Coverage Values
    X5 --> Multiplication by 5
    X10 --> Multiplication by 10
    X100 --> Multiplication by 100
    SQR --> Square root
    We recommend using a scaled log transformation for initial testing. 
    Other transformations can be useful on a case by case basis""")
    parser.add_argument('-T',dest='Threads',default=1,type=int,help="Specify Number of Threads For Feature Counts [Default: 1]")
    parser.add_argument('-o',dest="outDirectory",default=".",help="Specify directory for output files to be deposited [Default: Working Directory]")
    parser.add_argument('--version', action='version', version='%(prog)s v0.2.6.2')
    args=parser.parse_args()  
    if len(sys.argv)<3:
        parser.print_help()
    elif args.inputIds is None:
        print "Need to identify file with contig ids `--ids`"
    elif args.inputFasta and args.inputIds is None:
        print "Need to identify tab delimited file containing contig ids using '--ids' and fasta file using '-i'"
    elif args.inputFasta and args.inputMapLoc is None:
        print "Need to identify fasta file using '-i' and directory containing SAM/BAM files using '-s'."
    elif args.inputFasta and args.outCov is None:
        print "Need to identify fasta file using '-i' and output file for coverage profile using '-c'."        
    elif args.inputFasta is None:
        print "Need to identify fasta file using '-i'."
    elif args.outCov is None:
        print "Need to identify the output file using '-c'"
    elif args.inputIds is None:
        print "Need to identify tab delimited file containing contig ids using '--ids'"
    elif args.inputMapLoc is None:
        print "Need to identify directory containing SAM files using '-s'"
###########################################
    else:
	if args.outDirectory is not ".":
		if os.path.exists(args.outDirectory) is False:
			os.mkdir(args.outDirectory)
#	sys.stdout = Logger('Binsanity-Profile.log', args.outDirectory)
        format_for_feature(args.inputFasta,args.inputMapLoc,args.outDirectory)
	        
        print("""
        ******************************************************
                    Contigs formated to generate counts
        ******************************************************
        """)
        feature_counts(args.inputMapLoc,args.outDirectory,args.Threads)
        
        make_coverage(args.inputIds,args.outCov,args.outDirectory)
###########################################
        x = os.path.join(args.outCov+".cov")
        if args.transform =="scale":
            get_100x_log(x,x+".x100.lognorm")
        elif args.transform == "log":
            print "Transforming your combined coverage profile"
            get_log(x, x+".lognorm")
        elif args.transform =="X5":
            print "Transforming your combined coverage profile"            
            get_multiple(x,x+".x5", int(5))
        elif args.transform =="X10":
            print "Transforming your combined coverage profile"            
            get_multiple(x,x+".x10", int(10))
        elif args.transform =="X100":
            print "Transforming your combined coverage profile"
            get_multiple(x,x+".x100", int(100))       
        elif args.transform =="SQR":
            print "Transforming your combined coverage profile"            
            get_squareroot(x, x+".sqrt")
        print """
        
        ********************************************************
                        Coverage profile produced
        ********************************************************
        
        """
