#!python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import os
os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1'

import sys
import numpy
import torch
import pyfaidx
import argparse

from bpnetlite import BPNet
from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_loci

from bpnetlite.attributions import calculate_attributions
from bpnetlite.marginalize import marginalization_report

from bpnetlite.negatives import extract_matching_loci
from bpnetlite.negatives import calculate_gc_genomewide

import pandas
import pyBigWig
import json

torch.backends.cudnn.benchmark = True


desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

_help = """Must be either 'negatives', 'fit', 'predict', 'interpret',
	'marginalize', or 'pipeline'."""


# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd')

negatives_parser = subparsers.add_parser("negatives", 
	help="Sample GC-matched negatives.")
negatives_parser.add_argument("-i", "--peaks", required=True, 
	help="Peak bed file.")
negatives_parser.add_argument("-f", "--fasta", help="Genome FASTA file.")
negatives_parser.add_argument("-b", "--bigwig", required=True, 
	help="GC content bigwig.")
negatives_parser.add_argument("-o", "--output", required=True, 
	help="Output bed file.")
negatives_parser.add_argument("-l", "--bin_width", type=float, default=0.02, 
	help="GC bin width to match.")
negatives_parser.add_argument("-w", "--width", type=int, default=2114, 
	help="Width for calculating GC content.")
negatives_parser.add_argument("-v", "--verbose", default=True, 
	action='store_true')

fit_parser = subparsers.add_parser("fit", help="Fit a BPNet model.")
fit_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

predict_parser = subparsers.add_parser("predict", 
	help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

interpret_parser = subparsers.add_parser("interpret", 
	help="Make interpretations using a trained BPNet model.")
interpret_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

marginalize_parser = subparsers.add_parser("marginalize", 
	help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

pipeline_parser = subparsers.add_parser("pipeline", 
	help="Run each step on the given files.")
pipeline_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters used for each step.")


###
# Default Parameters
###

default_fit_parameters = {
	'n_filters': 64,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'name': None,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,

	'min_counts': 0,
	'max_counts': 99999999,

	'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
		'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 
		'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
	'validation_chroms': ['chr8', 'chr10'],
	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,
	'random_state': None
}

default_predict_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
		'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 
		'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
	'sequences': None,
	'loci': None,
	'controls': None,
	'model': None,
	'profile_filename': 'y_profile.npz',
	'counts_filename': 'y_counts.npz'
}

default_interpret_parameters = {
	'batch_size': 1,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 
		'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 
		'chr20', 'chr21', 'chr22', 'chrX'],
	'sequences': None,
	'loci': None,
	'model': None,
	'output': 'profile',
	'ohe_filename': 'ohe.npz',
	'attr_filename': 'attr.npz',
	'n_shuffles':20,
	'random_state':0
}

default_marginalize_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 
		'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 
		'chr20', 'chr21', 'chr22', 'chrX'],
	'sequences': None,
	'motifs': None,
	'loci': None,
	'n_loci': None,
	'shuffle': False,
	'model': None,
	'output_filename':'marginalize/',
	'random_state':0,
	'minimal': True
}

default_pipeline_parameters = {
	# Model architecture parameters
	'n_filters': 64,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'in_window': 2114,
	'out_window': 1000,
	'name': None,
	'model': None,
	'verbose': False,

	# Data parameters
	'batch_size': 64,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,
	'min_counts': 0,
	'max_counts': 99999999,

	'sequences': None,
	'loci': None,
	'signals': None,
	'controls': None,

	# Fit parameters
	'fit_parameters': {
		'batch_size': 64,
		'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
			'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 
			'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
		'validation_chroms': ['chr8', 'chr10'],
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,	
		'verbose': None,
		'random_state': None,
	},

	# Predict parameters
	'predict_parameters': {
		'batch_size': 64,
		'chroms': ['chr8', 'chr10'],
		'profile_filename': None,
		'counts_filename': None,
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,
		'verbose': None,
	},


	# Interpret parameters
	'interpret_parameters': {
		'batch_size': 1,
		'chroms': ['chr8', 'chr10'],
		'output': 'profile',
		'loci': None,
		'ohe_filename': None,
		'attr_filename': None,
		'n_shuffles': None,
		'random_state': None,
		'verbose': None

	},

	# Modisco parameters
	'modisco_motifs_parameters': {
		'n_seqlets': 100000,
		'output_filename': None,
		'verbose': None
	},

	# Modisco report parameters
	'modisco_report_parameters': {
		'motifs': None,
		'output_folder': None,
		'verbose': None
	},

	# Marginalization parameters
	'marginalize_parameters': {
		'loci': None,
		'n_loci': 100,
		'shuffle': False,
		'random_state': None,
		'output_folder': None,
		'motifs': None,
		'minimal': True,
		'verbose': None
	} 
}


###
# Commands
###


def merge_parameters(parameters, default_parameters):
	"""Merge the provided parameters with the default parameters.

	
	Parameters
	----------
	parameters: str
		Name of the JSON folder with the provided parameters

	default_parameters: dict
		The default parameters for the operation.


	Returns
	-------
	params: dict
		The merged set of parameters.
	"""

	with open(parameters, "r") as infile:
		parameters = json.load(infile)

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	return parameters


# Pull the arguments
args = parser.parse_args()


# Calculate GC-matched negatives
if args.cmd == 'negatives':
	if args.fasta is not None:
		chroms = ['chr{}'.format(i) for i in range(1, 23)] + ['chrX']
		calculate_gc_genomewide(args.fasta, args.bigwig, args.width, chroms, 
			args.verbose)

	numpy.random.seed(0)

	# Extract regions that match the GC content of the peaks
	matched_loci = extract_matching_loci(args.peaks, args.bigwig, args.width, 
		args.bin_width, args.verbose)

	matched_loci.to_csv(args.output, header=False, sep='\t', index=False)


# Fit a BPNet model to data
if args.cmd == "fit":
	parameters = merge_parameters(args.parameters, default_fit_parameters)

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_data = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] is not None:
		valid_sequences, valid_signals, valid_controls = valid_data
		n_control_tracks = 2
	else:
		valid_sequences, valid_signals = valid_data
		valid_controls = None
		n_control_tracks = 0

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'],
		n_outputs=len(parameters['signals']),
		n_control_tracks=n_control_tracks,
		profile_output_bias=parameters['profile_output_bias'],
		count_output_bias=parameters['count_output_bias'],
		alpha=parameters['alpha'],
		trimming=trimming,
		name=parameters['name'],
		verbose=parameters['verbose']).cuda()

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	if parameters['verbose']:
		print("Training Set Size: ", training_data.dataset.sequences.shape[0])
		print("Validation Set Size: ", valid_sequences.shape[0])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		X_ctl_valid=valid_controls, y_valid=valid_signals, 
		max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])


elif args.cmd == 'predict':
	parameters = merge_parameters(args.parameters, default_predict_parameters)

	###

	model = torch.load(parameters['model']).cuda()

	examples = extract_loci(
		sequences=parameters['sequences'],
		controls=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] == None:
		X = examples
		if model.n_control_tracks > 0:
			X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
		else:
			X_ctl = None
	else:
		X, X_ctl = examples

	y_profiles, y_counts = model.predict(X, X_ctl=X_ctl, 
		batch_size=parameters['batch_size'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['counts_filename'], y_counts)


elif args.cmd == 'interpret':
	parameters = merge_parameters(args.parameters, default_interpret_parameters)

	###

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if model.n_control_tracks > 0:
		X_ctl = (torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1]),)
	else:
		X_ctl = None

	X_attr = calculate_attributions(model, X, args=X_ctl,
		model_output=parameters['output'], hypothetical=True, 
		n_shuffles=parameters['n_shuffles'],
		batch_size=parameters['batch_size'],
		random_state=parameters['random_state'],
		verbose=parameters['verbose'])

	numpy.savez_compressed(parameters['ohe_filename'], X)
	numpy.savez_compressed(parameters['attr_filename'], X_attr)


# Marginalize motifs
elif args.cmd == 'marginalize':
	parameters = merge_parameters(args.parameters, 
		default_marginalize_parameters)

	###

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		n_loci=parameters['n_loci'],
		verbose=parameters['verbose']
	)

	if parameters['shuffle'] == True:
		idxs = numpy.arange(X.shape[0])
		numpy.random.shuffle(idxs)
		X = X[idxs]

	if parameters['n_loci'] is not None:
		X = X[:parameters['n_loci']]

	marginalization_report(model, parameters['motifs'], X, 
		parameters['output_filename'], minimal=parameters['minimal'])


elif args.cmd == 'pipeline':
	parameters = merge_parameters(args.parameters, default_pipeline_parameters)

	# Step 1: Fit a BPNet model to the provided data
	if parameters['verbose']:
		print("Step 1: Fitting a BPNet model")


	fit_parameters = {key: parameters.get(key, None) for key in 
		default_fit_parameters}
	for parameter, value in parameters['fit_parameters'].items():
		if value is not None:
			fit_parameters[parameter] = value


	if parameters['model'] is None:
		name = '{}.bpnet.fit.json'.format(parameters['name'])
		parameters['model'] = parameters['name'] + '.torch'

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4))

		os.system("bpnet fit -p {}".format(name))

	
	# Step 2: Make predictions for the entire validation set
	if parameters['verbose']:
		print("\nStep 2: Making predictions")

	predict_parameters = {key: parameters.get(key, None) for key in 
		default_predict_parameters if key in parameters}
	for parameter, value in parameters['predict_parameters'].items():
		if value is not None:
			predict_parameters[parameter] = value

	if 'profile_filename' not in predict_parameters:
		predict_parameters['profile_filename'] = '{}.y_profiles.npz'.format(
			parameters['name'])

	if 'counts_filename' not in predict_parameters:
		predict_parameters['counts_filename'] = '{}.y_counts.npz'.format(
			parameters['name'])

	name = '{}.bpnet.predict.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(predict_parameters, sort_keys=True, indent=4))

	os.system("bpnet predict -p {}".format(name))
	

	# Step 3: Interpret the validation set
	if parameters['verbose']:
		print("\nStep 3: Calculating attributions")

	interpret_parameters = {key: parameters.get(key, None) for key in 
		default_interpret_parameters if key in parameters}
	for parameter, value in parameters['interpret_parameters'].items():
		if value is not None:
			interpret_parameters[parameter] = value

	if 'ohe_filename' not in interpret_parameters:
		interpret_parameters['ohe_filename'] = '{}.ohe.npz'.format(
			parameters['name'])

	if 'attr_filename' not in interpret_parameters:
		interpret_parameters['attr_filename'] = '{}.attr.npz'.format(
			parameters['name'])

	name = '{}.bpnet.interpret.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(interpret_parameters, sort_keys=True, 
			indent=4))

	os.system("bpnet interpret -p {}".format(name))


	# Step 4: Calculate tf-modisco motifs
	if parameters['verbose']:
		print("\nStep 4: TF-MoDISco motifs")

	modisco_parameters = parameters['modisco_motifs_parameters']

	if modisco_parameters['output_filename'] is None:
		modisco_parameters['output_filename'] = '{}_modisco_results.h5'.format(
			parameters['name'])

	cmd = "modisco motifs -s {} -a {} -n {} -o {}".format(
		interpret_parameters['ohe_filename'], 
		interpret_parameters['attr_filename'],
		modisco_parameters['n_seqlets'],
		modisco_parameters['output_filename'])
	if modisco_parameters['verbose']:
		cmd += ' -v'

	os.system(cmd)


	# Step 5: Generate the tf-modisco report
	modisco_name = "{}_modisco_results.h5".format(parameters['name'])

	report_parameters = parameters['modisco_report_parameters']
	if report_parameters['verbose'] is None:
		report_parameters['verbose'] = parameters['verbose']

	if report_parameters['output_folder'] is None:
		report_parameters['output_folder'] = '{}_modisco/'.format(
			parameters['name'])

	if report_parameters['verbose']:
		print("\nStep 5: TF-MoDISco reports")

	os.system('modisco report -i {0} -o {1} -s {1} -m {2}'.format(
		modisco_parameters['output_filename'], 
		report_parameters['output_folder'], report_parameters['motifs']))
	

	# Step 6: Marginalization experiments
	if parameters['verbose']:
		print("\nStep 6: Run marginalizations")

	marginalize_parameters = {key: parameters.get(key, None) for key in 
		default_marginalize_parameters if key in parameters}
	for parameter, value in parameters['marginalize_parameters'].items():
		if value is not None:
			marginalize_parameters[parameter] = value

	if 'output_filename' not in marginalize_parameters:
		marginalize_parameters['output_filename'] = '{}_marginalize/'.format(
			parameters['name'])

	name = '{}.bpnet.marginalize.json'.format(parameters['name'])

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(marginalize_parameters, sort_keys=True, 
			indent=4))

	os.system("bpnet marginalize -p {}".format(name))
