#!/usr/bin/env python
from __future__ import print_function, division, absolute_import


import os
import sys
import argparse
from datetime import datetime


def log(*msg):
    print('[vcf2npy] '
          + str(datetime.now())
          + ' :: '
          + ' '.join(map(str, msg)),
          file=sys.stderr)
    sys.stderr.flush()


log(sys.executable)
log(sys.version_info)

import numpy as np
log('numpy', np.__version__)
import pyfasta
import vcfnp
log('vcfnp', vcfnp.__version__)


def main(vcf_filename, fasta_filename, output_dir, array_type, chromosome,
         task_size, task_index, exclude_fields, ploidy, dtypes, arities,
         progress):

    assert vcf_filename is not None, 'missing VCF filename, try vcf2npy --help'
    assert os.path.exists(vcf_filename), 'VCF file not found'
    log('loading', array_type, 'from', vcf_filename)

    # determine region to extract
    if chromosome is None:
        log('extract whole genome')
        region = None
    elif task_size is None:
        log('extract whole chromosome', chromosome)
        region = chromosome
    else:
        log('extract a chromosome split', chromosome, task_size, task_index)
        # N.B., make sure regions will work as sortable file names by
        # left-padding with zeros
        assert fasta_filename is not None, \
            'missing FASTA filename, try vcf2npy --help'
        assert os.path.exists(fasta_filename), 'FASTA file not found'
        genome = pyfasta.Fasta(fasta_filename)
        chrom_size = len(genome[chromosome])
        n_fill = int(np.ceil(np.log10(chrom_size)))
        start = range(0, chrom_size, task_size)[task_index]
        startstr = str(start+1).zfill(n_fill)
        stop = start + task_size
        stopstr = str(stop).zfill(n_fill)
        region = str(chromosome) + ':' + startstr + '-' + stopstr
        log(region)

    if array_type == 'variants':
        arr = vcfnp.variants(vcf_filename, region=region, progress=progress,
                             exclude_fields=exclude_fields, dtypes=dtypes,
                             arities=arities, flatten_filter=True, cache=True,
                             cachedir=output_dir, skip_cached=True,
                             verbose=True)
    elif array_type == 'calldata':
        arr = vcfnp.calldata(vcf_filename, region=region, progress=progress,
                             ploidy=ploidy, exclude_fields=exclude_fields,
                             dtypes=dtypes, arities=arities, cache=True,
                             cachedir=output_dir, skip_cached=True,
                             verbose=True)
    elif array_type == 'calldata_2d':
        arr = vcfnp.calldata_2d(vcf_filename, region=region, progress=progress,
                                ploidy=ploidy, exclude_fields=exclude_fields,
                                dtypes=dtypes, arities=arities, cache=True,
                                cachedir=output_dir, skip_cached=True,
                                verbose=True)
    else:
        raise Exception('unexpected array type: %s' % array_type)
    if arr is not None:
        log('loaded', arr.size, 'items', arr.nbytes, 'bytes')
    log('all done')


if __name__ == '__main__':

    # handle command line options
    pyv = '.'.join(map(str, sys.version_info[:3]))
    epilog = "Version: vcfnp %s (Python %s)" % (vcfnp.__version__, pyv)
    parser = argparse.ArgumentParser(epilog=epilog)
    parser.add_argument('--vcf', dest='vcf_filename', metavar='FILE',
                        default=None, help='input VCF file')
    parser.add_argument('--fasta', dest='fasta_filename', metavar='FASTA',
                        default=None,
                        help='reference genome as FASTA file (only required '
                             'if --task-size is given)')
    parser.add_argument('--output-dir', dest='output_dir', metavar='DIR',
                        default=None,
                        help='output directory (omit to use default vcfnp '
                             'cache directory)')
    parser.add_argument('--array-type', dest='array_type', default='variants',
                        help='array type, one of ['
                             'variants|calldata|calldata_2d]')
    parser.add_argument('--chromosome', dest='chromosome', default=None,
                        help='chromosome to extract (omit to extract whole '
                             'genome)')
    parser.add_argument('--task-size', dest='task_size', metavar='BP',
                        default=None, type=int,
                        help='size (in base pairs) of region to extract (omit '
                             'to extract whole chromosome)')
    parser.add_argument('--task-index', dest='task_index', metavar='N',
                        default=None,
                        help='task index as integer or string to get task '
                             'from environment variable (e.g., "SGE_TASK_ID")')
    parser.add_argument('--exclude-field', dest='exclude_fields',
                        metavar='FIELD', default=None, action='append',
                        help='field to exclude')
    parser.add_argument('--ploidy', dest='ploidy', default=2, type=int,
                        help='sample ploidy')
    parser.add_argument('--dtype', metavar='FIELD:DTYPE', dest='dtypes',
                        action='append',
                        help='override default dtype for a given field, e.g., '
                             '"MQ:f4"')
    parser.add_argument('--arity', metavar='FIELD:ARITY', dest='arities',
                        action='append',
                        help='override default arity for a given field, e.g., '
                             '"AD:2"')
    parser.add_argument('--progress', dest='progress', metavar='N',
                        default=10000, type=int,
                        help='log progress every N rows')
    args = parser.parse_args()

    # determine task index
    task_index = None
    if args.chromosome is not None and args.task_size is not None:
        # assume we are running as part of a parallel job array
        # want to determine task index
        if args.task_index is not None:
            try:
                task_index = int(args.task_index)
            except ValueError:
                task_index = str(args.task_index).strip()
                assert task_index in os.environ, 'expected %r in environment' % args.task_index
                task_index = int(os.environ[task_index])
        # if task_index not given in options, assume SGE job array
        elif 'SGE_TASK_ID' in os.environ:
            task_index = int(os.environ['SGE_TASK_ID'])
        else:
            raise Exception('could not determine task index; %s', args)
        task_index -= 1  # use zero-based indices internally

    # determine dtype overrides
    if args.dtypes:
        dtypes = dict(s.split(':') for s in args.dtypes)
    else:
        dtypes = None

    # determine arity overrides
    if args.arities:
        arities = dict((s.split(':')[0], int(s.split(':')[1])) for s in args.arities)
    else:
        arities = None

    main(vcf_filename=args.vcf_filename, fasta_filename=args.fasta_filename,
         output_dir=args.output_dir, array_type=args.array_type,
         chromosome=args.chromosome, task_size=args.task_size,
         task_index=task_index, exclude_fields=args.exclude_fields,
         ploidy=args.ploidy, dtypes=dtypes, arities=arities,
         progress=args.progress)
