#!python
# coding=utf-8

#  PINTS: Peak Identifier for Nascent Transcript Starts
#  Copyright (c) 2019-2025 Yu Lab.
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

import argparse
import os
import pyBigWig
import pybedtools
import numpy as np
import pandas as pd


__EXAMPLE__ = """
Example: 
pints_de -b C1_1_bidirectional_peaks.bed C2_1_bidirectional_peaks.bed \
    -p C1_R1_pl.bw C1_R2_pl.bw C2_R1_pl.bw C2_R2_pl.bw \
    -m C1_R1_mn.bw C1_R2_mn.bw C2_R1_mn.bw C2_R2_mn.bw \
    -c C1 C1 C2 C2 -r 1 2 1 2 -s c1c2.csv -d 
"""


def prepare_regions(primary_regions, additional_regions=None, max_dist=0):
    """
    Prepare regions based on primary and additional regions in each library.

    Parameters
    ----------
    primary_regions : Sequence[str]
        A list/tuple of primary regions (e.g. bidirectional peaks)
    additional_regions : Optional[Sequence[str]]
        A list/tuple of additional regions (e.g. unidirectional peaks)
    max_dist : int
        Maximum distance between features allowed for features to be merged.
        Default is 0 (overlapping & book-ended features are merged.)

    Returns
    -------
    merged_regions : pybedtools.BedTool
        Merged regions
    """
    all_files = list(primary_regions)
    if additional_regions:
        all_files.extend(list(additional_regions))
    merged_regions = pybedtools.BedTool.cat(
        *[pybedtools.BedTool(f) for f in all_files], postmerge=False
    ).sort().merge(
        d=max_dist).saveas()
    return merged_regions


def get_counts(regions, primary_signal, additional_signal=None):
    """
    Get read counts for regions in certain library

    Parameters
    ----------
    regions : pybedtools.BedTool
        Regions
    primary_signal : str
        Path to a bigWig file storing all primary signals (e.g. signal on the forward strand)
    additional_signal : str
        Path to a bigWig file storing all additional signals (e.g. signal on the reverse strand)

    Returns
    -------

    """
    vec_counts = np.zeros(len(regions), dtype=int)
    bws = [pyBigWig.open(primary_signal)]
    if additional_signal:
        bws.append(pyBigWig.open(additional_signal))

    for i, row in enumerate(regions):
        vec_counts[i] = np.sum([np.abs(np.nan_to_num(b.values(row[0], int(row[1]), int(row[2])))) for b in bws])

    return vec_counts


def check_file(in_file):
    """

    Parameters
    ----------
    in_file

    Returns
    -------

    """
    if not os.path.exists(in_file):
        raise IOError("File {} does not exist".format(in_file))


def get_deseq2_template(counts_file, condition_names):
    """
    Get a boilerplate template for DESeq2

    Parameters
    ----------
    counts_file : str
        Path to the counts file
    condition_names : Sequence[str]
        Names of conditions

    Returns
    -------

    """
    design = " + ".join(condition_names)
    condition_names.append("rep")
    flatten_condition_names = '"' + '", "'.join(condition_names) + '"'
    tpl = """
library(DESeq2)
cts <- read.csv("{counts_file}", row.names=1)
coldata <- as.data.frame(do.call(rbind, strsplit(colnames(cts), "_")))
colnames(coldata) <- c({flatten_condition_names})
rownames(coldata) <- colnames(cts)
dds <- DESeqDataSetFromMatrix(countData = cts,
                              colData = coldata,
                              design = ~ {design})
dds <- DESeq(dds)
res <- results(dds)
norm_counts <- counts(dds, normalized=TRUE)
""".format(
        counts_file=counts_file,
        flatten_condition_names=flatten_condition_names,
        design=design)

    print(tpl)


def main(bidirectional_calls, pl_bws, conditions, replications, save_to,
         unidirectional_calls=None, mn_bws=None, condition_names=None,
         max_dist=0, print_deseq2=True):
    """

    Parameters
    ----------
    bidirectional_calls : Sequence[str]
        Bidirectional calls
    pl_bws : Sequence[str]
        Signal files for the forward strand
    conditions : Sequence[str]
        Conditions for each library, e.g. treatment or control.
        If you have multiple conditions, e.g. treatment or control and male or female, concatenate
        conditions with underscore ('_'). So the input should have the following format:
            > treatment_male
            > ...
            > control_female
    replications : Sequence[int]
        Ordinal numbers of replications. For example, if you have two replicates for each condition,
        for the first replicate of cond1_cond2, then put 1 for the library; for the second replicate, put 2.
    save_to : str
        Save the count matrix to a file
    unidirectional_calls : Optional[Sequence[str]]
        Unidirectional calls
    mn_bws : Optional[Sequence[str]]
        Signal files for the reverse strand
    condition_names : Optional[Sequence[str]]
        Names of conditions
    max_dist : int
        Maximum distance between TRES allowed for TREs to be merged.
    print_deseq2 : bool
        Print boilerplate for running DESeq2

    Returns
    -------

    """
    # sanity check
    if len({len(pl_bws), len(conditions), len(replications)}) != 1:
        raise ValueError("pl_bws, conditions, and replications must have the same length.")

    for f in pl_bws: check_file(f)
    for f in bidirectional_calls: check_file(f)

    if unidirectional_calls is not None:
        if len(unidirectional_calls) != len(bidirectional_calls):
            raise ValueError("When specified, unidirectional_calls must have the same length as bidirectional_calls.")
        for f in unidirectional_calls: check_file(f)
    if mn_bws is not None:
        if len(mn_bws) != len(pl_bws):
            raise ValueError("When specified, mn_bws must have the same length as pl_bws.")
        for f in mn_bws: check_file(f)
    per_sample_conditions = [len(c.split("_")) for c in conditions]
    if len(set(per_sample_conditions)) != 1:
        raise ValueError("Libraries must have the same number of conditions. {}".format(per_sample_conditions))

    if condition_names is None:
        condition_names = ["Cond{}".format(i) for i in range(1, per_sample_conditions[0] + 1)]

    # actual logic
    union_of_regions = prepare_regions(bidirectional_calls, unidirectional_calls, max_dist=max_dist)
    labels = ["{}_{}".format(c, replications[i]) for i, c in enumerate(conditions)]
    counts_mat_dict = dict()

    for i, label in enumerate(labels):
        print("Working on {} using {}".format(label, (pl_bws[i], mn_bws[i] if mn_bws else None)))
        counts_mat_dict[label] = get_counts(union_of_regions, pl_bws[i], mn_bws[i] if mn_bws else None)

    region_names = ["{}:{}-{}".format(r[0], r[1], r[2]) for r in union_of_regions]
    counts_mat_df = pd.DataFrame(counts_mat_dict, index=region_names)
    counts_mat_df.to_csv(save_to)

    if print_deseq2: get_deseq2_template(save_to, condition_names)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(epilog=__EXAMPLE__)
    parser.add_argument("-b", "--bidirectional-calls", nargs="+", type=str, required=True,
                        help="Bidirectional peak calls for each condition")
    parser.add_argument("-u", "--unidirectional-calls", nargs="+", type=str,
                        help="[optional] Unidirectional peak calls for each condition.")
    parser.add_argument("-p", "--pl-bws", nargs="+", type=str, required=True,
                        help="Forward strand signal (in bw file) for each condition, each replicate")
    parser.add_argument("-m", "--mn-bws", nargs="+", type=str,
                        help="[optional] Reverse strand signal (in bw file) for each condition, each replicate.")
    parser.add_argument("-c", "--conditions", nargs="+", type=str, required=True,
                        help="Conditions for each library. If you have replicates for each condition, "
                             "you need to specify them here and in --replicates. For example, if you have 3 replicates "
                             "for condA, then you need to type `condA condA condA` here.")
    parser.add_argument("-n", "--condition-names", nargs="+", type=str, required=False,
                        help="Names of conditions. For example, if you specify conditions as t1_t2 in `--conditions`, "
                             "you can specify names for t1 and t2 here. If not specified, conditions will be named"
                             "as Cond1, ..., Condn.")
    parser.add_argument("-r", "--replicates", nargs="+", type=int, required=True,
                        help="Ordinal numbers of replications. For example, if you have three replicates for condA, "
                             "you need to type `1 2 3` here.")
    parser.add_argument("--max-dist", type=int, required=False, default=0,
                        help="Maximum distance between TRES allowed for TREs to be merged. "
                             "Default 0: overlapping & book-ended features are merged. "
                             "Negative values enforce the number of b.p. required for overlap. "
                             "See bedtools manual for more details.")
    parser.add_argument("-s", "--save-to", type=str, required=True)
    parser.add_argument("-d", "--deseq2", action="store_true",
                        help="Generate R code for using DESeq2 to identify differentially expressed peaks")

    args = parser.parse_args()

    pybedtools.set_tempdir(".")
    main(
        bidirectional_calls=args.bidirectional_calls,
        pl_bws=args.pl_bws, conditions=args.conditions, condition_names=args.condition_names,
        replications=args.replicates, save_to=args.save_to,
        unidirectional_calls=args.unidirectional_calls, mn_bws=args.mn_bws,
        print_deseq2=args.deseq2
    )
