#!python
# ChromBPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import sys
import numpy
import torch
import argparse

from bpnetlite.bpnet import BPNet
from bpnetlite.chrombpnet import ChromBPNet

from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_loci
from bpnetlite.attributions import calculate_attributions

import json

desc = """ChromBPNet is a neural network that builds off the original BPNet
	architecture by explicitly learning bias in the signal tracks themselves.
	Specifically, for ATAC-seq and DNAse-seq experiments, the cutting enzymes
	have a soft sequence bias (though this is much stronger for Tn5, the
	enzyme for ATAC-seq). Accordingly, ChromBPNet is a pair of neural networks
	where one models the bias explicitly and one models the accessibility
	explicitly. This tool provides functionality for training the combination
	of the bias model and accessibility model and making predictions using it.
	After training, the accessibility model can be used using the `bpnet`
	tool."""



# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'bias' or 'train' or 'predict'.", required=True, dest='cmd')

train_parser = subparsers.add_parser("train", help="Train a ChromBPNet model.")
train_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for training the model.")

predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained ChromBPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

bias_parser = subparsers.add_parser("bias", help="Train a bias model.")
bias_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for training the model.")


# Pull the arguments
args = parser.parse_args()

if args.cmd == "train":
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'n_filters': 64,
		'n_layers': 8,
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'max_jitter': 128,
		'reverse_complement': True,
		'max_epochs': 250,
		'validation_iter': 100,
		'lr': 0.001,
		'alpha': 1,
		'verbose': False,

		'min_counts': 0,
		'max_counts': 99999999,
		'bias_model': None,

		'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'validation_chroms': ['chr4', 'chr15', 'chr21'],
		'sequences': None,
		'loci': None,
		'signals': None,
		'random_state': None

	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	###

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)


	valid_sequences, valid_signals = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	bias = torch.load(parameters['bias_model'], map_location='cpu').cuda().eval()
	accessibility = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'], n_control_tracks=0, n_outputs=1,
		alpha=parameters['alpha'],
		trimming=trimming).cuda()

	model = ChromBPNet(bias=bias, accessibility=accessibility,
		name="chrombpnet.{}.{}".format(parameters['n_filters'], 
			parameters['n_layers']))


	optimizer = torch.optim.Adam(model.parameters(), lr=parameters['lr'])

	model.fit_generator(training_data, optimizer, X_valid=valid_sequences, 
		y_valid=valid_signals, max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])

elif args.cmd == 'predict':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'verbose': False,
		'chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'sequences': None,
		'loci': None,
		'model': None,
		'profile_filename': 'y_profile.npz',
		'count_filename': 'y_count.npz'
	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	model = torch.load(parameters['model']).cuda()

	X = extract_loci(
		sequences=parameters['sequences'],
		loci=parameters['loci'],
		chroms=parameters['chroms'],
		max_jitter=0,
		verbose=parameters['verbose']
	).cuda()

	y_profiles, y_counts = model.predict(X, batch_size=parameters['batch_size'])

	numpy.savez_compressed(parameters['profile_filename'], y_profiles)
	numpy.savez_compressed(parameters['count_filename'], y_counts)

elif args.cmd == 'bias':
	with open(args.parameters, "r") as infile:
		parameters = json.load(infile)

	default_parameters = {
		'n_filters': 256,
		'n_layers': 4,
		'batch_size': 64,
		'in_window': 2114,
		'out_window': 1000,
		'max_jitter': 128,
		'reverse_complement': True,
		'max_epochs': 250,
		'validation_iter': 100,
		'lr': 0.001,
		'alpha': 1,
		'beta': 0.5,
		'verbose': False,

		'min_counts': 0,
		'max_counts': None,

		'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
			'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
			'chr18', 'chr19', 'chr20', 'chr22'],
		'validation_chroms': ['chr4', 'chr15', 'chr21'],
		'sequences': None,
		'peaks': None,
		'negatives': None,
		'signals': None,
		'random_state': None

	}

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	###

	if parameters['max_counts'] is None:
		_, train_signals = extract_loci(
			sequences=parameters['sequences'],
			signals=parameters['signals'],
			controls=None,
			loci=parameters['peaks'],
			chroms=parameters['validation_chroms'],
			in_window=parameters['in_window'],
			out_window=parameters['out_window'],
			max_jitter=0,
			verbose=parameters['verbose']
		)

		parameters['max_counts'] = train_signals.sum(dim=-1).min()

	training_data = PeakGenerator(
		loci=parameters['negatives'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts']*parameters['beta'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	valid_sequences, valid_signals = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		loci=parameters['negatives'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	model = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'], n_outputs=1, n_control_tracks=0,
		alpha=parameters['alpha'],
		trimming=trimming).cuda()

	optimizer = torch.optim.Adam(model.parameters(), lr=parameters['lr'])

	model.fit_generator(training_data, optimizer, X_valid=valid_sequences, 
		y_valid=valid_signals, max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])

