#!python
# coding=utf-8

#  PINTS: Peak Identifier for Nascent Transcript Starts
#  Copyright (c) 2019-2025 Yu Lab.
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

import argparse
import datetime
import logging
import os
import sys

from pints import __version__

logging.basicConfig(format="%(name)s - %(asctime)s - %(levelname)s: %(message)s",
                    datefmt="%d-%b-%y %H:%M:%S",
                    level=logging.INFO,
                    handlers=[
                        logging.StreamHandler()
                    ])
logger = logging.getLogger("PINTS - Visualizer")

try:
    import pyBigWig
    import numpy as np
    from pints.io_engine import get_read_signal
except ImportError as e:
    missing_package = str(e).replace("No module named '", "").replace("'", "")
    logger.error("Please install %s first!" % missing_package)
    sys.exit(-1)


def _strand_atom(cov_dict, normalization_factor, strand_sign):
    """
    Generate bigwig for a certain strand

    Parameters
    ----------
    cov_dict: dict
        dict containing file path to coverages for each chromosome (numpy.array)
    normalization_factor: float
        normalization factors
    strand_sign: str
        "+" or "-"

    Returns
    -------
    signal: list
        Signals in list [chrom, start, end, score]
    chromosome_sizes: list
        Chromosome sizes in list [(chrom, size), ...]
    total_signal: int
        Sum of all signals in this strand
    """
    total_signal = 0
    chrom_sizes = []
    chroms = []
    starts = []
    ends = []
    values = []
    for chrom in cov_dict:
        logger.info("Handling %s (%s)" % (chrom, strand_sign))
        coverage = np.load(cov_dict[chrom])
        total_signal += coverage.sum()
        chrom_sizes.append((chrom, coverage.shape[0]))
        nonzero_loci = np.where(coverage > 0)[0]
        nonzero_coverage = coverage[nonzero_loci]
        logger.info("%d nonzero loci" % nonzero_loci.shape[0])
        for i in range(nonzero_loci.shape[0]):
            chroms.append(chrom)
            starts.append(nonzero_loci[i])
            ends.append(nonzero_loci[i] + 1)
            values.append(nonzero_coverage[i] * normalization_factor)
        logger.info("Finished %s (%s)" % (chrom, strand_sign))
    if strand_sign == "-":
        values = (np.asarray(values) * -1).tolist()
    return [chroms, starts, ends, values], chrom_sizes, total_signal


def _ranges_to_bw(chrom_sizes, ranges, save_to):
    """
    Convert regions/ranges to bigwig

    Parameters
    ----------
    chrom_sizes : tuple/list of tuples, ((chrom_name, size), )
    ranges : list of lists, [[chroms], [starts], [ends], [scores]]
    save_to : str
        save

    Returns
    -------

    """
    bw = pyBigWig.open(save_to, "w")
    bw.addHeader(chrom_sizes)

    s = list(map(int, ranges[1]))
    e = list(map(int, ranges[2]))
    v = list(map(float, ranges[3]))
    bw.addEntries(ranges[0], s,
                  ends=e, values=v)
    bw.close()


def coverage_dict_to_bigwig(pl_dict, mn_dict, output_pref, normalization_factor=1., rpm_normalization=False):
    """
    Convert coverage dictionary to bigwig files

    Parameters
    ----------
    pl_dict : dict
        Dictionary containing coverage files for all chromosomes on forward strand
    mn_dict : dict
        Dictionary containing coverage files for all chromosomes on reverse strand
    output_pref : str
        Output will be written to output_pref+_pl.bw and output_pref+_mn.bw
    normalization_factor : float
        Normalization factor, by default, 1.0
    rpm_normalization : bool
        RPM normalization or not, if set to True, then normalization_factor will be fixed at 1

    Returns
    -------

    """
    logger.info("Generating signal tracks for forward strand")
    pl_ranges, pl_csize, pl_sum = _strand_atom(pl_dict, normalization_factor=normalization_factor,
                                               strand_sign="+")
    logger.info("Done: Generating signal tracks for forward strand")

    logger.info("Generating signal tracks for reverse strand")
    mn_ranges, mn_csize, mn_sum = _strand_atom(mn_dict, normalization_factor=normalization_factor,
                                               strand_sign="-")
    logger.info("Done: Generating signal tracks for reverse strand")

    if rpm_normalization:
        logger.info("%d reads passed filter" % (pl_sum + mn_sum))
        rpm_scale = 1000 * 1000 / (pl_sum + mn_sum)
        logger.info("RPM normalizing (forward strand) with scale factor {:.3f}".format(rpm_scale))
        pl_ranges[3] = (np.asarray(pl_ranges[3]) * rpm_scale).tolist()
        logger.info("Done: RPM normalizing (forward strand)")

        logger.info("RPM normalizing (reverse strand) with scale factor {:.3f}".format(rpm_scale))
        # mn_ranges.Score *= rpm_scale
        mn_ranges[3] = (np.asarray(mn_ranges[3]) * rpm_scale).tolist()
        logger.info("Done: RPM normalizing (reverse strand)")

    logger.info("Generating bigwig file for forward strand")
    pl_fn = os.path.join(output_pref + "_pl.bw")
    if len(pl_ranges[0]) > 0:
        _ranges_to_bw(pl_csize, pl_ranges, pl_fn)
        logger.info("Done: Bigwig file for forward strand (%s)" % pl_fn)
    else:
        logger.error("Failed to generate bigwig for forward strand (%s), no signal" % pl_fn)

    logger.info("Generating bigwig file for reverse strand")
    mn_fn = os.path.join(output_pref + "_mn.bw")
    if len(mn_ranges[0]) > 0:
        _ranges_to_bw(mn_csize, mn_ranges, mn_fn)
        logger.info("Done: Bigwig file for reverse strand (%s)" % mn_fn)
    else:
        logger.error("Failed to generate bigwig for reverse strand (%s), no signal" % mn_fn)


def bam_to_bigwig(bam_file, exp_type, output_pref, normalization_factor, rpm=False, chromosome_startswith=None,
                  rc=False, cache=False, filters=[], **kwargs):
    """
    Convert bam file to bigwig file

    Parameters
    ----------
    bam_file: str
        path to the bam file
    exp_type: str

    output_pref: str
        Prefix to write output
    normalization_factor: float

    rpm: bool
        RPM normalization or not, if set to True, then normalization_factor will be fixed at 1
    chromosome_startswith: None or str
        Only consider chroms start with chromosome_startswith
    rc: bool
        Reverse complementary, default False
    cache : bool
        True if you want to keep genome coverage data from parsed bam, default False
    filters : list or tuple
        List of keywords to filter chromosomes
    kwargs

    Returns
    -------

    """
    if rpm is True and normalization_factor != 1:
        logger.warning("RPM switch is set as True, while normalization_factor is not 1, "
                       "normalization_factor will be ignored.")
        normalization_factor = 1
    output_dir = os.path.split(bam_file)[0]
    output_prefix = os.path.basename(bam_file)
    fn, ext = os.path.splitext(bam_file)
    output_prefix = output_prefix.replace(ext, "")
    logger.info("Start to parse bam file %s" % bam_file)
    filters = filters if filters is not None else set()
    chromosome_coverage_pl, chromosome_coverage_mn, _ = get_read_signal(input_bam=bam_file,
                                                                        loc_prime=exp_type,
                                                                        chromosome_startswith=chromosome_startswith,
                                                                        reverse_complement=rc,
                                                                        output_dir=output_dir,
                                                                        output_prefix=output_prefix,
                                                                        filters=filters,
                                                                        **kwargs)

    logger.info("Bam file parsed")

    coverage_dict_to_bigwig(pl_dict=chromosome_coverage_pl, mn_dict=chromosome_coverage_mn, output_pref=output_pref,
                            normalization_factor=normalization_factor, rpm_normalization=rpm)

    if not cache:
        housekeeping_files = []
        housekeeping_files.extend(chromosome_coverage_pl.values())
        housekeeping_files.extend(chromosome_coverage_mn.values())
        for hf in housekeeping_files:
            if os.path.exists(hf):
                try:
                    os.remove(hf)
                except:
                    pass


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-b", "--bam", action="store", required=True)
    parser.add_argument("-e", "--exp-type", action="store", default=None, dest="bam_parser",
                        help="Type of experiment, acceptable values are: CoPRO/GROcap/GROseq/PROcap/PROseq, or if you "
                             "know the position of RNA ends which you\'re interested on the reads, you can specify "
                             "R_5, R_3, R1_5, R1_3, R2_5 or R2_3")
    parser.add_argument("-r", "--reverse-complement", action="store_true", dest="seq_reverse_complement",
                        required=False, default=False,
                        help="Set this switch if reads in this library represent the reverse complement of nascent "
                             "RNAs, like PROseq")
    parser.add_argument("-c", "--rpm", action="store_true", dest="rpm",
                        required=False, default=False,
                        help="Set this switch if you want to use RPM to normalize the outputs")
    parser.add_argument("--mapq-threshold", action="store", dest="mapq_threshold",
                        type=int, required=False, default=30, help="Minimum mapping quality")
    parser.add_argument("--chromosome-start-with", action="store", dest="chromosome_startswith",
                        type=str, required=False, default="chr",
                        help="Only keep reads mapped to chromosomes with this prefix")
    parser.add_argument("-o", "--output-prefix", action="store", type=str)
    parser.add_argument("-f", "--filters", action="store", type=str, nargs="*",
                        help="reads from chromosomes whose names contain any matches in filters will be ignored")
    parser.add_argument("-n", "--norm-fact", action="store", default=1., type=float,
                        help="Normalization factor, can be generated by normalizer.py")
    parser.add_argument("-s", "--cache", action="store_true", dest="cache",
                        required=False, default=False,
                        help="Set this switch if you want to reuse the intermediate .npy files.")
    parser.add_argument("-v", "--version", action="version", version=__version__)

    args = parser.parse_args()

    DEFAULT_PREFIX = "visualizing_" + datetime.datetime.now().strftime("%Y_%M_%d_%H_%M_%S")
    logger.addHandler(logging.FileHandler("{0}.log".format(DEFAULT_PREFIX)))
    logger.info("PINTS version: {0}".format(__version__))
    logger.info("Command: {0}".format(" ".join(sys.argv)))

    bam_to_bigwig(bam_file=args.bam, exp_type=args.bam_parser, output_pref=args.output_prefix,
                  normalization_factor=args.norm_fact, mapq_threshold=args.mapq_threshold,
                  chromosome_startswith=args.chromosome_startswith, rpm=args.rpm,
                  filters=args.filters, cache=args.cache, rc=args.seq_reverse_complement)
