#!/usr/bin/env python
from __future__ import print_function, division, absolute_import


import os
import sys
import argparse
from glob import glob
from datetime import datetime
import operator


PY2 = sys.version_info[0] == 2
if not PY2:
    from functools import reduce


def log(*msg):
    print('[vcfnpy2hdf5] '
          + str(datetime.now())
          + ' :: '
          + ' '.join(map(str, msg)),
          file=sys.stderr)
    sys.stderr.flush()


log(sys.executable)
log(sys.version_info)


import numpy as np
log('numpy', np.__version__)
import h5py
log('h5py', h5py.__version__)
import vcfnp
log('vcfnp', vcfnp.__version__)
import vcfnp.vcflib as vcflib
from vcfnp.compat import b as _b


def load_compound_arrays_into_dataset(input_filenames, dataset_name,
                                      parent_group, chunk_size, chunk_width,
                                      compression, compression_opts, shuffle,
                                      fletcher32, scaleoffset):

    a = np.load(input_filenames[0])

    if dataset_name in parent_group:
        log(dataset_name, 'delete existing dataset')
        del parent_group[dataset_name]

    log(dataset_name, 'setup dataset')

    # determine shape
    shape = (0,) + a.shape[1:]  # initially empty
    maxshape = (None,) + a.shape[1:]  # resizable along first dimension

    # determine chunk shape
    if a.ndim == 1:
        chunks = chunk_size // a.dtype.itemsize,
    elif a.ndim > 1:
        chunk_width = min(chunk_width, a.shape[1])
        row_size = reduce(
            operator.mul,
            (a.dtype.itemsize, chunk_width) + a.shape[2:]
        )
        chunk_height = chunk_size // row_size
        chunks = (chunk_height, chunk_width) + a.shape[2:]
    else:
        raise Exception('unexpected number of dimensions: %r' % a)

    log(
        dataset_name,
        'dtype', a.dtype,
        'shape', shape,
        'maxshape', maxshape,
        'chunks', chunks,
        'chunk size', a.dtype.itemsize * reduce(operator.mul, chunks),
        'compression', compression,
        'compression_opts', compression_opts,
        'shuffle', shuffle,
        'fletcher32', fletcher32,
        'scaleoffset', scaleoffset
    )
    ds = parent_group.create_dataset(dataset_name, shape=shape, dtype=a.dtype,
                                     chunks=chunks, maxshape=maxshape,
                                     compression=compression,
                                     compression_opts=compression_opts,
                                     shuffle=shuffle, fletcher32=fletcher32,
                                     scaleoffset=scaleoffset)

    log(dataset_name, 'load data')
    for fn in input_filenames:
        a = np.load(fn)
        if a.size > 0:
            n = ds.shape[0]
            n_add = a.shape[0]
            n_new = n + n_add
            log(dataset_name, 'loading', n, n_new)
            ds.resize(n_new, axis=0)
            ds[n:n_new] = a

    return ds


def load_compound_arrays_into_group(input_filenames, group_name, parent_group,
                                    chunk_size, chunk_width, compression,
                                    compression_opts, shuffle, fletcher32,
                                    scaleoffset):

    a = np.load(input_filenames[0])
    grp = parent_group.require_group(group_name)

    log(group_name, 'setup datasets')
    for f in a.dtype.names:

        if f in grp:
            log(group_name, f, 'deleting existing dataset')
            del grp[f]

        # determine shape
        shape = (0,) + a[f].shape[1:]  # initially empty
        maxshape = (None,) + a[f].shape[1:]  # resizable along first dimension

        # determine chunk shape
        if a[f].ndim == 1:
            chunks = chunk_size // a[f].dtype.itemsize,
        elif a[f].ndim > 1:
            chunk_width = min(chunk_width, a[f].shape[1])
            row_size = reduce(
                operator.mul,
                (a[f].dtype.itemsize, chunk_width) + a[f].shape[2:]
            )
            chunk_height = chunk_size // row_size
            chunks = (chunk_height, chunk_width) + a[f].shape[2:]
        else:
            raise Exception('unexpected number of dimensions: %r' % a[f])

        log(
            group_name, f,
            'dtype', a[f].dtype,
            'shape', shape,
            'maxshape', maxshape,
            'chunks', chunks,
            'chunk size', a[f].dtype.itemsize * reduce(operator.mul, chunks),
            'compression', compression,
            'compression_opts', compression_opts,
            'shuffle', shuffle,
            'fletcher32', fletcher32,
            'scaleoffset', scaleoffset
        )
        grp.create_dataset(f, shape=shape, dtype=a[f].dtype, chunks=chunks,
                           maxshape=maxshape, compression=compression,
                           compression_opts=compression_opts, shuffle=shuffle,
                           fletcher32=fletcher32, scaleoffset=scaleoffset)

    log(group_name, 'load data')
    for fn in input_filenames:
        a = np.load(fn)
        if a.size > 0:
            for f in a.dtype.names:
                n = grp[f].shape[0]
                n_add = a.shape[0]
                n_new = n + n_add
                log(group_name, f, 'loading', n, n_new)
                grp[f].resize(n_new, axis=0)
                grp[f][n:n_new, ...] = a[f]

    return grp


def load_hdf5(input_dir, output_filename,
              input_filename_template='{array_type}*.npy', vcf_filename=None,
              variants_only=False, tabulate_variants=False,
              parent_group_name='/', chunk_size=100000, chunk_width=10,
              compression=None, compression_opts=None, shuffle=False,
              fletcher32=False, scaleoffset=None):

    # guard conditions
    assert input_dir is not None and os.path.exists(input_dir)
    if vcf_filename is not None:
        assert os.path.exists(vcf_filename)

    # template for input arrays
    input_path_template = os.path.join(input_dir, input_filename_template)

    log('open hdf5 file', output_filename)
    with h5py.File(output_filename, 'a') as h5f:

        log('setup parent group', parent_group_name)
        parent_group = h5f.require_group(parent_group_name)

        if vcf_filename is not None:
            log('store metadata from VCF', vcf_filename)
            vcf = vcflib.PyVariantCallFile(vcf_filename)
            if 'samples' in parent_group:
                del parent_group['samples']
            samples = np.array(_b(vcf.sample_names))
            parent_group.create_dataset('samples', data=samples)
        else:
            log('no VCF file provided, skipping metadata storage')

        log('process variants')
        variants_input_filenames = sorted(
            glob(input_path_template.format(array_type='variants'))
        )
        log('variants found', len(variants_input_filenames), 'files to load')

        if tabulate_variants:
            log('variants use dataset layout')
            load_compound_arrays_into_dataset(variants_input_filenames,
                                              'variants', parent_group,
                                              chunk_size=chunk_size,
                                              chunk_width=chunk_width,
                                              compression=compression,
                                              compression_opts=compression_opts,
                                              shuffle=shuffle,
                                              fletcher32=fletcher32,
                                              scaleoffset=scaleoffset)
        else:
            log('variants use group layout')
            load_compound_arrays_into_group(variants_input_filenames,
                                            'variants', parent_group,
                                            chunk_size=chunk_size,
                                            chunk_width=chunk_width,
                                            compression=compression,
                                            compression_opts=compression_opts,
                                            shuffle=shuffle,
                                            fletcher32=fletcher32,
                                            scaleoffset=scaleoffset)

        if not variants_only:

            log('process calldata')
            calldata_input_filenames = sorted(
                glob(input_path_template.format(array_type='calldata_2d'))
            )
            log('calldata found',
                len(calldata_input_filenames),
                'files to load')

            log('calldata check number of input files')
            nvf = len(variants_input_filenames)
            nc2df = len(calldata_input_filenames)
            assert nvf == nc2df, (
                'bad number of calldata files, expected %s, found %s'
                % (nvf, nc2df)
            )

            load_compound_arrays_into_group(
                calldata_input_filenames, 'calldata', parent_group,
                chunk_size=chunk_size,
                chunk_width=chunk_width,
                compression=compression,
                compression_opts=compression_opts,
                shuffle=shuffle,
                fletcher32=fletcher32,
                scaleoffset=scaleoffset
            )

        log('all done')


def main():

    # handle command line args
    pyv = '.'.join(map(str, sys.version_info[:3]))
    epilog = "Version: vcfnp %s (Python %s)" % (vcfnp.__version__, pyv)
    parser = argparse.ArgumentParser(epilog=epilog)
    parser.add_argument('--vcf',
                        dest='vcf_filename', metavar='VCF', default=None,
                        help='VCF file to extract metadata from')
    parser.add_argument('--input-dir',
                        dest='input_dir', metavar='DIR', default=None,
                        help='input directory containing npy files')
    parser.add_argument('--input-filename-template',
                        dest='input_filename_template', metavar='TEMPLATE',
                        default='{array_type}*.npy',
                        help='template for input file names, defaults to '
                             '"{array_type}*.npy"')
    parser.add_argument('--output',
                        dest='output_filename', metavar='HDF5', default=None,
                        help='name of output HDF5 file')
    parser.add_argument('--group',
                        metavar='GROUP', dest='group', default='/',
                        help='destination group in HDF5 file, defaults to '
                             'root group')
    parser.add_argument('--chunk-size',
                        dest='chunk_size', type=int, metavar='NBYTES',
                        default=2**17,
                        help='chunk size (defaults to 128kb)')
    parser.add_argument('--chunk-width',
                        dest='chunk_width', type=int, metavar='N',
                        default=10,
                        help='chunk width for 2 dimensional datasets ('
                             'defaults to 10)')
    parser.add_argument('--compression',
                        dest='compression', metavar='NAME', default='gzip',
                        help='compression, default is gzip')
    parser.add_argument('--compression-opts',
                        dest='compression_opts', metavar='LEVEL', type=int,
                        default=3,
                        help='compression level, applies only to gzip ('
                             'defaults to 3)')
    parser.add_argument('--shuffle',
                        dest='shuffle', action='store_true',
                        default=False,
                        help='use the shuffle filter (defaults to False)')
    parser.add_argument('--fletcher32',
                        dest='fletcher32', action='store_true',
                        default=False,
                        help='use the Fletcher32 checksum filter (defaults to '
                             'False)')
    parser.add_argument('--scaleoffset',
                        dest='scaleoffset', metavar='N', type=int,
                        default=None,
                        help='use the scale-offset filter (defaults to None)')
    parser.add_argument('--variants-only',
                        dest='variants_only', action='store_true',
                        default=False,
                        help="load variants only, don't look for calldata")
    parser.add_argument('--tabulate-variants',
                        dest='tabulate_variants', action='store_true',
                        default=False,
                        help="organise the variants as a dataset with a "
                             "compound dtype (defaults to False - organise "
                             "variants as a group, with each column as a "
                             "separate dataset)")
    args = parser.parse_args()

    compression = args.compression
    compression_opts = args.compression_opts if compression == 'gzip' else None

    load_hdf5(input_dir=args.input_dir, output_filename=args.output_filename,
              input_filename_template=args.input_filename_template,
              vcf_filename=args.vcf_filename, parent_group_name=args.group,
              chunk_size=args.chunk_size, chunk_width=args.chunk_width,
              compression=compression, compression_opts=compression_opts,
              shuffle=args.shuffle, fletcher32=args.fletcher32,
              scaleoffset=args.scaleoffset, variants_only=args.variants_only,
              tabulate_variants=args.tabulate_variants)


if __name__ == '__main__':
    main()
