#!/usr/bin/env python
# tf-modisco command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>, Ivy Raine <ivy.ember.raine@gmail.com>

from typing import List, Literal, Union
import h5py
import hdf5plugin
import argparse
import modiscolite

import numpy as np

from modiscolite.util import calculate_window_offsets


desc = """TF-MoDISco is a motif detection algorithm that takes in nucleotide
	sequence and the attributions from a neural network model and return motifs
	that are repeatedly enriched for attriution score across the examples.
	This tool will take in one-hot encoded sequence, the corresponding
	attribution scores, and a few other parameters, and return the motifs."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'motifs', 'report', 'convert', 'convert-backward', 'meme', 'seqlet-bed', or 'seqlet-fasta'.", required=True, dest='cmd')

motifs_parser = subparsers.add_parser("motifs", help="Run TF-MoDISco and extract the motifs.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
motifs_parser.add_argument("-s", "--sequences", type=str,
	help="A .npy or .npz file containing the one-hot encoded sequences.")
motifs_parser.add_argument("-a", "--attributions", type=str,
	help="A .npy or .npz file containing the hypothetical attributions, i.e., the attributions for all nucleotides at all positions.")
motifs_parser.add_argument("-i", "--h5py", type=str,
	help="A legacy h5py file containing the one-hot encoded sequences and shap scores.")
motifs_parser.add_argument("-n", "--max_seqlets", type=int, required=True,
	help="The maximum number of seqlets per metacluster.")
motifs_parser.add_argument("-l", "--n_leiden", type=int, default=2,
	help="The number of Leiden clusterings to perform with different random seeds.")
motifs_parser.add_argument("-w", "--window", type=int, default=400,
	help="The window surrounding the peak center that will be considered for motif discovery.")
motifs_parser.add_argument("-z", "--size", type=int, default=20,
	help="The size of the seqlet cores, corresponding to `sliding_window_size`.")
motifs_parser.add_argument("-t", "--trim_size", type=int, default=30,
	help="The size to trim to, corresponding to `trim_to_window_size`.")
motifs_parser.add_argument("-f", "--seqlet_flank_size", type=int, default=5,
	help="The size of the flanks to add to each seqlet, corresponding to `flank_size`.")
motifs_parser.add_argument("-g", "--initial_flank_to_add", type=int, default=10,
	help="The size of the flanks to add to each pattern initially, corresponding to `initial_flank_to_add`.")
motifs_parser.add_argument("-j", "--final_flank_to_add", type=int, default=0,
	help="The size of the flanks to add to each pattern at the end, corresponding to `final_flank_to_add`.")
motifs_parser.add_argument("-o", "--output", type=str, default="modisco_results.h5",
	help="The path to the output file.")
motifs_parser.add_argument("-v", "--verbose", action="store_true", default=False,
	help="Controls the amount of output from the code.")

report_parser = subparsers.add_parser("report", help="Create a HTML report of the results.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
report_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file containing the output from modiscolite.")
report_parser.add_argument("-o", "--output", type=str, required=True,
	help="A directory to put the output results including the html report.")
report_parser.add_argument("-t", "--write-tomtom", action="store_true",
	default=False,
	help="Write the TOMTOM results to the output directory if flag is given.")
report_parser.add_argument("-s", "--suffix", type=str, default="./",
	help="The suffix to add to the beginning of images. Should be equal to the output if using a Jupyter notebook.")
report_parser.add_argument("-m", "--meme_db", type=str, default=None,
	help="A MEME file containing motifs.")
report_parser.add_argument("-n", "--n_matches", type=int, default=3,
	help="The number of top TOMTOM matches to include in the report.")

convert_parser = subparsers.add_parser("convert", help="Convert an old h5py to the new format.")
convert_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file formatted in the original way.")
convert_parser.add_argument("-o", "--output", type=str, required=True,
	help="An HDF5 file formatted in the new way.")

convertback_parser = subparsers.add_parser("convert-backward", help="Convert a new h5py to the old format.")
convertback_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file formatted in the new way.")
convertback_parser.add_argument("-o", "--output", type=str, required=True,
	help="An HDF5 file formatted in the old way.")

meme_parser = subparsers.add_parser("meme", help="""Output a MEME file from a
modisco results file to stdout (default) and/or to a file (if specified).""",
	formatter_class=argparse.RawTextHelpFormatter)
meme_parser.add_argument("-i", "--h5py", type=str,
	help="An HDF5 file containing the output from modiscolite.")
meme_parser.add_argument("-t", "--datatype", type=modiscolite.util.MemeDataType,
	choices=list(modiscolite.util.MemeDataType), required=True,
	help="""A case-sensitive string specifying the desired data of the output file.,
The options are as follows:
- 'PFM':      The position-frequency matrix.
- 'CWM':      The contribution-weight matrix.
- 'hCWM':     The hypothetical contribution-weight matrix; hypothetical
              contribution scores are the contributions of nucleotides not encoded
              by the one-hot encoding sequence. 
- 'CWM-PFM':  The softmax of the contribution-weight matrix.
- 'hCWM-PFM': The softmax of the hypothetical contribution-weight matrix."""
)
meme_parser.add_argument("-o", "--output", type=str, help="The path to the output file.")
meme_parser.add_argument("-q", "--quiet", action="store_true", default=False,
	help="Suppress output to stdout.")

seqlet_bed_parser = subparsers.add_parser("seqlet-bed", help="""Output a BED
file of seqlets from a modisco results file to stdout (default) and/or to a
file (if specified).""")
seqlet_bed_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file containing the output from modiscolite.")
seqlet_bed_parser.add_argument("-o", "--output", type=str, default=None,
	help="The path to the output file.")
seqlet_bed_parser.add_argument("-p", "--peaksfile", type=str, required=True,
	help="The path to the peaks file. This is to compute the absolute start and\
end positions of the seqlets within a reference genome, as well as the chroms.")
seqlet_bed_parser.add_argument("-c", "--chroms", type=str, required=True,
	help="""A comma-delimited list of chromosomes, or '*', denoting which
chromosomes to process. Should be the same set of chromosomes used during
interpretation. '*' will use every chr in the provided peaks file.
Examples: 'chr1,chr2,chrX' || '*' || '1,2,X'.""")
seqlet_bed_parser.add_argument("-q", "--quiet", action="store_true", default=False,
	help="Suppress output to stdout.")
seqlet_bed_parser.add_argument("-w", "--windowsize", type=int,
	help="""Optional. This is for backwards compatibility for older modisco h5
files that don't contain the window size as an attribute. This should be set
the to size of the window around the peak center that was used for.""")

seqlet_fasta_parser = subparsers.add_parser("seqlet-fasta", help="""Output a FASTA 
file of seqlets from a modisco results file to stdout (default) and/or to a
file (if specified).""")
seqlet_fasta_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file containing the output from modiscolite.")
seqlet_fasta_parser.add_argument("-o", "--output", type=str, default=None,
	help="The path to the output file.")
seqlet_fasta_parser.add_argument("-p", "--peaksfile", type=str, required=True,
	help="The path to the peaks file. This is to compute the absolute start and\
end positions of the seqlets within a reference genome, as well as the chroms.")
seqlet_fasta_parser.add_argument("-s", "--sequences", type=str, required=True,
	help="A .npy or .npz file containing the one-hot encoded sequences.")
seqlet_fasta_parser.add_argument("-c", "--chroms", type=str, required=True,
	help="""A comma-delimited list of chromosomes, or '*', denoting which
chromosomes to process. Should be the same set of chromosomes used during
interpretation. '*' will use every chr in the provided peaks file.
Examples: 'chr1,chr2,chrX' || '*' || '1,2,X'.""")
seqlet_fasta_parser.add_argument("-q", "--quiet", action="store_true", default=False,
	help="Suppress output to stdout.")
seqlet_fasta_parser.add_argument("-w", "--windowsize", type=int,
	help="""Optional. This is for backwards compatibility for older modisco h5
files that don't contain the window size as an attribute. This should be set
the to size of the window around the peak center that was used for.""")

def convert_arg_chroms_to_list(chroms: str) -> Union[List[str], Literal['*']]:
	"""Converts the chroms argument to a list of chromosomes."""
	if chroms == "*":
		# Return all chromosome numbers
		return '*'
	else:
		return chroms.split(",")

# Pull the arguments
args = parser.parse_args()

if args.cmd == "motifs":
	if args.h5py is not None:
		# Load the scores
		scores = h5py.File(args.h5py, 'r')

		try:
			center = scores['hyp_scores'].shape[1] // 2
			start, end = calculate_window_offsets(center, args.window)
			
			attributions = scores['hyp_scores'][:, start:end, :]
			sequences = scores['input_seqs'][:, start:end, :]
		except KeyError:
			center = scores['shap']['seq'].shape[2] // 2
			start, end = calculate_window_offsets(center, args.window)
			
			attributions = scores['shap']['seq'][:, :, start:end].transpose(0, 2, 1)
			sequences = scores['raw']['seq'][:, :, start:end].transpose(0, 2, 1)

		scores.close()

	else:
		if args.sequences[-3:] == 'npy':
			sequences = np.load(args.sequences)
		elif args.sequences[-3:] == 'npz':
			sequences = np.load(args.sequences)['arr_0']

		if args.attributions[-3:] == 'npy':
			attributions = np.load(args.attributions)
		elif args.attributions[-3:] == 'npz':
			attributions = np.load(args.attributions)['arr_0']

		center = sequences.shape[2] // 2
		start, end = calculate_window_offsets(center, args.window)

		sequences = sequences[:, :, start:end].transpose(0, 2, 1)
		attributions = attributions[:, :, start:end].transpose(0, 2, 1)

	if sequences.shape[1] < args.window:
		raise ValueError("Window ({}) cannot be ".format(args.window) +
			"longer than the sequences".format(sequences.shape))

	sequences = sequences.astype('float32')
	attributions = attributions.astype('float32')


	pos_patterns, neg_patterns = modiscolite.tfmodisco.TFMoDISco(
		hypothetical_contribs=attributions, 
		one_hot=sequences,
		max_seqlets_per_metacluster=args.max_seqlets,
		sliding_window_size=args.size,
		flank_size=args.seqlet_flank_size,
		trim_to_window_size=args.trim_size,
		initial_flank_to_add=args.initial_flank_to_add,
		final_flank_to_add=args.final_flank_to_add,
		target_seqlet_fdr=0.05,
		n_leiden_runs=args.n_leiden,
		verbose=args.verbose)

	modiscolite.io.save_hdf5(args.output, pos_patterns, neg_patterns, args.window)

elif args.cmd == 'meme':
	modiscolite.io.write_meme_from_h5(args.h5py, args.datatype, args.output, args.quiet)

elif args.cmd == 'seqlet-bed':
	modiscolite.io.write_bed_from_h5(args.h5py, args.peaksfile, args.output, convert_arg_chroms_to_list(args.chroms), args.windowsize, args.quiet)

elif args.cmd == 'seqlet-fasta':
	modiscolite.io.write_fasta_from_h5(args.h5py, args.peaksfile, args.sequences, args.output, convert_arg_chroms_to_list(args.chroms), args.windowsize, args.quiet)

elif args.cmd == 'report':
	modiscolite.report.report_motifs(args.h5py, args.output,
		img_path_suffix=args.suffix, meme_motif_db=args.meme_db,
		is_writing_tomtom_matrix=args.write_tomtom,
		top_n_matches=args.n_matches)

elif args.cmd == 'convert':
	modiscolite.io.convert(args.h5py, args.output)

elif args.cmd == 'convert-backward':
	modiscolite.io.convert_new_to_old(args.h5py, args.output)
