#!python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import sys
import numpy
import torch
import pyfaidx
import argparse

from bpnetlite import BPNet
from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_loci

from bpnetlite.attributions import calculate_attributions
from bpnetlite.marginalize import marginalization_report

import json

desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'train', 'predict', 'interpret', or 'marginalize'.", required=True, dest='cmd')

train_parser = subparsers.add_parser("train", help="Train a BPNet model.")
train_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for training the model.")

predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

interpret_parser = subparsers.add_parser("interpret", help="Make interpretations using a trained BPNet model.")
interpret_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

marginalize_parser = subparsers.add_parser("marginalize", help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")


# Pull the arguments
args = parser.parse_args()

if args.cmd == "train":
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'n_filters': 64,
		'n_layers': 8,
		'n_outputs': 2,
		'n_control_tracks': 2,
		'profile_output_bias': True,
		'count_output_bias': True,
		'name': None,

		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'max_jitter': 128,
		'reverse_complement': True,
		'max_epochs': 250,
		'validation_iter': 100,
		'lr': 0.001,
		'alpha': 1,
		'verbose': False,

		'min_counts': 0,
		'max_counts': 99999999,

		'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'validation_chroms': ['chr4', 'chr15', 'chr21'],
		'sequences': None,
		'loci': None,
		'signals': None,
		'controls': None,
		'random_state': None
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_data = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] is not None:
		valid_sequences, valid_signals, valid_controls = valid_data
	else:
		valid_sequences, valid_signals = valid_data
		valid_controls = None

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'],
		n_outputs=parameters['n_outputs'],
		n_control_tracks=parameters['n_control_tracks'],
		profile_output_bias=parameters['profile_output_bias'],
		count_output_bias=parameters['count_output_bias'],
		alpha=parameters['alpha'],
		trimming=trimming,
		name=parameters['name'],
		verbose=parameters['verbose']).cuda()

	optimizer = torch.optim.Adam(model.parameters(), lr=parameters['lr'])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		X_ctl_valid=valid_controls, y_valid=valid_signals, 
		max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])

elif args.cmd == 'predict':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'verbose': False,
		'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'sequences': None,
		'loci': None,
		'controls': None,
		'model': None,
		'profile_filename': 'y_profile.npz',
		'count_filename': 'y_count.npz'
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model']).cuda()

	examples = extract_loci(
		sequences=parameters['sequences'],
		controls=parameters['controls'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['controls'] == None:
		X = examples
		if model.n_control_tracks > 0:
			X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
		else:
			X_ctl = None
	else:
		X, X_ctl = examples

	y_profiles, y_counts = model.predict(X, X_ctl=X_ctl, 
		batch_size=parameters['batch_size'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['count_filename'], y_counts)

elif args.cmd == 'interpret':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'verbose': False,
		'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'sequences': None,
		'loci': None,
		'model': None,
		'output': 'count',
		'ohe_filename': 'ohe.npz',
		'shap_filename': 'shap.npz',
		'random_state':0,
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if model.n_control_tracks > 0:
		X_ctl = (torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1]),)
	else:
		X_ctl = None

	X_attr = calculate_attributions(model, X, args=X_ctl,
		model_output=parameters['output'], hypothetical=True, 
		random_state=parameters['random_state'],
		verbose=parameters['verbose'])

	numpy.savez_compressed(parameters['ohe_filename'], X)
	numpy.savez_compressed(parameters['shap_filename'], X_attr)

elif args.cmd == 'marginalize':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'verbose': False,
		'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'sequences': None,
		'motifs': None,
		'loci': None,
		'n_loci': None,
		'shuffle': False,
		'model': None,
		'output_filename':'marginalize/',
		'random_state':0,
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter != "controls":
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	if parameters['shuffle'] == True:
		idxs = numpy.arange(X.shape[0])
		numpy.random.shuffle(idxs)
		X = X[idxs]

	if parameters['n_loci'] is not None:
		X = X[:parameters['n_loci']]

	motifs = pyfaidx.Fasta(parameters['motifs'])
	marginalization_report(model, motifs, X, 
		parameters['output_filename'])
