#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

""" Plot a single BOSS spectrum.
"""

from __future__ import division,print_function

from astropy.utils.compat import argparse

import os.path

import numpy as np
import numpy.ma
import matplotlib.pyplot as plt

import bossdata.path
import bossdata.remote
import bossdata.spec
import bossdata.bits
import bossdata.plate

def print_mask_summary(label, mask_values):
    if np.any(mask_values):
        print('{0} pixel mask summary:'.format(label))
        bit_summary = bossdata.bits.summarize_bitmask_values(
            bossdata.bits.SPPIXMASK,mask_values)
        for bit_name,bit_count in bit_summary.iteritems():
            print('{0:5d} {1}'.format(bit_count,bit_name))
    else:
        print('No pixels masked.')

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter,
        description = 'Plot a single BOSS spectrum.')
    parser.add_argument('--verbose', action = 'store_true',
        help = 'Provide verbose output.')
    parser.add_argument('--plate',type = int, default = 6641, metavar = 'PLATE',
        help = 'Plate number of spectrum to plot.')
    parser.add_argument('--mjd',type = int, default = 56383, metavar = 'MJD',
        help = 'Modified Julian date of plate observation to use.')
    parser.add_argument('--fiber',type = int,default = 30, metavar = 'FIBER',
        help = 'Fiber number identifying the spectrum of the requested PLATE-MJD to plot.')
    parser.add_argument('--exposure',type = int,default = None, metavar = 'EXP',
        help = 'Exposure sequence number starting from 0, or plot the coadd if not set.')
    parser.add_argument('--camera',type = str, choices = ['blue','red','both'], default = 'both',
        help = 'Camera to use when plotting a single exposure.')
    parser.add_argument('--allow-mask', type = str, default = None,
        help = 'SPPIXMASK bit names to allow in valid data. Separate multiple names with |.')
    parser.add_argument('--frame', action='store_true',
        help = 'Plot the spectrum from an uncalibrated spFrame file.')
    parser.add_argument('--cframe', action='store_true',
        help = 'Plot the spectrum from a calibrated spCFrame file.')
    parser.add_argument('--save-plot', type = str, default = None, metavar = 'FILE',
        help = 'File name to save the generated plot to.')
    parser.add_argument('--no-display', action = 'store_true',
        help = 'Do not display the image on screen (useful for batch processing).')
    parser.add_argument('--scatter', action = 'store_true',
        help = 'Show scatter of flux instead of a flux error band.')
    parser.add_argument('--show-mask', action = 'store_true',
        help = 'Indicate pixels with invalid data using vertical lines.')
    parser.add_argument('--show-dispersion', action = 'store_true',
        help = 'Show the wavelength dispersion using the right-hand axis.')
    parser.add_argument('--show-sky', action = 'store_true',
        help = 'Show the subtracted sky flux instead of the object flux.')
    parser.add_argument('--add-sky', action = 'store_true',
        help = 'Add the subtracted sky to the object flux (overrides show-sky).')
    args = parser.parse_args()

    if args.exposure is None:
        if args.frame or args.cframe:
            print('Coadds not available from frame and cframe files.')
            return -1
        if args.camera is not 'both':
            print('Ignoring camera = "{0}" for coadded spectrum.'.format(args.camera))
            args.camera = 'both'

    if args.allow_mask is None:
        pixel_quality_mask = None
    else:
        pixel_quality_mask = bossdata.bits.bitmask_from_text(
            bossdata.bits.SPPIXMASK,args.allow_mask)

    try:
        finder = bossdata.path.Finder()
        mirror = bossdata.remote.Manager()
    except ValueError as e:
        print(e)
        return -1

    if args.frame or args.cframe:
        frames = {}
        frame_path = finder.get_plate_path(plate=args.plate)
        plan_path = finder.get_plate_plan_path(plate=args.plate, mjd=args.mjd)
        plan = bossdata.plate.Plan(mirror.get(plan_path))
        if args.camera in ('red','both'):
            red_name = plan.get_exposure_name(
                args.exposure, 'red', args.fiber, calibrated=args.cframe)
            if red_name is None:
                print('Red camera data not available.')
                return -1
            frames['red'] = bossdata.plate.FrameFile(
                mirror.get(os.path.join(frame_path, red_name)),
                index=1 + (args.fiber-1)//500, calibrated=args.cframe)
        if args.camera in ('blue','both'):
            blue_name = plan.get_exposure_name(
                args.exposure, 'blue', args.fiber, calibrated=args.cframe)
            if blue_name is None:
                print('Blue camera data not available.')
                return -1
            frames['blue'] = bossdata.plate.FrameFile(
                mirror.get(os.path.join(frame_path, blue_name)),
                index=1 + (args.fiber-1)//500, calibrated=args.cframe)
    else:
        remote_path = finder.get_spec_path(plate=args.plate, mjd=args.mjd, fiber=args.fiber,
            lite=(args.exposure is None))
        local_path = mirror.get(remote_path)
        specfile = bossdata.spec.SpecFile(local_path)
        if args.verbose:
            print('Exposure summary:')
            print(specfile.exposure_table)

    # Initialize the plot.
    figure = plt.figure(figsize=(12,8))
    left_axis = plt.gca()
    figure.set_facecolor('white')
    if args.frame:
        plt.xlabel('Wavelength index')
        left_axis.set_ylabel('Flux (electrons)')
    else:
        plt.xlabel('Wavelength (Angstrom)')
        left_axis.set_ylabel('Flux (1e-17 erg/s/cm**2)')
    if args.show_dispersion:
        right_axis = left_axis.twinx()
        right_axis.set_ylabel('Dispersion (Angstrom)')

    # We will potentially plot two spectra.
    spectra = [ ]
    plot_colors = [ ]
    data_args = dict(include_wdisp=args.show_dispersion, include_sky=args.show_sky or args.add_sky)
    if args.exposure is None:
        spectra.append(specfile.get_valid_data(pixel_quality_mask=pixel_quality_mask, **data_args))
        plot_colors.append('black')
        if args.verbose:
            print('Showing coadd of {0:d} exposures:'.format(specfile.num_exposures))
            print_mask_summary('Coadd (AND)',specfile.get_pixel_mask())
    elif args.frame or args.cframe:
        fibers = np.array([args.fiber],dtype=int)
        if args.verbose:
            print('Showing exposure {0}.'.format(plan.exposures['science'][args.exposure]['EXPID']))
        if args.camera in ('blue','both'):
            spectra.append(frames['blue'].get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue',frames['blue'].get_pixel_masks(fibers)[0])
        if args.camera in ('red','both'):
            spectra.append(frames['red'].get_valid_data(fibers,
                pixel_quality_mask=pixel_quality_mask, **data_args)[0])
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red',frames['red'].get_pixel_masks(fibers)[0])
    else:
        if args.verbose:
            print('Showing exposure {0}.'.format(specfile.exposure_table[args.exposure]['exp']))
        if args.camera in ('blue','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'blue',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('blue')
            if args.verbose:
                print_mask_summary('Blue',specfile.get_pixel_mask(args.exposure,'blue'))
        if args.camera in ('red','both'):
            spectra.append(specfile.get_valid_data(args.exposure,'red',
                pixel_quality_mask=pixel_quality_mask, **data_args))
            plot_colors.append('red')
            if args.verbose:
                print_mask_summary('Red',specfile.get_pixel_mask(args.exposure,'red'))

    wlen_min,wlen_max = +1e6,-1e6
    for data,plot_color in zip(spectra,plot_colors):

        wlen,dflux = data['wavelength'][:],data['dflux'][:]
        if args.add_sky:
            flux = data['sky'][:] + data['flux'][:]
        elif args.show_sky:
            flux = data['sky'][:]
        else:
            flux = data['flux'][:]

        if args.scatter:
            left_axis.scatter(wlen,flux,color=plot_color,marker='.',s=0.1)
        else:
            left_axis.fill_between(wlen,flux-dflux,flux+dflux,color=plot_color,alpha=0.5)

        num_masked = len(data.mask)
        if args.show_mask and num_masked > 0:
            x_mask = [ ]
            y_mask = [ ]
            ymin,ymax = left_axis.get_ylim()
            bad_pixels = np.where(data.mask)
            for x in data.data['wavelength'][bad_pixels]:
                x_mask.extend([x,x,None])
                y_mask.extend([ymin,ymax,None])
            plt.plot(x_mask,y_mask,'-',color=plot_color,alpha=0.2)

        if args.show_dispersion:
            right_axis.plot(wlen,data['wdisp'][:],ls='-',color=plot_color)

        # Update the plot wavelength limits to include this data.
        wlen_min = min(wlen_min,np.ma.min(wlen))
        wlen_max = max(wlen_max,np.ma.max(wlen))

    # The x-axis limits are reset by the twinx() function so we set them here.
    plt.xlim(wlen_min,wlen_max)

    if args.save_plot:
        figure.savefig(args.save_plot)
    if not args.no_display:
        plt.show()
    plt.close()

if __name__ == '__main__':
    main()
