#!python
# ChromBPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import os
import sys
import numpy
import torch
import argparse

from bpnetlite.bpnet import BPNet
from bpnetlite.chrombpnet import ChromBPNet

from bpnetlite.io import PeakGenerator
from bpnetlite.io import extract_loci
from bpnetlite.attributions import calculate_attributions
from bpnetlite.marginalize import marginalization_report

import json

desc = """ChromBPNet is a neural network that builds off the original BPNet
	architecture by explicitly learning bias in the signal tracks themselves.
	Specifically, for ATAC-seq and DNAse-seq experiments, the cutting enzymes
	have a soft sequence bias (though this is much stronger for Tn5, the
	enzyme for ATAC-seq). Accordingly, ChromBPNet is a pair of neural networks
	where one models the bias explicitly and one models the accessibility
	explicitly. This tool provides functionality for training the combination
	of the bias model and accessibility model and making predictions using it.
	After training, the accessibility model can be used using the `bpnet`
	tool."""

_help = """Must be either 'fit', 'predict', 'interpret', 'marginalize', 
	or 'pipeline'."""


# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd')

train_parser = subparsers.add_parser("fit", help="Fit a BPNet model.")
train_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for fitting the model.")

predict_parser = subparsers.add_parser("predict", 
	help="Make predictions using a trained BPNet model.")
predict_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for making predictions.")

interpret_parser = subparsers.add_parser("interpret", 
	help="Make interpretations using a trained BPNet model.")
interpret_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

marginalize_parser = subparsers.add_parser("marginalize", 
	help="Run marginalizations given motifs.")
marginalize_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters for calculating attributions.")

pipeline_parser = subparsers.add_parser("pipeline", 
	help="Run each step on the given files.")
pipeline_parser.add_argument("-p", "--parameters", type=str, required=True,
	help="A JSON file containing the parameters used for each step.")


###
# Default Parameters
###

default_fit_parameters = {
	'n_filters': 256,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'name': None,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 10,
	'beta': 0.5,
	'verbose': False,
	'bias_model': None,

	'min_counts': None,
	'max_counts': None,

	'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
		'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 
		'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
	'validation_chroms': ['chr8', 'chr10'],

	'sequences': None,
	'loci': None,
	'negatives': None,
	'signals': None,
	'random_state': None,

	# Fit bias model
	'bias_fit_parameters': {
		'n_filters': None,
		'n_layers': 4,
		'alpha': None,
		'max_counts': None,
		'loci': None,
		'verbose': None,
		'random_state': None
	}
}

default_predict_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 
		'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 
		'chr20', 'chr21', 'chr22', 'chrX'],
	'sequences': None,
	'loci': None,
	'model': None,
	'profile_filename': 'y_profile.npz',
	'counts_filename': 'y_counts.npz'
}

default_interpret_parameters = {
	'batch_size': 1,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 
		'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 
		'chr20', 'chr21', 'chr22', 'chrX'],
	'sequences': None,
	'loci': None,
	'model': None,
	'output': 'profile',
	'ohe_filename': 'ohe.npz',
	'attr_filename': 'attr.npz',
	'n_shuffles':20,
	'random_state':0
}

default_marginalize_parameters = {
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'verbose': False,
	'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 
		'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 
		'chr20', 'chr21', 'chr22', 'chrX'],
	'sequences': None,
	'motifs': None,
	'loci': None,
	'n_loci': None,
	'shuffle': False,
	'model': None,
	'output_filename':'marginalize/',
	'random_state':0,
	'minimal': True
}

default_pipeline_parameters = {
	# Model architecture parameters
	'n_filters': 256,
	'n_layers': 8,
	'profile_output_bias': True,
	'count_output_bias': True,
	'in_window': 2114,
	'out_window': 1000,
	'name': None,
	'model': None,
	'bias_model': None,
	'accessibility_model': None,
	'verbose': False,

	# Data parameters
	'batch_size': 64,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 50,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 10,
	'beta': 0.5,
	'min_counts': 0,
	'max_counts': 99999999,

	'sequences': None,
	'loci': None,
	'negatives': None,
	'signals': None,

	'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
		'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 
		'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
	'validation_chroms': ['chr8', 'chr10'],

	# Fit bias model
	'bias_fit_parameters': {
		'n_filters': None,
		'n_layers': 4,
		'alpha': None,
		'max_counts': None,
		'loci': None,
		'verbose': None,
		'random_state': None
	},

	# Fit accessibility model
	'chrombpnet_fit_parameters': {
		'batch_size': 64,
		'sequences': None,
		'loci': None,
		'signals': None,
		'verbose': None,
		'random_state': None,
	},

	# Predict parameters
	'predict_parameters': {
		'batch_size': 64,
		'chroms': ['chr8', 'chr10'],
		'profile_filename': None,
		'counts_filename': None,
		'sequences': None,
		'loci': None,
		'signals': None,
		'verbose': None,
	},


	# Interpret parameters
	'interpret_parameters': {
		'batch_size': 1,
		'chroms': ['chr8', 'chr10'],
		'output': 'profile',
		'loci': None,
		'ohe_filename': None,
		'attr_filename': None,
		'n_shuffles': None,
		'verbose': None,
		'random_state': None
	},

	# Modisco parameters
	'modisco_motifs_parameters': {
		'n_seqlets': 100000,
		'output_filename': None,
		'verbose': None
	},

	# Modisco report parameters
	'modisco_report_parameters': {
		'motifs': None,
		'output_folder': None,
		'verbose': None
	},

	# Marginalization parameters
	'marginalize_parameters': {
		'loci': None,
		'n_loci': 100,
		'shuffle': False,
		'output_folder': None,
		'motifs': None,
		'minimal': True,
		'verbose': None,
		'random_state': None
	} 
}


###
# Commands
###


def merge_parameters(parameters, default_parameters):
	"""Merge the provided parameters with the default parameters.

	
	Parameters
	----------
	parameters: str
		Name of the JSON folder with the provided parameters

	default_parameters: dict
		The default parameters for the operation.


	Returns
	-------
	params: dict
		The merged set of parameters.
	"""

	with open(parameters, "r") as infile:
		parameters = json.load(infile)

	optional = ['bias_model', 'min_counts', 'max_counts']

	for parameter, value in default_parameters.items():
		if parameter not in parameters:
			if value is None and parameter not in optional:
				raise ValueError("Must provide value for '{}'".format(parameter))

			parameters[parameter] = value

	return parameters


# Pull the arguments
args = parser.parse_args()

# Fit ChromBPNet model
if args.cmd == "fit":
	parameters = merge_parameters(args.parameters, default_fit_parameters)

	###

	print(parameters['loci'])

	training_data = PeakGenerator(
		loci=parameters['loci'], 
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		chroms=parameters['training_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=parameters['max_jitter'],
		reverse_complement=parameters['reverse_complement'],
		min_counts=parameters['min_counts'],
		max_counts=parameters['max_counts'],
		random_state=parameters['random_state'],
		batch_size=parameters['batch_size'],
		verbose=parameters['verbose']
	)

	trimming = (parameters['in_window'] - parameters['out_window']) // 2

	if parameters['bias_model'] is None:
		min_counts = training_data.dataset.signals.sum(dim=(1, 2)).min().item()
		print(min_counts)

		name = 'bias_fit_parameters'
		bias_fit_parameters = {key: parameters[key] for key in 
			default_fit_parameters}
		for parameter, value in bias_fit_parameters[name].items():
			if value is not None:
				bias_fit_parameters[parameter] = value

		del bias_fit_parameters['negatives'], bias_fit_parameters['beta']
			
		name = '{}.chrombpnet.bias.fit.json'.format(parameters['name'])
		bias_fit_parameters['max_counts'] = min_counts * parameters['beta']
		bias_fit_parameters['name'] = parameters['name'] + '.bias'
		parameters['bias_model'] = bias_fit_parameters['name'] + '.torch'

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(bias_fit_parameters, sort_keys=True, 
				indent=4))

		os.system("bpnet fit -p {}".format(name))

	if parameters['negatives'] is not None:
		training_data = PeakGenerator(
			loci=[parameters['loci'], parameters['negatives']], 
			sequences=parameters['sequences'],
			signals=parameters['signals'],
			controls=None,
			chroms=parameters['training_chroms'],
			in_window=parameters['in_window'],
			out_window=parameters['out_window'],
			max_jitter=parameters['max_jitter'],
			reverse_complement=parameters['reverse_complement'],
			min_counts=parameters['min_counts'],
			max_counts=parameters['max_counts'],
			random_state=parameters['random_state'],
			batch_size=parameters['batch_size'],
			verbose=parameters['verbose']
		)

	valid_sequences, valid_signals = extract_loci(
		sequences=parameters['sequences'],
		signals=parameters['signals'],
		controls=None,
		loci=parameters['loci'],
		chroms=parameters['validation_chroms'],
		in_window=parameters['in_window'],
		out_window=parameters['out_window'],
		max_jitter=0,
		verbose=parameters['verbose']
	)

	bias = torch.load(parameters['bias_model'], map_location='cpu').cuda().eval()
	accessibility = BPNet(n_filters=parameters['n_filters'], 
		n_layers=parameters['n_layers'], n_control_tracks=0, n_outputs=1,
		alpha=parameters['alpha'],
		name=parameters['name'] + '.accessibility',
		trimming=trimming).cuda()

	model = ChromBPNet(bias=bias, accessibility=accessibility,
		name=parameters['name'])

	optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

	model.fit(training_data, optimizer, X_valid=valid_sequences, 
		y_valid=valid_signals, max_epochs=parameters['max_epochs'], 
		validation_iter=parameters['validation_iter'], 
		batch_size=parameters['batch_size'])


# Make predictions from the full ChromBPNet model
elif args.cmd == 'predict':
	os.system("bpnet predict -p {}".format(args.parameters))


# Make interpretations from the full ChromBPNet model
elif args.cmd == 'interpret':
	os.system("bpnet interpret -p {}".format(args.parameters))


# Marginalize motifs
elif args.cmd == 'marginalize':
	os.system("bpnet marginalize -p {}".format(args.parameters))


# Run everything for ChromBPNet
elif args.cmd == 'pipeline':
	parameters = merge_parameters(args.parameters, default_pipeline_parameters)
	model_name = parameters['name']

	# Step 1: Fit a BPNet model to the provided data
	if parameters['verbose']:
		print("Step 1: Fitting a ChromBPNet model")

	if parameters['model'] is None:
		name = '{}.chrombpnet.fit.json'.format(parameters['name'])
		parameters['model'] = parameters['name'] + '.torch'

		fit_parameters = {key: parameters[key] for key in 
			default_fit_parameters}
		for parameter, value in parameters['fit_parameters'].items():
			if value is not None:
				fit_parameters[parameter] = value

		with open(name, 'w') as outfile:
			outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4))

		os.system("chrombpnet fit -p {}".format(name))


	if parameters['bias_model'] is None:
		parameters['bias_model'] = model_name + '.bias.torch'

	if parameters['accessibility_model'] is None:
		parameters['accessibility_model'] = (model_name + 
			'.accessibility.torch')

	del parameters['bias_fit_parameters'], parameters['fit_parameters']

	# Run pipeline with ChromBPNet model
	name = '{}.chrombpnet.pipeline.json'.format(parameters['name'])
	with open(name, 'w') as outfile:
		outfile.write(json.dumps(parameters, sort_keys=True, indent=4))

	os.system("bpnet pipeline -p {}".format(name))


	# Run pipeline with accessibility model
	name = '{}.chrombpnet.accessibility.pipeline.json'.format(
		model_name)

	parameters['model'] = parameters['accessibility_model']
	parameters['name'] = model_name + '.accessibility'

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(parameters, sort_keys=True, indent=4))

	os.system("bpnet pipeline -p {}".format(name))	


	# Run pipeline with bias model
	name = '{}.chrombpnet.bias.pipeline.json'.format(model_name)

	parameters['model'] = parameters['bias_model']
	parameters['name'] = model_name + '.bias'

	with open(name, 'w') as outfile:
		outfile.write(json.dumps(parameters, sort_keys=True, indent=4))

	os.system("bpnet pipeline -p {}".format(name))	

