#!/usr/bin/env python3.6
from   __future__ import print_function

# system imports
import os
import re
import sys
import collections
from   operator   import methodcaller
from   functools  import (reduce, partial)

# extensions
import numpy as np
from   pyrap import tables as pt

# own stuff
from jiveplot import (jplotter, command)

# program default(s)
NoWgt   = True    # do not produce weight plots
ScanNo  = None    # automatic scan selection
Version = "$Id: standardplots,v 1.1 2014-08-08 15:38:41 jive_cc Exp $"
# Polarization colour scheme as per JIVE standard.
# Single pol data gets coloured black.
PolCMap = "ckey p[rr]=2 p[ll]=3 p[rl]=4 p[lr]=5 p[none]=1"

# One-liner to split a list of things into two lists, one satisfying the predicate, the other not
partition = lambda p, l: reduce(lambda y_n, x: (y_n[0]+[x], y_n[1]) if p(x) else (y_n[0], y_n[1]+[x]), l, ([], []))
# function composition is really great
compose   = lambda *fns: lambda x: reduce(lambda a, f: f(a), reversed(fns), x)
Map       = lambda fn  : partial(map, fn)

# full MS path to plot base name transformation
mk_basenm = compose(partial(re.sub, r'\.ms', ''), os.path.basename, partial(re.sub, r'/*$', ''))
# CalSrc must be transformed such that it can be fed into "/{0}/i" regex,
# even if CalSrc is comma-separated list of sources.
# So: split CalSrc by ',', escape the individual source names and transform
#     to "(<src>|<src>|....)" as regex alternatives for matching
mk_calsrc = compose("({0})".format, "|".join, Map(re.escape), methodcaller('split', ','))

# We need to capture errors and terminate in stead of going on.
# replace the errorfunction from hvutil with one that terminates
def mkerrf(pfx):
    def actualerrf(msg):
        print("{0} {1}".format(pfx, msg))
        sys.exit(-1)
    return actualerrf
jplotter.hvutil.mkerrf = mkerrf

# returns None if the argument wasn't present, tp(<value>) if it was
# (such that it will give an exception if e.g. you expect int but
#  the user didn't pass a valid int
def get_val(arg, tp=str):
    conversion_error = False
    try:
        # is 'arg' given?
        aidx = sys.argv.index(arg)  # raises ValueError if not found FFS
        aval = sys.argv[aidx+1]     # raises IndexError

        # Check it doesn't start with a '-'!
        if aval[0]=='-':
            raise RuntimeError("Option {0} expects argument, got another option '{1}'" .format(arg, aval))

        # remove those arguments
        del sys.argv[aidx]; del sys.argv[aidx]
        # now set 'conversion_error' to True because the following
        # statement could (also) raise a ValueError (like the
        # "sys.argv.index()"). FFS Python! So we must take measures to tell
        # them apart
        conversion_error = True
        return tp(aval)
    except ValueError:
        if conversion_error:
            raise
        # no 'arg' given, don't complain
        return None
    except IndexError:
        # Mission option value to option
        raise RuntimeError("Mission optionvalue to {0}".format(arg))


def chunkert(f, l, cs, verbose=True):
    while f<l:
        n = min(cs, l-f)
        yield (f, n)
        f = f + n

def get_observed_subbands(ms):
    """Given a MS file, it checks which subbands each antenna actually observed.
    Meaning for which subbands each antenna has non-zero data.

    Returns a dictionary with the antenna names as keys,
    and a Python set with the observed subbands as values.
    """
    ants_spws = collections.defaultdict(set)
    with pt.table(ms, readonly=True, ack=False) as mstable:
        with pt.table(mstable.getkeyword('ANTENNA'), readonly=True, ack=False) as ms_ants:
            antenna_names = ms_ants.getcol('NAME')

        with pt.table(mstable.getkeyword('DATA_DESCRIPTION'), readonly=True, ack=False) as ms_spws:
            spw_names = ms_spws.getcol('SPECTRAL_WINDOW_ID')

        for (start, nrow) in chunkert(0, len(mstable), 5000):
            ants1 = mstable.getcol('ANTENNA1', startrow=start, nrow=nrow)
            ants2 = mstable.getcol('ANTENNA2', startrow=start, nrow=nrow)
            spws = mstable.getcol('DATA_DESC_ID', startrow=start, nrow=nrow)
            msdata = mstable.getcol('DATA', startrow=start, nrow=nrow)

            for antenna in antenna_names:
                for spw in spw_names:
                    cond = np.where((ants1 == antenna) & (ants2 == antenna) & (spws == spw))
                    if (msdata[cond] < 1e-7).all():
                        ants_spws[antenna].add(spw)

    return ants_spws

def find_best_subband(ants_spws):
    """Given the dictionary with the subbands that each antenna observed
    (format from get_observed_subbands' output), it returns the subband number
    at which most of the antennas are present.
    """
    counting = collections.Counter()
    for antenna in ants_spws:
        counting.update(ants_spws[antenna])

    # .most_common() returns a list with the orderered elements,
    # each in a two-elemen tuple: the element and the number of repetitions.
    return counting.most_common()[0][0]


def Usage(prog, full):
    print("Usage: {0} [-h] [-v] [-d] [-weight] [-scan #] <MS> <Refant> <CalSrc>".format(prog))
    if full:
        print("""
    where:
    -h      print this message and exit
    -v      print current version and exit
    -d      enable debug output in case of error
    -weight also produce weight plots (default: {1})
    -scan # select scan number # for plots, rather than
            automatically selected one(s)
    MS      the /path/to/the/measurementset.ms where to take data from
    Refant  which station to use as reference antenna (two-letter station
            code as found in the ANTENNA subtable/vexfile)
    CalSrc  which source to use as calibrator source for calibrator scan
            selection. May be comma-separated list of possible sources.
            Source names will be matched exactly, mind you.

    Produces JIVE standard plots for the MeasurementSet. These plots are:

    *  visibility amplitude/phase versus channel for up to four bands for
       all baselines to Refant for two scans on the Calibrator source;
       one at the beginning of the experiment and one at the end.
       Use '-scan #' to make these plots of a scan of own selection rather
       than automatic selection.
    *  visibility amplitude versus time for the whole experiment.
    *  weight versus time for each station, although since disk recording
       this plot has become mostly useless and it is not produced by
       default. Use '-weight' option to include it.
""".format(prog, not NoWgt))
    sys.exit( 0 if longhelp else 1 )


class Settings(object):
    def __init__(self, ms, refant, calsrc):
        self.measurementset = ms
        self.refant         = refant
        self.calsrc         = mk_calsrc(calsrc)
        self.myBasename     = mk_basenm( self.measurementset )
        self.tempFileName   = "/tmp/sptf-{0}.ps".format( os.getpid() )

    def cleanup(self):
        os.unlink( self.tempFileName )

#####################################################################
##
##                     Inspect command line
##
#####################################################################

# Remove the options that have a value
ScanNo = get_val('-scan', tp=int)

# Now split in options and arguments
(opts, args) = partition(lambda opt: opt[0]=='-', sys.argv[1:])

## Not three arguments given or help requested?
longhelp = ('-h' in opts)
if len(args)!=3 or longhelp:
    Usage(os.path.basename(sys.argv[0]), longhelp)

## Maybe print version?
if '-v' in opts:
    print(Version)
    sys.exit( 0 )

if '-weight' in opts:
    NoWgt = False

#####################################################################
##
##                     The plots themselves
##
#####################################################################

def open_ms(settings):
    # open MS and run indexr
    yield "ms {0}".format( settings.measurementset )
    yield "indexr"
    yield "refile {0}".format( settings.tempFileName )

def mk_anp_chan_cross_plot(settings, scansel, num):
    print("generating cross plots [anp/channel] calibrator scan [{0}]".format( num ))
    # all cross baselines to refant
    yield "bl {0}* -auto".format( settings.refant )
    yield "fq *;ch none"
    yield "avt vector;avc none"
    yield "pt anpchan"
    yield "y local"
    # select the scan
    yield "scan mid-30s to mid+30s where {0}".format( scansel(settings.calsrc) )
    yield "new all false bl true sb false"
    yield "multi true"
    yield "sort bl"
    yield PolCMap
    yield "nxy 2 4"
    yield "refile {0}-cross-{1}.ps/cps".format( settings.myBasename, num )
    yield "pl"
    print("done cross plots on calibrator scan")

def mk_amp_chan_auto_plot(settings, scansel, num):
    print("generating auto plots [amp/channel] calibrator scan [{0}]".format( num ))
    yield "bl auto"
    yield "fq */p;ch none"
    yield "avt scalar;avc none"
    yield "time none"
    yield "pt ampchan"
    yield "y 0 2"
    # select the scan
    yield "scan mid-30s to mid+30s where {0}".format( scansel(settings.calsrc) )
    yield "new all false bl true sb false time true"
    yield "multi true"
    yield "sort bl"
    yield PolCMap
    yield "nxy 2 4"
    yield "refile {0}-auto-{1}.ps/cps".format( settings.myBasename, num )
    yield "pl"
    print("done auto plots on calibrator scan")

def mk_amp_time_auto_plot(settings, scansel, num):
    print("generating auto plots [amp/time] calibrator scan")
    yield "bl auto"
    # select inner 90% of the channels
    yield "fq *;ch 0.1*last:0.9*last"
    yield "new all f bl t"
    yield "avt none;avc vector"
    yield "pt amptime"
    yield "y local"
    # select the scan
    yield "scan start-20m to end+100m where {0}".format( scansel(settings.calsrc) )
    yield "time"
    yield "sort bl"
    yield "refile {0}-amptime-{1}.ps/cps".format( settings.myBasename, num )
    yield "pl"
    print("done auto plots on calibrator scan")

def mk_anp_time_cross_plot(settings, timesel, sbsel, num):
    print("generating cross plots [anp/time] all scans")
    yield "bl {0}* -auto".format( settings.refant )
    # select inner 90% of the channels
    yield "fq {0}/p;ch 0.1*last:0.9*last".format(sbsel)
    yield "new all f bl t"
    yield "avt none;avc vector"
    yield "pt anptime"
    yield "y local"
    # select the scan
    # yield "scan start-20m to end+100m where {0}".format( scansel(settings.calsrc) )
    yield "src none"
    yield "time {0}".format(timesel)
    yield "sort bl"
    yield "ckey src src[none]=1"
    yield "ptsz 2"
    yield "refile {0}-ampphase-{1}.ps/cps".format( settings.myBasename, num )
    yield "pl"
    print("done auto plots on calibrator scan")

def mk_weight_plot(settings):
    # the weight plot is only auto baselines, max four subbands x 2 pols
    # for the whole time range of the experiment
    # no need to set the channel selection, it is ignored
    print("generating weight plot")
    yield "ms {0}".format( settings.measurementset )
    yield "bl auto; fq */p"
    yield "src none"
    yield "time none"
    yield "ch mid"
    yield "pt wt"
    yield "new all f bl t"
    yield "y global"
    yield "sort bl"
    yield "refile {0}-weight.ps/cps".format( settings.myBasename )
    yield "wt 0.1"
    yield "pl"
    print("done weight plot")


#####################################################################
##
##                     And finally run them
##
#####################################################################
s = Settings(args[0], args[1], args[2])

todo = [open_ms(s)]

# Per selected scan we plot the auto and cross spectra
# and for the amp/phase vs. time we either use the user
# selected scan or the first/last calibrator scan
# Note: 'scans' is now a list of functions that, when called
#       with an argument, return a proper scan selection string.
#       (if a ScanNo is given that argument is ignored :D)
#       Allows for more freedom producing the scan selection query.
scans       = []

subbandNo = find_best_subband(get_observed_subbands(s.measurementset))
print(f"Subband {subbandNo} selected for amp & phase VS time plot.")


if ScanNo:
    scans = ["scan_number={0}".format(ScanNo).format]
else:
    scans = [ "field ~ /{0}/i order by start asc limit 1".format,
              "field ~ /{0}/i order by end desc limit 1".format ]
for (i, cond) in enumerate(scans):
    todo.append( mk_anp_chan_cross_plot(s, cond, i) )
    todo.append( mk_amp_chan_auto_plot(s, cond, i) )

# And the anp versus time
for (i, cond) in enumerate(('none', '$start-5m to +55m')):
    todo.append( mk_anp_time_cross_plot(s, cond, subbandNo, i) )

if not NoWgt:
    todo.append( mk_weight_plot(s) )

# Just to do a 'r' to show the MS info at the very end
todo.append("r")

try:
    jplotter.run_plotter(command.scripted(*todo), debug=('-d' in opts))
except:
    pass
finally:
    s.cleanup()
