#!/usr/bin/env python

"""For a given external trigger (GRB, FRB, neutrino, etc...), generate a sky
grid covering its localization error region.

This sky grid will be used by `pycbc_multi_inspiral` to find multi-detector
gravitational wave triggers and calculate the coherent SNRs and related
statistics. 

The grid is constructed following the method described in Section V of
https://arxiv.org/abs/1410.6042.

Please refer to help(pycbc.types.angle_as_radians) for the recommended
configuration file syntax for angle arguments.
"""

import numpy as np
import argparse
import itertools
from scipy.spatial.transform import Rotation as R

import pycbc
from pycbc.detector import Detector
from pycbc.types import angle_as_radians
from pycbc.tmpltbank.sky_grid import SkyGrid


def spher_to_cart(sky_points):
    """Convert spherical coordinates to cartesian coordinates."""
    cart = np.zeros((len(sky_points), 3))
    cart[:, 0] = np.cos(sky_points[:, 0]) * np.cos(sky_points[:, 1])
    cart[:, 1] = np.sin(sky_points[:, 0]) * np.cos(sky_points[:, 1])
    cart[:, 2] = np.sin(sky_points[:, 1])
    return cart


def cart_to_spher(sky_points):
    """Convert cartesian coordinates to spherical coordinates."""
    spher = np.zeros((len(sky_points), 2))
    spher[:, 0] = np.arctan2(sky_points[:, 1], sky_points[:, 0])
    spher[:, 1] = np.arcsin(sky_points[:, 2])
    spher[spher[:, 0] <0, 0] += 2 * np.pi
    return spher

def make_single_det_grid(args):
    """Construct a sky grid for a single-detector analysis.
    Since a one-detector network has no sky localization capability,
    we just put a single point at the center of the input distribution.
    """
    sky_grid = SkyGrid(
        [args.ra], [args.dec], args.instruments, args.trigger_time
    )
    return sky_grid

def make_multi_det_grid(args):
    '''
    Compute a skygrid object using the time delay among detectors 
    in the network.
    '''
    args.instruments.sort()  # Put the ifos in alphabetical order
    detectors = args.instruments
    detectors = [Detector(d) for d in detectors]
    detector_pairs = list(itertools.combinations(detectors, 2))

    # Calculate the time delay for each detector pair
    tds = [
        detector_pairs[i][0].time_delay_from_detector(
            detector_pairs[i][1], args.ra, args.dec, args.trigger_time
        )
        for i in range(len(detector_pairs))
    ]

    # Calculate the light travel time between the detector pairs
    light_travel_times = [
        detector_pairs[i][0].light_travel_time_to_detector(detector_pairs[i][1])
        for i in range(len(detector_pairs))
    ]

    # Calculate the required angular spacing between the sky points
    ang_spacings = [
        (2 * args.timing_uncertainty)
        / np.sqrt(light_travel_times[i] ** 2 - tds[i] ** 2)
        for i in range(len(detector_pairs))
    ]
    angular_spacing = min(ang_spacings)

    sky_points = np.zeros((1, 2))

    number_of_rings = int(args.sky_error / angular_spacing)

    # Generate the sky grid centered at the North pole
    for i in range(number_of_rings + 1):
        if i == 0:
            sky_points[0][0] = 0
            sky_points[0][1] = np.pi / 2
        else:
            number_of_points = int(2 * np.pi * i)
            for j in range(number_of_points):
                sky_points = np.row_stack(
                    (
                        sky_points,
                        np.array([j / i, np.pi / 2 - i * angular_spacing]),
                    )
                )

    # Convert spherical coordinates to cartesian coordinates
    cart = spher_to_cart(sky_points)

    grb = np.zeros((1, 2))
    grb[0] = args.ra, args.dec
    grb_cart = spher_to_cart(grb)

    north_pole = [0, 0, 1]

    ort = np.cross(grb_cart, north_pole)
    norm = np.linalg.norm(ort)
    ort /= norm
    n = -np.arccos(np.dot(grb_cart, north_pole))
    u = ort * n

    # Rotate the sky grid to the center of the external trigger error box
    r = R.from_rotvec(u)
    rota = r.apply(cart)

    # Convert cartesian coordinates back to spherical coordinates
    spher = cart_to_spher(rota)

    sky_grid = SkyGrid(
        spher[:, 0], spher[:, 1], args.instruments, args.trigger_time
    )
    return sky_grid

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument(
    '--ra',
    type=angle_as_radians,
    required=True,
    help="Right ascension of the center of the external trigger "
    "error box. Use the rad or deg suffix to specify units, "
    "otherwise radians are assumed.",
)
parser.add_argument(
    '--dec',
    type=angle_as_radians,
    required=True,
    help="Declination of the center of the external trigger "
    "error box. Use the rad or deg suffix to specify units, "
    "otherwise radians are assumed.",
)
parser.add_argument(
    '--instruments',
    nargs="+",
    type=str,
    required=True,
    help="List of instruments to analyze.",
)
parser.add_argument(
    '--sky-error',
    type=angle_as_radians,
    required=True,
    help="3-sigma confidence radius of the external trigger error "
    "box. Use the rad or deg suffix to specify units, otherwise "
    "radians are assumed.",
)
parser.add_argument(
    '--trigger-time',
    type=int,
    required=True,
    help="Time (in s) of the external trigger",
)
parser.add_argument(
    '--timing-uncertainty',
    type=float,
    default=0.0001,
    help="Timing uncertainty (in s) we are willing to accept",
)
parser.add_argument(
    '--output', type=str, required=True, help="Name of the sky grid"
)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

if len(args.instruments) == 1:
    # Make a single point skygrid object
    sky_grid = make_single_det_grid(args)
else:
    # Make a skygrid object using timedelays among all detectors in the
    # network.
    sky_grid = make_multi_det_grid(args)

extra_attributes = {
    'trigger_ra': args.ra,
    'trigger_dec': args.dec,
    'sky_error': args.sky_error,
    'timing_uncertainty': args.timing_uncertainty
}

sky_grid.write_to_file(args.output, extra_attributes)
