#!/home/sam/anaconda3/bin/python
import matplotlib.pyplot as plt 
import os,sys 
import matplotlib.gridspec as gridspec
import argparse 
import numpy as np
import warnings
warnings.filterwarnings("ignore")

import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astroplan import FixedTarget, Observer, EclipsingSystem
from astroplan import (PrimaryEclipseConstraint, is_event_observable, AtNightConstraint, AltitudeConstraint, LocalTimeConstraint)
from astroplan.plots import dark_style_sheet, plot_airmass, plot_sky

description = '''Transit prediciton'''

# Argument parser
parser = argparse.ArgumentParser('predict', description=description)

parser.add_argument('-a', 
                    '--t_zero',
                    help='The transit epoch in arbritraty time units consisting with the input file.', 
                    default=2457500.60872, type=float)

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=4.70744, type=float)  

parser.add_argument('-c', 
                    '--width',
                    help='The transit width in hrs.', 
                    default=5.8, type=float)

parser.add_argument('-d', 
                    '--ntransits',
                    help='The number of transits to predict.',
                    default=10, type=int) 

parser.add_argument('-e', 
                    '--ra',
                    help='The RA in deg.', 
                    default=205.96079385700, type=float)

parser.add_argument('-f', 
                    '--dec',
                    help='The Dec in deg.', 
                    default=-12.35998970810, type=float)      

parser.add_argument('--plot', action="store_true", default=False, help="Plot each night")
parser.add_argument('--complete', action="store_true", default=False, help="Only complete nights")


parser.add_argument('-g', 
                '--observatory',
                help='The Observatory.',
                default='SAAO') 

parser.add_argument('-j', 
                '--name',
                help='The target name.',
                default='Star 1')

#
parser.add_argument('-k', 
                '--date',
                help='The date from which to calculate. If not supplied, will default to today. Should be supplied as "2017-01-01"',
                default='now')

if __name__=='__main__':
    # Parse the arguments
    args = parser.parse_args()

    # Define the observatory
    observatory = Observer.at_site(args.observatory)

    # Define the coordinates
    skycoord = SkyCoord(args.ra, args.dec, frame='icrs', unit='deg')
    fixedtarget = FixedTarget(skycoord, name = args.name)

    primary_eclipse_time = Time(args.t_zero, format='jd')
    orbital_period = args.period * u.day
    eclipse_duration = args.width * u.hour

    # Define the eclipsing target
    eclipsetarget = EclipsingSystem(primary_eclipse_time=primary_eclipse_time,
                                    orbital_period=orbital_period, duration=eclipse_duration,
                                    name=args.name) 

    # Work out the current time
    if args.date == 'now' : obs_time = Time.now()
    else                  : obs_time = Time(args.date + ' 12:00')

    # Calculate the mid-transit times
    midtransit_times = []
    mask_observervable, mask_observervable_entirely =[],[]
    
    # Now do contraints 
    constraints = [AtNightConstraint.twilight_civil(), AltitudeConstraint(min=30*u.deg)]
    
    # now calculate transit times and make sure observability
    i_mult = 1
    while len(midtransit_times) < args.ntransits:
        # First get ingres and egress time, along with times of mid-transit
        ingressegress = eclipsetarget.next_primary_ingress_egress_time(obs_time, n_eclipses=i_mult*args.ntransits)
        midtransit_times = np.array(eclipsetarget.next_primary_eclipse_time(obs_time, n_eclipses=i_mult*args.ntransits))

        # Get mask to see if each epoch is observable
        mask_observervable = is_event_observable(constraints, observatory, fixedtarget, times=midtransit_times)[0]
        mask_observervable_entirely = is_event_observable(constraints, observatory, fixedtarget, times_ingress_egress=ingressegress)[0]

        # Finally, mask the epochs to make sure we have enough
        if args.complete : 
            midtransit_times = midtransit_times[mask_observervable_entirely]
            ingressegress = ingressegress[mask_observervable_entirely] 
            mask_observervable_entirely = mask_observervable_entirely[mask_observervable_entirely] 
        else             : 
            midtransit_times = midtransit_times[mask_observervable]
            ingressegress = ingressegress[mask_observervable] 
            mask_observervable_entirely = mask_observervable_entirely[mask_observervable] 

        # Multiple if we don't have enough
        i_mult = i_mult*2

    # Calculate sunrise and set times
    sun_set_times = [ observatory.sun_set_time(midtransit_times[i], which="previous") for i in range(len(midtransit_times)) ]
    sun_rise_times = [observatory.sun_rise_time(midtransit_times[i], which="next") for i in range(len(midtransit_times)) ]
    print('-------------------------------------------------------------------------------------------------------------------------------------------------')
    print('| Summary of Epochs for {:>15}                                                                                                         |'.format(args.name))
    print('| All times in UTC with airmass given in square brackets                                                                                        |')
    print('|------------------------------------------------------------------------------------------------------------------------------------------------')
    print('|{:>5} |    {:}    |     {:}       |        {:}        |        {:}        |        {:}        |      {:}      |    {:>7}     |'.format('Epoch', 'date', 'sunset', 'in', 'mid', 'out', 'sunrise', 'Complete transit'))
    print('|------|------------|------------------|------------------|-------------------|-------------------|-------------------|-------------------------|')
    for i in range(args.ntransits):
        mid_datetime = midtransit_times[i].datetime
        set_datetime = sun_set_times[i].datetime
        rise_datetime = sun_rise_times[i].datetime
        in_datetime = Time(ingressegress[i][0], format='jd').datetime
        out_datetime = Time(ingressegress[i][1], format='jd').datetime


        airmasses = [observatory.altaz(sun_set_times[i], skycoord).secz,
                    observatory.altaz( Time(ingressegress[i][0], format='jd'), skycoord).secz,
                    observatory.altaz( midtransit_times[i], skycoord).secz,
                    observatory.altaz( Time(ingressegress[i][1], format='jd'), skycoord).secz,
                    observatory.altaz(sun_rise_times[i], skycoord).secz]

        print('| {:>3}  | {:>4} {:0>2} {:0>2} | {:0>2}:{:0>2}:{:0>2.0f} [{:0>5.2f}] | {:0>2}:{:0>2}:{:0>2.0f} [{:0>5.2f}] | {:0>2}:{:0>2}:{:0>2.0f} [{:0>5.2f}]  | {:0>2}:{:0>2}:{:0>2.0f} [{:0>5.2f}]  | {:0>2}:{:0>2}:{:0>2.0f} [{:0>5.2f}]  |           {:}         |'.format(i+1,
                                 mid_datetime.year, mid_datetime.month, mid_datetime.day,
                                 set_datetime.hour, set_datetime.minute, set_datetime.second, airmasses[0],
                                 in_datetime.hour, in_datetime.minute, in_datetime.second, airmasses[1],
                                 mid_datetime.hour, mid_datetime.minute, mid_datetime.second, airmasses[2],
                                 out_datetime.hour, out_datetime.minute, out_datetime.second, airmasses[3],
                                 rise_datetime.hour, rise_datetime.minute, rise_datetime.second, airmasses[4],
                                 mask_observervable_entirely[i]
                                 ))
        # print('|{:>5} | {:>15}'.format(i+1, '{:>4} {:0>2} {:0>2} | {:0>2}:{:0>2}:{:0>2.0f} |       {:}       |'.format(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute, datetime.second , mask_observervable_entirely[i]  )))
    print('|-----------------------------------------------------------------------------------------------------------------------------------------------|')
    print('| Target set and rise times for each epoch                                                                                                      |')
    print('|-----------------------------------------------------------------------------------------------------------------------------------------------|')
    print('|{:} |   {:}   |    {:}   |                                                                                                                  |'.format('Epoch', 'rise', 'set'))
    print('|-----------------------------------------------------------------------------------------------------------------------------------------------|')
    try:
        for i in range(args.ntransits):

            rise_time  = observatory.target_rise_time(midtransit_times[i], fixedtarget, which='previous').datetime
            set_time = observatory.target_set_time(midtransit_times[i], fixedtarget, which='next').datetime
            print('| {:>3}  | {:0>2}:{:0>2}:{:0>2.0f} | {:0>2}:{:0>2}:{:0>2.0f} |                                                                                                                  |'.format(i+1,
                    rise_time.hour, rise_time.minute, rise_time.second,
                    set_time.hour, set_time.minute, set_time.second))
    except: pass
    print('-------------------------------------------------------------------------------------------------------------------------------------------------')







    # plotting commands, if needed
    if args.plot:
        i = 0

        f = plt.figure()
        ax1 = plt.gca() 

        # plot airmas
        ax2 = plot_airmass(fixedtarget, observatory, midtransit_times[i], style_sheet=dark_style_sheet, ax=ax1, brightness_shading=True, altitude_yaxis=True)
        
        # Calculate airmass

        #ax1.axvspan( (midtransit_times[i] - args.width*u.hour).datetime, (midtransit_times[i] + args.width*u.hour).datetime, ymin=0, ymax=1, color='none', alpha=0.5,hatch="X", edgecolor="b")
        ax1.fill_between( [(midtransit_times[i] - args.width*u.hour).datetime,(midtransit_times[i] + args.width*u.hour).datetime], [3,3], color='blue', alpha=0.2,hatch="X", edgecolor="b")


        ax1.set_title('Gray is nightitme\nBlue hatch is in transit')

        plt.savefig('{:}_Epoch_{:}.png'.format(args.name , i+1))
        plt.close()


        for i in range(args.ntransits):

            f = plt.figure()
            ax1 = plt.gca() 

            # plot airmas
            ax2 = plot_airmass(fixedtarget, observatory, midtransit_times[i], style_sheet=dark_style_sheet, ax=ax1, brightness_shading=True, altitude_yaxis=True)
            
            # Calculate airmass

            #ax1.axvspan( (midtransit_times[i] - args.width*u.hour).datetime, (midtransit_times[i] + args.width*u.hour).datetime, ymin=0, ymax=1, color='none', alpha=0.5,hatch="X", edgecolor="b")
            ax1.fill_between( [(midtransit_times[i] - args.width*u.hour/2).datetime,(midtransit_times[i] + args.width*u.hour/2).datetime], [3,3], color='blue', alpha=0.2,hatch="X", edgecolor="b")

            ax1.set_title('Gray is nightitme\nBlue hatch is in transit')

            plt.savefig('{:}_Epoch_{:}.png'.format(args.name, i+1))
            plt.close()     