#!python
# coding=utf-8

#  PINTS: Peak Identifier for Nascent Transcript Starts
#  Copyright (c) 2019-2025 Yu Lab.
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

import argparse
import datetime
import logging
import os
import sys

from pints import __version__

try:
    import pysam
    import numpy as np
    from Bio import SeqIO
    from pints.io_engine import log_assert

    from importlib.util import spec_from_loader, module_from_spec
    from importlib.machinery import SourceFileLoader
    spec = spec_from_loader("pints_visualizer", SourceFileLoader(
        "pints_visualizer", 
        os.path.join(os.path.split(__file__)[0], "pints_visualizer")))
    visualizer = module_from_spec(spec)
    spec.loader.exec_module(visualizer)
except ImportError as e:
    missing_package = str(e).replace("No module named '", "").replace("'", "")
    sys.exit("Please install %s first!" % missing_package)

logging.basicConfig(format="%(name)s - %(asctime)s - %(levelname)s: %(message)s",
                    datefmt="%d-%b-%y %H:%M:%S",
                    level=logging.INFO,
                    handlers=[
                        logging.StreamHandler()
                    ])
logger = logging.getLogger("PINTS - Normalizer")


def spikein_normalization(bam_files, spike_ins, exp_type, output_dir, cell_nums=(), spikein_scale=1000000, **kwargs):
    """
    Normalize samples using spike-in counts and cell counts

    Parameters
    ----------
    bam_files: list or tuple
        a list of paths of bam files
    spike_ins: str
        path to fasta file which contains all spike-in sequences
    exp_type: str

    output_dir: str
        output will be wrote to this folder
    cell_nums: None or list
        None or list of cell nums per sample
    spikein_scale: int
        Spike-in scale
    kwargs

    Returns
    -------

    """
    log_assert(os.path.exists(spike_ins), "Cannot access spike in file %s" % spike_ins, logger)
    spike_in_IDs = set()
    spike_in_counts = np.zeros(len(bam_files), dtype=int)
    total_read_counts = np.zeros(len(bam_files), dtype=int)
    if cell_nums is not None and len(cell_nums) == len(bam_files):
        cell_nums = np.asarray(cell_nums)
    else:
        cell_nums = np.ones(len(bam_files))

    for spikein in SeqIO.parse(spike_ins, "fasta"):
        spike_in_IDs.add(spikein.id)
    for sample_index, bam in enumerate(bam_files):
        _pysam_obj = pysam.AlignmentFile(bam)
        for read in _pysam_obj:
            if read.mapq >= kwargs["mapq_threshold"]:
                total_read_counts[sample_index] += 1
                if read.reference_name in spike_in_IDs:
                    spike_in_counts[sample_index] += 1

    logger.info("Total reads per sample:")
    logger.info("\t".join(list(map(str, total_read_counts))))
    logger.info("Spike-in reads per sample:")
    logger.info("\t".join(list(map(str, spike_in_counts))))
    log_assert(sum(spike_in_counts == 0) == 0,
               "Sample must contain spike-in reads to be normalized, otherwise please use RPM",
               logger)
    logger.info("Percent of Spike-ins per sample:")
    spikein_percents = spike_in_counts / total_read_counts
    logger.info("\t".join(list(map(lambda x: "{:.3%}".format(x), spikein_percents))))
    med = np.median(spikein_percents)
    tmp = np.abs(spikein_percents - med)
    mad = 1.4826 * np.median(np.abs(spikein_percents - med))
    problematic_samples = tmp / mad > 2.5
    if sum(problematic_samples) > 0:
        logger.warning("Be cautious for sample(s) %s (spike-in ratios)" % ",".join(
            list(map(str, np.where(tmp / mad > 2.5)[0] + 1))))
    spikein_scales = spikein_scale / (spike_in_counts * cell_nums)
    logger.info("Spike-in scales per sample:")
    logger.info("\t".join(list(map(str, spikein_scales))))

    for i in range(spikein_scales.shape[0]):
        logger.info("Normalizing %s" % bam_files[i])
        fn, ext = os.path.splitext(bam_files[i])
        output_prefix = fn.replace(ext, "")
        bam_to_bigwig(bam_files[i], exp_type=exp_type, output_pref=os.path.join(output_dir, output_prefix),
                      normalization_factor=spikein_scales[i], chromosome_startswith=kwargs["chromosome_startswith"])
        logger.info("Done: normalizing %s" % bam_files[i])


def RPM_normalization(bam_files, exp_type, output_dir, rc=False, **kwargs):
    """
    RPM normalization

    Parameters
    ----------
    bam_files: list or tuple
        a list of paths of bam files
    exp_type: str

    output_dir: str
        output will be wrote to this folder
    rc: bool
        whether the parser needs to calc reverse complement
    kwargs

    Returns
    -------

    """
    for bam in bam_files:
        logger.info("Working on %s" % bam)
        bam_to_bigwig(bam, exp_type=exp_type, output_dir=output_dir, normalization_factor=1, rpm=True, rc=rc, **kwargs)
        bam_to_bigwig(bam, exp_type=exp_type, output_dir=output_dir, normalization_factor=1, rpm=True, rc=rc, **kwargs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-m", "--method", help="Normalization method, acceptable values: RPM or Spikein")
    parser.add_argument("-b", "--bam", nargs="+", required=True)
    parser.add_argument("-f", "--fasta-spike-in")
    parser.add_argument("-e", "--exp-type")
    parser.add_argument("-o", "--output-dir")
    parser.add_argument("-c", "--cell-numbers", nargs="+", required=False, type=int)
    parser.add_argument("-r", "--reverse-complement", action="store_true", dest="seq_reverse_complement",
                        required=False, default=False,
                        help="Set this switch if reads in this library represent the reverse complement of nascent "
                             "RNAs, like PROseq")
    parser.add_argument("--mapq-threshold", action="store", dest="mapq_threshold",
                        type=int, required=False, default=30, help="Minimum mapping quality")
    parser.add_argument("--chromosome-start-with", action="store", dest="chromosome_startswith",
                        type=str, required=False, default="chr",
                        help="Only keep reads mapped to chromosomes with this prefix")
    parser.add_argument("-v", "--version", action="version", version=__version__)
    args = parser.parse_args()

    DEFAULT_PREFIX = "normalizing_" + datetime.datetime.now().strftime("%Y_%M_%d_%H_%M_%S")
    logger.addHandler(logging.FileHandler(os.path.join(args.output_prefix, "{0}.log".format(DEFAULT_PREFIX))))
    logger.info("PINTS version: {0}".format(__version__))
    logger.info("Command: {0}".format(" ".join(sys.argv)))

    for bam in args.bam:
        log_assert(os.path.exists(bam), "Cannot access bam file %s" % bam, logger)
    logger.info("Total samples to be normalized: %d" % len(args.bam))

    if args.method == "RPM":
        RPM_normalization(bam_files=args.bam, exp_type=args.exp_type, output_dir=args.output_dir,
                          rc=args.seq_reverse_complement, mapq_threshold=args.mapq_threshold,
                          chromosome_startswith=args.chromosome_startswith)
    elif args.method == "Spikein":
        spikein_normalization(bam_files=args.bam, spike_ins=args.fasta_spike_in, exp_type=args.exp_type,
                              output_dir=args.output_dir, cell_nums=args.cell_numbers,
                              mapq_threshold=args.mapq_threshold, chromosome_startswith=args.chromosome_startswith,
                              rc=args.seq_reverse_complement)
    else:
        logger.error("Not supported normalization method")
