#!/usr/bin/python
__version__ = '1.7'

##----PACKAGE------##
import argparse
import time
import sys
import os
from argparse import RawTextHelpFormatter
from pkg_resources import resource_filename
from multiprocessing import Pool

##----MAIN---------##
def main():
	args = get_args()

	if args.which == 'one':
		run_mode_one(args)

	elif args.which == 'more':
		run_mode_more(args)

##----CODA---------##
	make_ornament('> END', 100, ' ', 1, 1)
	make_ornament('', 100, '-', 0, 0)
#	print ('\t|' + ' '*47 + '__.' + ' '*48 + '|\n'
#'\t|' + ' '*32 + '___.  ____.   |  |  __. __.__.   __.' + ' '*30 + '|\n'
#'\t|' + ' '*30 + '_/ ___\ \__  \  |  | <   y  |\  \ /  /' + ' '*30 + '|\n'
#'\t|' + ' '*30 + '\  c___  /  a \_|  l__\___  | >  x  <' + ' '*31 + '|\n'
#'\t|' + ' '*31 + '\_____>(______/|____//_____|/__/ \__\\' + ' '*30 + '|\n'
#'\t|' + '~'*42 + 'www.calyx.biz' + '~'*43 + '|\n\n')

##----FUNCTION-----##
#---get args---
def get_args():
	tool = os.path.basename(sys.argv[0])
	author = 'Yingxiang Li'
	email = 'xlccalyx@gmail.com'
	date = 'Apr 15, 2016'
	update_date = '041316,042116'
	home = 'www.calyx.biz'

	parser = argparse.ArgumentParser(description='\ttool:   ' + tool + ' v' + __version__ + '\n\tdate:   ' + date + '\n\tauthor: ' + author + ' (' + email + ')\n\thome:   ' + home + '\n\tMUST-install (NOT guaranteed on other versions):\n\t        bwa: 0.7.5a; fastqc: v0.11.2; samtools: 1.3; java: 1.7.0_95', prog=tool, formatter_class=RawTextHelpFormatter)

	parser.add_argument('-V', '--version', action='version', version='%(prog)s v' + __version__)

	subparser = parser.add_subparsers(help='select 1 of 2 modes! CIpipe mode -h for further help!')

#---subparser for mode one
	subparser_one = subparser.add_parser('one', help='mode one, for one sample analysis.')
	subparser_one.set_defaults(which='one')
	subparser_one.add_argument('-R', '--reference', help='sample reference file, fasta format. (eg: my_ref.fa)', required=True)
	subparser_one.add_argument('-D', '--data', help='sample data directory, fastq-ONLY. one file for single end, two files for paired end. (eg: my_data/)', required=True)
	subparser_one.add_argument('-O', '--output', help='output directory, will be created if not exists. (eg: my_output/)', required=True)

	subparser_one.add_argument('-N', '--name', help='sample name, default is name of output directory. (eg: my_sample)', default='NoName')
	subparser_one.add_argument('-RK', '--rank', help='sample rank. (eg: 1)', default='')

	subparser_one.add_argument('-P', '--pvalue', help='minimal p value, default: 0.05.', default='0.05')
	subparser_one.add_argument('-B', '--basequality', help='minimal base quality, default: 30.', default='30')
	subparser_one.add_argument('-A', '--varfreq', help='minimal variant frequency, default: 0.0001.', default='0.0001')

	subparser_one.add_argument('-T', '--target', help='CRISPR target position. indel in target range will be picked out, mutiple targets separated by \';\', default: NoTarget. (eg: gene1:100;gene2:200)', default='NoTarget')
	subparser_one.add_argument('-US', '--upstream', help='up stream distance from CRISPR target position, default: 20.', default='20')
	subparser_one.add_argument('-DS', '--downstream', help='down stream distance from CRISPR target position, default: 10.', default='10')

	subparser_one.add_argument('-F', '--fastqc', help='fastq quality control by FastQC, default: ON. -F will turn OFF.', action='store_false', default=True)
	subparser_one.add_argument('-X', '--index', help='build reference index by BWA, default: ON. -X will turn OFF.', action='store_false', default=True)

	subparser_one.add_argument('-U', '--unlimited', help='no read depth limit in mpileup by SAMtools, default: OFF.', action='store_true')
	subparser_one.add_argument('-VI', '--indel', help='search for indel by VarScan, default: ON. -I will turn OFF.', action='store_false', default=True)
	subparser_one.add_argument('-VS', '--snp', help='search for SNP by VarScan, default: OFF.', action='store_true')
	subparser_one.add_argument('-VC', '--consensus', help='search for consensus call by VarScan, default: OFF.', action ='store_true')
	subparser_one.add_argument('-VR', '--readcount', help='search for read counts by VarScan, default: OFF.', action='store_true')

#---subparser for mode more
	subparser_more_raw = subparser.add_parser('more', help='mode more, for mutiple samples and advanced analysis. run: \'CIpipe more -E\' first.')
	subparser_more_raw.set_defaults(which='more')
	subparser_more = subparser_more_raw.add_mutually_exclusive_group()
	subparser_more.add_argument('-E', '--example', help='create example input data. modify the example.input.tab to fit your data.', default=False, action='store_true')
	subparser_more.add_argument('-I', '--input', help='information table of all input data. all settings should be in it. (eg. example.input.tab)')

#---head
	args = parser.parse_args()

	print '\n\n\t' + ' '.join(sys.argv[:]) + '\n'
	make_ornament('', 100, '-', 0, 0)
	make_ornament('tool:   ' + tool + ' v' + __version__, 100, ' ', 0, 0)
	make_ornament('author: ' + author + ' (' + email + ')', 100, ' ', 0, 0)
	make_ornament('', 100, '-', 0, 0)
	make_ornament('> BEGIN', 100, ' ', 1, 1)

	return args

#---run mode one---
def run_mode_one(args):
	preset_one = run_preset_one(args)

	if not preset_one:
		make_ornament('please fix the problems above and re-try!', 100, ' ', 0, 0)

	else:
		name = preset_one

		if os.path.exists(os.path.normpath(args.output) + '/' + name + '/'):
			make_ornament('WARNING! output directory exists.', 100, ' ', 1, 0)

		#-output & log directory
		output_dir = make_dir(os.path.normpath(args.output) + '/' + name + '/')
		log_dir = make_dir(output_dir + '/log/')
		done_one_file = log_dir + 'done'
		if os.path.isfile(done_one_file):
			os.remove(done_one_file)
		#-fastq file
		fastq1_file = args.data + sorted(os.listdir(args.data))[0]
		fastq2_file = '' if len(os.listdir(args.data)) <= 1 else args.data + sorted(os.listdir(args.data))[1]
	
		#---bwa index
		run_bwa_index(args, log_dir)

		#---fastqc quality control
		run_fastqc_quality_control(args, output_dir, fastq1_file, fastq2_file, log_dir)

		#---bwa map
		bwa_map = run_bwa_map(fastq2_file, output_dir, name, args, fastq1_file, log_dir)

		if bwa_map:
			#---samtools sam to bam
			samtools_sam_to_bam = run_samtools_sam_to_bam(args, output_dir, name, log_dir)

			if samtools_sam_to_bam:
				#---samtools sort index
				samtools_sort_index = run_samtools_sort_index(args, output_dir, name, log_dir)

				if samtools_sort_index:
					#---samtools mpileup
					samtools_mpileup = run_samtools_mpileup(args, output_dir, name, log_dir)

					if samtools_mpileup:
						varscan = resource_filename(os.path.basename(sys.argv[0]), 'VarScan.v2.3.9.jar')
						#---varscan indel
						run_varscan_indel(args, output_dir, name, varscan, log_dir)
						#---varscan snp
						run_varscan_snp(args, output_dir, name, varscan, log_dir)
						#---varscan consensus call
						run_varscan_consensus_call(args, output_dir, name, varscan, log_dir)
						#---varscan read count
						run_varscan_read_count(args, output_dir, name, varscan, log_dir)

						#---result
						result_dir = make_dir(output_dir + 'result/')
						run_samtools_flagstat(args, output_dir, name, log_dir)
						run_map_infor(args, output_dir, name, result_dir)
						run_indel_brief_result(args, output_dir, name, result_dir)
						run_indel_near_target(args, output_dir, name, result_dir)
						plot_indel_detail(args, output_dir, name, result_dir, log_dir)

						#---finish information
						make_ornament('CONGRATS! \'' + args.name + '\' is finished!', 100, ' ', 1, 0)
						write_content(log_dir + 'done', sys.argv[:])

					else:
						make_ornament('please fix SAMtools mpileup and re-try!', 100, ' ', 0, 0)

				else:
					make_ornament('please fix SAMtools sort and re-try!', 100, ' ', 0, 0)

			else:
				make_ornament('please fix SAMtools sam to bam and re-try!', 100, ' ', 0, 0)

		else:
			make_ornament('please fix BWA map and re-try!', 100, ' ', 0, 0)

#---run preset one mode--
def run_preset_one(args):
	preset_one = False
	if not args.reference.endswith('.fa') and not args.reference.endswith('.fasta'):
		make_ornament('ABORT! -R file. should be fa(sta) format!', 100, ' ', 1, 0)
		return preset_one		
	if not os.path.isdir(args.data):
		make_ornament('ABORT! -D data. should be a directory!', 100, ' ', 1, 0)
		return preset_one
	else:
		fastq_file = [x for x in os.listdir(args.data) if x.endswith('fq') or x.endswith('fastq')]
		if len(fastq_file) > 2:
			make_ornament('ABORT! -D data. no more than 2 fastq-ONLY files!', 100, ' ', 1, 0)
			return preset_one
		elif len(fastq_file) == 0:
			make_ornament('ABORT! -D data. no fastq file in the directory!', 100, ' ', 1, 0)
			return preset_one
		elif len(fastq_file) < len(os.listdir(args.data)):
			make_ornament('ABORT! -D data. remove fastq-NOT files!', 100, ' ', 1, 0)
			return preset_one
		else:
			if args.name == 'NoName':
				preset_one = [x for x in args.output.split('/') if x != ''][-1]
				make_ornament('WARNING! directory name will be assigned as output name.', 100, ' ', 1, 0)
			else:
				preset_one = args.name
			return preset_one

#---run mode more---
def run_mode_more(args):
	if args.example:
		example_input_file = resource_filename(os.path.basename(sys.argv[0]), 'example.input.tab')
		os.system('cp ' + example_input_file + ' .')
		make_ornament('example.input.tab created in current dir, please modify it!', 100, ' ', 0, 0)

	else:
		preset_more = run_preset_more(args)
		if not preset_more:
			make_ornament('please fix the problems above and re-try!', 100, ' ', 0, 0)

		else:
			log_dir, thread_number, args_more = preset_more
			make_ornament('\'CIpipe more\' is running, more details are in batch/sample/log/!', 100, ' ', 1, 0)
			pool = Pool(thread_number) 
			pool_result = pool.map(run_mode_one, args_more)
			pool.close() 
			pool.join()

			run_indel_matrix(args_more)

			make_ornament('CONGRATS! \'CIpipe more\' is finished!', 100, ' ', 1, 0)
			write_content(log_dir + 'done', sys.argv[:])

#---run preset more mode--
def run_preset_more(args):
	if not os.path.isfile(args.input):
		make_ornament('ABORT! -I input. should be input file!', 100, ' ', 1, 0)
		return False
	else:
		input_table_default = open(resource_filename('CIpipe', 'example.input.tab'), 'rU').readlines()
		input_key_default = [x.split('\t')[0].lstrip() for x in input_table_default]
		input_value_default = [x.rstrip().split('\t')[1:] for x in input_table_default]
		input_table = open(args.input, 'rU').readlines()
		input_key = [x.split('\t')[0].lstrip() for x in input_table]
		input_value = [x.rstrip().split('\t')[1:] for x in input_table]
		input_dict =  dict(zip(input_key, input_value))
		if not input_key == input_key_default:
			make_ornament('ABORT! input.tab parameter names are not default!', 100, ' ', 1, 0)
			return False
		else:
			output_dir = make_dir(input_dict['output'][0] + input_dict['batch'][0] + '/')
			log_dir = make_dir(output_dir + input_dict['batch'][0] + '.log/')
			thread_number = int(input_dict['thread'][0])
			group_order = get_group_order(input_dict['group'])
			args_more = [get_args_one(input_dict, name, group_order) for name in input_dict['name']]
			preset_more = (log_dir, thread_number, args_more)
			return preset_more

#args_dict = {'input':'/data/tongji1/liyx/CSIA/Test/Input/GB.Input.tab'}
#args=get_class_from_dict(**args_dict)

#---get args one from mode more---
def get_args_one(input_dict, name, group_order):
	name_group = [input_dict['name'].index(name) + 1 in x for x in group_order].index(1)
	args_one_value = []
	for key in input_dict.keys():
		if len(input_dict[key]) == len(input_dict['name']):
			args_one_value.append(input_dict[key][input_dict['name'].index(name)])
		elif len(input_dict[key]) == len(input_dict['group']):
			args_one_value.append(input_dict[key][name_group])
		else:
			if input_dict[key][0] == 'ON' or input_dict[key][0] == 'OFF':
				args_one_value.append([True, False][input_dict[key][0] == 'OFF'])
			else:
				args_one_value.append(input_dict[key][0])
	args_one_dict = dict(zip(input_dict.keys(), args_one_value))
	args_one_dict['output'] = args_one_dict['output'] + input_dict['batch'][0] + '/'
	args_one_dict['rank'] = str(input_dict['name'].index(name) + 1)
	args_one = get_class_from_dict(**args_one_dict)
	return args_one

#--common--
class get_class_from_dict:
	def __init__(self, **entries): 
		self.__dict__.update(entries)

def make_dir(dir):
	dir = dir.strip().rstrip("\\")
	if not os.path.exists(dir):
		os.makedirs(dir)
	return dir

def write_content(content_file, content):
	output = open(content_file, 'w')
	output.writelines(content)
	output.close()

def run_bash_command(log_dir, command_name, command):
	command_file = make_dir(log_dir) + command_name + '.sh'
	write_content(command_file, command)
	bash_command = 'bash ' + command_file + ' > ' + command_file.replace('.sh', '.log') + ' 2>&1'
	os.system(bash_command)

def make_ornament(title, width=100, ornament_type=' ', show_time=1, show_date = 0):
	if show_time == 1:
		if show_date == 0:
			ornament = '\t|' + title + ornament_type*(width - 2 - len(title) - 11) + ' @ ' + time.strftime("%X", time.localtime()) + '|'
		else:
			ornament = '\t|' + title + ornament_type*(width - 2 - len(title) - 22) + ' @ ' + time.strftime("%Y-%m-%d %X", time.localtime()) + '|'
	else:
		ornament = '\t|' + title + ornament_type*(width - 2 - len(title)) + '|'
	print ornament

def get_process_time(function_name, is_finish=0, width=100, indent=16):
	function_name_indent = ' '*(indent - len(function_name.split(':')[0])) + function_name
	if is_finish == 0:		
		make_ornament(function_name_indent + ' '*(width - 23 - len(function_name_indent)) + '  -running', width)
	else:
		make_ornament(function_name_indent + ' '*(width - 23 - len(function_name_indent)) + ' -done', width)

def get_absolute_file(file):
	split_file = [x for x in file.split('/') if x != '']
	current_dir = os.getcwd()
	split_current_dir = [x for x in current_dir.split('/') if x != '']
	if len(set(split_file)&set(split_current_dir)) == 0:
		absolute_file = current_dir + '/' + file
	else:
		absolute_file = file
	if os.path.isfile(absolute_file):
		return absolute_file
	else:
		return 'WRONG file or directory!'

def get_file_size(file):
	import os
	file_size = os.path.getsize(file)
	unit = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
	unit_order = 0
	while len(str(file_size)) >= 5:
		former_file_size = file_size
		former_unit_order = unit_order
		file_size = round(file_size/1024.0, 1)
		unit_order += 1
	return str(former_file_size) + ' ' + unit[former_unit_order]

def add_thousand_separator(int_number):
	return str(format(int(int_number), ','))

def make_initial_upper(word):
	initial_upper = word[0].upper() + word[1:].lower()
	return initial_upper

def get_group_order(group):
	group_order = []
	for group_one in group:
		group_one_flat = []
		for group_one_split in group_one.split(','):
			if len(group_one_split.split('-')) == 1:
				group_one_flat.append(int(group_one_split))
			else:
				group_one_flat = group_one_flat + range(int(group_one_split.split('-')[0]), int(group_one_split.split('-')[1]) + 1)
		group_order = group_order + [group_one_flat]
	return group_order

#---run bwa index--
def run_bwa_index(args, log_dir):
	if args.index:
		get_process_time('bwa: index -' + args.rank)
		bwa_index = 'bwa index -a bwtsw ' + args.reference
		refer_name = os.path.basename(os.path.splitext(args.reference)[0])
		run_bash_command(log_dir, 'BWA_Index.' + refer_name, bwa_index)
		get_process_time('bwa: index -' + args.rank, 1)

#---run FastQC quality control--
def run_fastqc_quality_control(args, output_dir, fastq1_file, fastq2_file, log_dir):
	if args.fastqc:	
		get_process_time('fastqc: quality control -' + args.rank)
		quality_control_dir = make_dir(output_dir + 'FastQC/')
		fastqc_quality_control = 'fastqc -q --extract -o ' + quality_control_dir + ' ' + fastq1_file + ' ' + fastq2_file
		run_bash_command(log_dir, 'FastQC_QualiyControl', fastqc_quality_control)
		get_process_time('fastqc: quality control -' + args.rank, 1)
		if len(os.listdir(quality_control_dir)) == 0:
			make_ornament('WARNING! no fastqc result! please check FastQC_QualiyControl.log!', 100, ' ', 1, 0)

#---run bwa map--
def run_bwa_map(fastq2_file, output_dir, name, args, fastq1_file, log_dir):
	is_pair = ['pair', 'single'][fastq2_file == '']
	get_process_time('bwa: map ' + is_pair + ' -' + args.rank)
	map_file = make_dir(output_dir + 'BWA/') + name + '.sam'
	bwa_map = '''bwa mem -t 10 -R "@RG\\tID:''' + name + '.BWA_map.' + is_pair + '\\tLB:bwa\\tPL:NA\\tSM:' + name + '\" ' + args.reference + ' ' + fastq1_file + ' ' + fastq2_file + ' > '+ map_file
	run_bash_command(log_dir, 'BWA_Map', bwa_map)	
	get_process_time('bwa: map -' + args.rank, 1)
	if not os.path.isfile(map_file):
		make_ornament('ABORT! no bwa result! please check BWA_Map.log!', 100, ' ', 1, 0)
		return False
	else:
		return True
	
#---samtools: sam to bam--
def run_samtools_sam_to_bam(args, output_dir, name, log_dir):
	get_process_time('samtools: sam to bam -' + args.rank)
	map_file = output_dir + 'BWA/' + name + '.sam'
	bam_file = make_dir(output_dir + 'SAMtools/') + name + '.bam'
	samtools_sam_to_bam = 'samtools view -bhS ' + map_file + ' -o ' + bam_file
	run_bash_command(log_dir, 'SAMtools_SamToBam', samtools_sam_to_bam)
	get_process_time('samtools: sam to bam -' + args.rank, 1)
	if not os.path.isfile(bam_file):
		make_ornament('ABORT! no bam result! please check SAMtools_SamToBam.log!', 100, ' ', 1, 0)
		return False
	else:
		return True	

#---samtools: sort & index--
def run_samtools_sort_index(args, output_dir, name, log_dir):
	get_process_time('samtools: sort & index -' + args.rank)
	bam_file = output_dir + 'SAMtools/' + name + '.bam'
	sort_bam_file = bam_file.replace('.bam', '.sort.bam')
	samtools_sort = 'samtools sort ' + bam_file + ' -o ' + sort_bam_file
	samtools_index = 'samtools index ' + sort_bam_file
	run_bash_command(log_dir, 'SAMtools_Sort', samtools_sort)
	run_bash_command(log_dir, 'SAMtools_Index', samtools_index)	
	get_process_time('samtools: sort & index -' + args.rank, 1)

	if not os.path.isfile(sort_bam_file):
		make_ornament('ABORT! no bam sort result! please check SAMtools_Sort.log!', 100, ' ', 1, 0)
		return False
	else:
		return True

#---samtools: mpileup--
def run_samtools_mpileup(args, output_dir, name, log_dir):
	get_process_time('samtools: mpileup' + ['', ' (unlimited: True)'][args.unlimited] + ' -' + args.rank)
	sort_bam_file = output_dir + 'SAMtools/' + name + '.sort.bam'
	mpileup_file = sort_bam_file.replace('.sort.bam', '.mpu')
	samtools_mpileup = 'samtools mpileup%s -f %s %s > %s' % (['', ' -d10000000'][args.unlimited], args.reference, sort_bam_file, mpileup_file)
	run_bash_command(log_dir, 'SAMtools_Mpileup', samtools_mpileup)
	get_process_time('samtools: mpileup -' + args.rank, 1)

	if not os.path.isfile(mpileup_file):
		make_ornament('ABORT! no samtools mpileup result! please check SAMtools_Mpileup.log!', 100, ' ', 1, 0)
		return False
	else:
		return True

#---varscan: indel--
def run_varscan_indel(args, output_dir, name, varscan, log_dir):
	if args.indel:
		get_process_time('varscan: indel (base quality:%s,var freq:%s,pvalue:%s)' % (args.basequality, args.varfreq, args.pvalue) + ' -' + args.rank)
		mpileup_file = output_dir + 'SAMtools/' + name + '.mpu'
		indel_file = make_dir(output_dir + 'VarScan/') + name + '.indel.tab'
		varscan_indel = 'java -jar %s pileup2indel %s --min-avg-qual %s --min-var-freq %s --p-value %s > %s' % (varscan, mpileup_file, args.basequality, args.varfreq, args.pvalue, indel_file)
		run_bash_command(log_dir, 'VarScan_Indel', varscan_indel)
		get_process_time('varscan: indel -' + args.rank, 1)

		if not os.path.isfile(indel_file):
			make_ornament('WARNING! no varscan indel result! please check VarScan_Indel.log!', 100, ' ', 1, 0)

#---varscan: snp--
def run_varscan_snp(args, output_dir, name, varscan, log_dir):
	if args.snp:
		get_process_time('varscan: snp (base quality:%s,var freq:%s,pvalue:%s)' % (args.basequality, args.varfreq, args.pvalue) + ' -' + args.rank)
		mpileup_file = output_dir + 'SAMtools/' + name + '.mpu'
		snp_file = make_dir(output_dir + 'VarScan/') + name + '.snp.tab'
		varscan_snp = 'java -jar %s pileup2snp %s --min-avg-qual %s --min-var-freq %s --p-value %s > %s' % (varscan, mpileup_file, args.basequality, args.varfreq, args.pvalue, snp_file)
		run_bash_command(log_dir, 'VarScan_Snp', varscan_snp)
		get_process_time('varscan: snp -' + args.rank, 1)

		if not os.path.isfile(snp_file):
			make_ornament('WARNING! no varscan snp result! please check VarScan_Snp.log!', 100, ' ', 1, 0)

#---varscan: consensus call--
def run_varscan_consensus_call(args, output_dir, name, varscan, log_dir):
	if args.consensus:
		get_process_time('varscan: consensus call (base quality:%s,pvalue:%s)' % (args.basequality, args.pvalue) + ' -' + args.rank)
		mpileup_file = output_dir + 'SAMtools/' + name + '.mpu'
		consensus_call_file = make_dir(output_dir + 'VarScan/') + name + '.consensus.tab'
		varscan_consensus_call = 'java -jar %s pileup2cns %s --min-avg-qual %s --p-value %s > %s' % (varscan, mpileup_file, args.basequality, args.pvalue, consensus_call_file)
		run_bash_command(log_dir, 'VarScan_ConseCall', varscan_consensus_call)
		get_process_time('varscan: consensus call -' + args.rank, 1)

		if not os.path.isfile(consensus_call_file):
			make_ornament('WARNING! no varscan consensus call result! please check VarScan_ConseCall.log!', 100, ' ', 1, 0)

#---varscan: read count--
def	run_varscan_read_count(args, output_dir, name, varscan, log_dir):
	if args.readcount:
		get_process_time('varscan: read count (base quality:%s)' % (args.basequality) + ' -' + args.rank)
		mpileup_file = '%sSAMtools/%s.mpu' % (output_dir, name)
		read_count_file = make_dir(output_dir + 'VarScan/') + name + '.read.tab'
		varscan_read_count = 'java -jar %s readcounts %s --min-base-qual %s --output-file %s' % (varscan, mpileup_file, args.basequality, read_count_file)
		run_bash_command(log_dir, 'VarScan_ReadCount', varscan_read_count)
		get_process_time('varscan: read count -' + args.rank, 1)

		if not os.path.isfile(read_count_file):
			make_ornament('WARNING! no varscan read counts result! please check VarScan_ReadCount.log!', 100, ' ', 1, 0)

#---run samtools flagstat--
def run_samtools_flagstat(args, output_dir, name, log_dir):
	get_process_time('samtools: flagstat -' + args.rank)
	sort_bam_file = '%sSAMtools/%s.sort.bam' % (output_dir, name)
	flagstat_file = sort_bam_file.replace('.sort.bam', '.flagstat.txt')
	samtools_flagstat = 'samtools flagstat %s > %s' % (sort_bam_file, flagstat_file)
	run_bash_command(log_dir, 'SAMtools_FlagStat', samtools_flagstat)
	get_process_time('samtools: flagstat -' + args.rank, 1)

	if not os.path.isfile(flagstat_file):
		make_ornament('WARNING! no samtools flagstat! please check SAMtools_FlagStat.log!', 100, ' ', 1, 0)

#---run map infor
def run_map_infor(args, output_dir, name, result_dir):
	flagstat_file = '%sSAMtools/%s.flagstat.txt' % (output_dir, name)

	if not os.path.isfile(flagstat_file):
		make_ornament('WARNING! no samtools flagstat! get map infor stopped!', 100, ' ', 1, 0)

	else:
		get_process_time('get: map infor -' + args.rank)
		map_infor_file = result_dir + name + '.map.infor.txt'
		map_infor = ['sample\tread_number\tproperly_mapped_number\tratio\n']
		flagstat_content = open(flagstat_file, 'rU').readlines()
		map_infor.append(name + '\t' + add_thousand_separator(flagstat_content[5].split(' ')[0]) + '\t' + add_thousand_separator(flagstat_content[8].split(' ')[0]) + '\t' + flagstat_content[8].split('paired (')[1].split(' :')[0] + '\n')
		write_content(map_infor_file, map_infor)
		get_process_time('get: map infor -' + args.rank, 1)

#---run indel brief result--
def	run_indel_brief_result(args, output_dir, name, result_dir):
	indel_file = output_dir + 'VarScan/' + name + '.indel.tab'	

	if not os.path.isfile(indel_file):
		make_ornament('WARNING! no varscan indel result! collect indel brief result stopped!', 100, ' ', 1, 0)

	else:
		get_process_time('get: indel brief result -' + args.rank)
		indel_content = open(indel_file, 'rU').readlines()
		indel_brief_result = ['\t'.join(x.split()[:7]) + '\n' for x in indel_content]
		indel_brief_result_file = result_dir + name + '.indel.brief.tab'
		write_content(indel_brief_result_file, indel_brief_result)
		get_process_time('get: indel brief result -' + args.rank, 1)

#---run indel near target--
def	run_indel_near_target(args, output_dir, name, result_dir):
	if not args.target == 'NoTarget':
		indel_file = output_dir + 'VarScan/' + name + '.indel.tab'
		if not os.path.isfile(indel_file):
			make_ornament('WARNING! no varscan indel result! indel near target search stopped!', 100, ' ', 1, 0)
		else:
			get_process_time('get: indel near target (%s, -%snt, +%snt)' % (args.target, args.upstream, args.downstream) + ' -' + args.rank)
			indel_content = open(indel_file, 'rU').readlines()
			indel_near_target_file = result_dir + name + '.indel.target.tab'
			indel_near_target = [indel_content[0]]
			for target_position in args.target.split(';'):
				target_chromo, target_cut = target_position.split(':')
				indel_chromo = [x for x in indel_content if x.startswith(target_chromo)]
				target_upstream = int(target_cut) - int(args.upstream)
				target_downstream = int(target_cut) + int(args.downstream)
				indel_near_target_one = [x for x in indel_chromo if int(x.split()[1]) >= target_upstream and int(x.split()[1]) <= target_downstream]
				indel_near_target = indel_near_target + indel_near_target_one
			write_content(indel_near_target_file, indel_near_target)
			get_process_time('get: indel near target -' + args.rank, 1)

#---plot indel detail--
def	plot_indel_detail(args, output_dir, name, result_dir, log_dir):
	if not args.target == 'NoTarget':
		indel_near_target_file = result_dir + name + '.indel.target.tab'	

		if not os.path.isfile(indel_near_target_file):
			make_ornament('WARNING! no indel near target result! plot indel detail stopped!', 100, ' ', 1, 0)

		else:
			get_process_time('plot: indel detail (%s, -%snt, +%snt)' % (args.target, args.upstream, args.downstream) + ' -' + args.rank)
			reference_dict = get_fasta_dict(args.reference)
			#reference_dict = get_fasta_dict('/data/tongji1/liyx/CSIA/CSIA.Test/Refer/iGFP_448bp.fa')
			#indel_near_target_file = '/data/tongji1/liyx/CSIA/CSIA.Test/Output/iGFP/VarScan/iGFP.indel.tab'
			indel_near_target = open(indel_near_target_file, 'rU').readlines()

			for target_position in args.target.split(';'):
				#target_position = 'iGFP_448bp:235'
				target_chromo, target_cut = target_position.split(':')
				indel_chromo = [x for x in indel_near_target if x.startswith(target_chromo)]
				target_upstream = int(target_cut) - int(args.upstream)
				target_downstream = int(target_cut) + int(args.downstream)
				indel_near_target_one = [x for x in indel_chromo if int(x.split()[1]) >= target_upstream and int(x.split()[1]) <= target_downstream]
				reference_sequence = reference_dict[target_chromo][target_upstream:target_downstream].upper()
				R_file = log_dir + name + '.indel.' + target_position + '.r'
				#R_file = '/data/tongji1/liyx/CSIA/CSIA.Test/Output/iGFP/log/iGFP.indel.' + target_position + '.r'
				pdf_file = result_dir + name + '.indel.' + target_position + '.pdf'
				#pdf_file = '/data/tongji1/liyx/CSIA/CSIA.Test/Output/iGFP/result/iGFP.indel.' + target_position + '.pdf'
				plot_indel_detail_one(indel_near_target_one, reference_sequence, target_cut, target_upstream, target_downstream, pdf_file, target_chromo, R_file)

			get_process_time('plot: indel detail -' + args.rank, 1)

#---plot indel detail one--
def plot_indel_detail_one(indel_near_target_one, reference_sequence, target_cut, target_upstream, target_downstream, pdf_file, target_chromo, R_file):
	position_type = list(set([int(x.split()[1]) for x in indel_near_target_one]))
	position_type.sort()
	R = []
	R.append('reference_sequence = \'%s\'\n' % (reference_sequence))
	R.append('reference_length = %s\n' % (str(len(reference_sequence))))
	R.append('target_cut = %s\n' % (str(target_cut)))
	R.append('target_upstream = %s\n' % (str(target_upstream)))
	R.append('target_downstream = %s\n' % (str(target_downstream)))
	R.append('position_type = %s\n' % (str(len(position_type))))
	R.append('position_number = %s\n' % (str(len(indel_near_target_one))))
	R.append('insertion_number = %s\n' % (str(len([x for x in indel_near_target_one if x.split()[3].startswith('*/+')]))))	
	R.append('deletion_number = %s\n' % (str(len([x for x in indel_near_target_one if x.split()[3].startswith('*/-')]))))	
	R.append('total_height=position_number*2 + position_type*2 + insertion_number\n')
	R.append('character_cex=8\ncharacter_cex_small=7\nadd_height=0.5\nadd_width=0.45\n')
	R.append('''text_sequence <- function(direction, separate=1, sequence, cord_x, cord_y, cex, color='black'){
		if (direction == 'left'){
			if (separate == 1){
				for (i in nchar(sequence):1){
					text(cord_x - nchar(sequence) + i, cord_y, substr(sequence, i, i), pos = 2, cex = cex, col = color)
				}
			}else{
				text(cord_x, cord_y, sequence, pos = 2, cex = cex, col = color)
			}
		}else{
			if (separate == 1){
				for (i in 1:nchar(sequence)){
					text(cord_x + i - 1, cord_y, substr(sequence, i, i), pos = 4, cex = cex, col = color)
				}
			}else{
				text(cord_x, cord_y, sequence, pos = 4, cex = cex, col = color)
			}
		}
	}\n''')
	R.append('pdf(\'%s\', width = reference_length + 31, height = total_height + 8)\n' % (pdf_file))
	R.append('par(family=\'mono\', cex=1, mar=c(0, 0, 0, 0), oma=c(0, 0, 0, 0))\n')
	R.append('plot(0, 0, xlim = c(-11, reference_length + 20), ylim = c(0, total_height + 8), type = \'n\', xlab = \'\', ylab = \'\', axes = FALSE)\n')
	R.append('text((reference_length + 31)/2 - 11, total_height  + 6, \'%s\', pos = 3, font = 2, cex = character_cex + 1)\n' % (target_chromo))
	R.append('text_sequence(\'left\', 0, \'position:\', 0, total_height  + 3, character_cex)\n')
	R.append('text_sequence(\'right\', 0, target_upstream + 1, 1, total_height + 3, character_cex_small)\n')
	R.append('text_sequence(\'right\', 0, target_cut, target_cut - target_upstream, total_height + 3, character_cex_small)\n')
	R.append('text_sequence(\'right\', 0, target_downstream, target_downstream - target_upstream, total_height + 3, character_cex_small)\n')
	R.append('text_sequence(\'right\', 0, \'|\', 1, total_height + 2, character_cex_small - 2.5, \'red\')\n')
	R.append('text_sequence(\'right\', 0, \'|\', target_cut - target_upstream, total_height + 2, character_cex_small - 2.5, \'red\')\n')
	R.append('text_sequence(\'right\', 0, \'|\', target_downstream - target_upstream, total_height + 2, character_cex_small - 2.5, \'red\')\n')
	R.append('text_sequence(\'left\', 0, \'reference:\', 0, total_height + 1, character_cex)\n')
	R.append('text_sequence(\'right\', 1, reference_sequence, 1, total_height + 1, character_cex)\n')
	indel_line_order = 0
	for type in position_type:
		indel_line_order += 2
		indel_position_type = [x for x in indel_near_target_one if x.split()[1] == str(type)]
		for indel in indel_position_type:
			indel_split = indel.split()
			if indel_split[3].startswith('*/-'):			
				R.append('text_sequence(\'left\', 0, \'%s:\', 0, total_height - %s, character_cex)\n' % (str(type), str(indel_line_order)))
				R.append('text_sequence(\'left\', 0, \'%s\', reference_length + 7, total_height - %s, character_cex, \'red\')\n' % (indel_split[6], str(indel_line_order)))
				R.append('text_sequence(\'right\', 0, \'deletion\', reference_length + 8, total_height - %s, character_cex)\n' % (str(indel_line_order)))
				deletion_length = len(indel_split[3]) - 3
				indel_sequence = reference_sequence[:(type - target_upstream)] + ' '*deletion_length + reference_sequence[(type - target_upstream + deletion_length):]
				R.append('text_sequence(\'right\', 1, \'%s\', 1, total_height - %s, character_cex)\n' % (indel_sequence, str(indel_line_order)))
				indel_sequence_delete = ' '*(type - target_upstream) + '-'*deletion_length + ' '*(target_downstream - type - deletion_length)
				R.append('text_sequence(\'right\', 1, \'%s\', 1, total_height - %s, character_cex, \'red\')\n' % (indel_sequence_delete, str(indel_line_order)))
				indel_line_order += 2
			else:
				indel_sequence = reference_sequence
				R.append('text_sequence(\'left\', 0, \'%s:\', 0, total_height - %s, character_cex)\n' % (str(type), str(indel_line_order)))
				R.append('text_sequence(\'left\', 0, \'%s\', reference_length + 7, total_height - %s, character_cex, \'red\')\n' % (indel_split[6], str(indel_line_order)))
				R.append('text_sequence(\'right\', 0, \'insertion\', reference_length + 8, total_height - %s, character_cex)\n' % (str(indel_line_order)))
				R.append('text_sequence(\'right\', 1, \'%s\', 1, total_height - %s, character_cex)\n' % (indel_sequence, str(indel_line_order)))
				indel_sequence_insertion = indel_split[3][3:]
				R.append('text_sequence(\'right\', 1, \'%s\', %s - target_upstream + 0.5, total_height - %s - 1, character_cex, \'red\')\n' % (indel_sequence_insertion, str(type), str(indel_line_order)))
				R.append('text_sequence(\'right\', 1, \'^\', %s - target_upstream + add_width, total_height - %s - add_height, character_cex + 0.5, \'red\')\n' % (str(type), str(indel_line_order)))
				indel_line_order += 3
	R.append('garbage = dev.off()\n')
	write_content(R_file, R)
	os.system('Rscript ' + R_file)

#write_content('/Users/liyx/Downloads/test.r', R)

#---get fasta dict--
def get_fasta_dict(fasta_file):
	fasta_name = []
	fasta_sequence = []
	fasta_number = -1
	for line in open(fasta_file, 'rU').readlines():
		line = line.rstrip()
		if line[0] == '>':
			fasta_name.append(line[1:])
			fasta_number += 1
			fasta_sequence.append('')
		else:
			fasta_sequence[fasta_number] = fasta_sequence[fasta_number] + line
	fasta_dict = dict(zip(fasta_name, fasta_sequence))
	return fasta_dict

#---run indel matrix--
def run_indel_matrix(args_more):
	get_process_time('run: indel matrix')
	result_dir = make_dir(args_more[0].output + args_more[0].batch + '.result/')
	refer_name_all = [os.path.basename(os.path.splitext(x.reference)[0]) for x in args_more]
	for refer_name in refer_name_all:
		indel_matrix_file = result_dir + args_more[0].batch + '.indel.' + refer_name + '.mat'
		indel_matrix_head = 'Chrom\tPosition\tRef\tCons'
		indel_matrix_dict = {}
		n = 0		
		for args_one in args_more:
			if os.path.basename(os.path.splitext(args_one.reference)[0]) == refer_name:
				indel_file = args_one.output + args_one.name + '/VarScan/' + args_one.name + '.indel.tab'
				if not os.path.isfile(indel_file):
					make_ornament('WARNING! no %s varscan indel result!' % (args_one.name), 100, ' ', 1, 0)
				else:
					indel_matrix_head = indel_matrix_head + '\t%s_Reads1\t%s_Reads2\t%s_VarFreq' % (args_one.name, args_one.name, args_one.name)
					indel = open(indel_file, 'rU').readlines()
					indel_dict = dict(zip(['\t'.join(x.split()[0:4]) for x in indel[1:]], ['\t'.join(x.split()[4:7]) for x in indel[1:]]))
					for x, y in indel_dict.items():
						if not x in indel_matrix_dict.keys():
							indel_matrix_dict[x] = ['na\tna\tna']*n + [y]
						else:
							indel_matrix_dict[x] = indel_matrix_dict[x] + [y]
					for x in indel_matrix_dict.keys():
						if not x in indel_dict.keys():
							indel_matrix_dict[x] = indel_matrix_dict[x] + ['na\tna\tna']
				n += 1
		indel_matrix = [indel_matrix_head + '\n'] + [x + '\t' + '\t'.join(indel_matrix_dict[x]) + '\n' for x in indel_matrix_dict.keys()]
		write_content(indel_matrix_file, indel_matrix)
	get_process_time('run: indel matrix', 1)
#	return indel_matrix

##----PROCESS------##
if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write("User interrupted me! ;-) Bye!\n")
        sys.exit(0)

##----TEST--------##
