#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Licensed under a MIT style license - see LICENSE.rst

"""Fetch BOSS data files containing the spectra of specified observations and mirror them locally.
"""

from __future__ import division,print_function

import os.path
import multiprocessing

from astropy.utils.compat import argparse

from progressbar import ProgressBar, Percentage, Bar

import astropy.table

import bossdata.path
import bossdata.remote
import bossdata.plate

def fetch(remote_paths,response_queue):
    mirror = bossdata.remote.Manager()
    for remote_path in remote_paths:
        try:
            local_path = mirror.get(remote_path, progress_min_size=None)
            response_queue.put((os.path.getsize(local_path),))
        except RuntimeError as e:
            response_queue.put((0,remote_path,str(e)))

def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose', action='store_true',
        help='Provide verbose output.')
    parser.add_argument('observations', type=str, default=None, metavar='FILE',
        help='File containing PLATE,MJD,FIBER columns that specify the observations to fetch.')
    parser.add_argument('--full', action='store_true',
        help='Fetch the full version of each spectrum data file.')
    parser.add_argument('--frame', action='store_true',
        help='Fetch spFrame files for each plate instead of individual spectra.')
    parser.add_argument('--cframe', action='store_true',
        help='Fetch spCFrame files for each plate instead of individual spectra.')
    parser.add_argument('--save', type=str, default=None, metavar='FILE',
        help='Filename for saving the list of data files to download.')
    parser.add_argument('--dry-run', action='store_true',
        help='Prepare the list of files to fetch but do not perform downloads.')
    parser.add_argument('--nproc', type=int, default=2,
        help='Number of subprocesses to use to parallelize downloads (1-5).')
    args = parser.parse_args()

    if args.nproc < 1 or args.nproc > 5:
        print('nproc must be 1-5.')
        return -1

    if args.full and (args.frame or args.cframe):
        print('Option --full is not compatible with --frame or --cframe.')
        return -1

    # Read the list of observations to fetch.
    root,ext = os.path.splitext(args.observations)
    if ext in ('.dat','.txt'):
        input_format = 'ascii'
    else:
        input_format = None
    table = astropy.table.Table.read(args.observations,format=input_format)

    try:
        finder = bossdata.path.Finder()
        mirror = bossdata.remote.Manager()
    except ValueError as e:
        print(e)
        return -1

    # Build a list of remote paths from the input plate-mjd-fiber values.
    try:
        remote_paths = set()
        if args.frame or args.cframe:
            plans = {}
            calibrated = {}
            if args.frame:
                calibrated['spFrame'] = False
            if args.cframe:
                calibrated['spCFrame'] = True
            for row in table:
                tag = (row['PLATE'], row['MJD'])
                if tag in plans:
                    plan = plans[tag]
                else:
                    plan_remote_path = finder.get_plate_plan_path(
                        plate=row['PLATE'], mjd=row['MJD'])
                    # The next line downloads the small plan file, if necessary.
                    plan_local_path = mirror.get(plan_remote_path)
                    plan = bossdata.plate.Plan(plan_local_path)
                    plans[tag] = plan
                plate_path = finder.get_plate_path(plate=row['PLATE'])
                for i in range(plan.num_science_exposures):
                    # Add the names of spFrame and/or spCFrame files, if they are available.
                    for name,cal in calibrated.iteritems():
                        blue_path = plan.get_exposure_name(
                            i, 'blue', row['FIBER'], calibrated=cal)
                        if blue_path:
                            remote_paths.add(os.path.join(plate_path, blue_path))
                        red_path = plan.get_exposure_name(
                            i, 'red', row['FIBER'], calibrated=cal)
                        if red_path:
                            remote_paths.add(os.path.join(plate_path, red_path))
        else:
            for row in table:
                remote_paths.add(finder.get_spec_path(
                    plate=row['PLATE'], mjd=row['MJD'], fiber=row['FIBER'], lite=not args.full))
    except RuntimeError as e:
        print('Error while preparing paths: {}'.format(str(e)))
        return -1
    remote_paths = list(remote_paths)
    num_files = len(remote_paths)
    if num_files == 0:
        print('No files to fetch.')
        return 0

    if args.save:
        with open(args.save,'w') as f:
            for remote_path in remote_paths:
                f.write(remote_path+'\n')
        print('Saved {0} remote file names to {1}.'.format(num_files,args.save))

    if args.dry_run:
        return 0

    if args.verbose:
        print('Fetching {:d} files...'.format(num_files))
        progress_bar = ProgressBar(widgets=[Percentage(), Bar()], maxval=num_files).start()

    # Initialize a queue that subprocesses use to signal their progress.
    response_queue = multiprocessing.Queue()

    # Launch subprocesses to handle subsets of remote paths.
    if num_files < args.nproc:
        args.nproc = num_files
    chunk_size = (len(remote_paths) + args.nproc - 1)//args.nproc
    processes = []
    for i in range(args.nproc):
        # The last chunk will be shorter if the number of paths does not evenly divide
        # between the subprocesses.
        chunk = remote_paths[i*chunk_size:(i+1)*chunk_size]
        process = multiprocessing.Process(target=fetch, args=(chunk, response_queue))
        processes.append(process)
        process.start()

    # Monitor subprocess progress.
    num_fetched = 0
    num_bytes = 0
    try:
        while num_fetched < len(remote_paths):
            response = response_queue.get()
            if response[0] == 0:
                print('Download error for {file}:\n{msg}'.format(file=response[1], msg=response[2]))
            else:
                num_bytes += response[0]
            num_fetched += 1
            if args.verbose:
                progress_bar.update(num_fetched)
        if args.verbose:
            progress_bar.finish()
        # Give subprocesses a chance to finish normally.
        for process in processes:
            process.join(timeout=1)
    except KeyboardInterrupt:
        print('Stopping after keyboard interrupt.')

    # Ensure that all subprocesses have terminated. This should never be necessary
    # after normal completion.
    for process in processes:
        if process.is_alive():
            print('Killing subprocess {}.'.format(process.name))
            process.terminate()

    if args.verbose:
        print('Fetched {:.1f} Mb for {:d} files.'.format(
            num_bytes/float(1<<20),num_fetched))

    if num_fetched != num_files:
        print('WARNING: {:d} of {:d} files were not fetched.'.format(
            num_files-num_fetched,num_files),'Re-run the command after any problems are fixed.')

if __name__ == '__main__':
    main()
