#!/usr/bin/env python
# encoding: utf-8
"""
run_genmod.py

Script for annotating genetic models in variant files.

Created by Måns Magnusson on 2014-01-21.
Copyright (c) 2013 __MyCompanyName__. All rights reserved.
"""

from __future__ import print_function
from __future__ import unicode_literals

import sys
import os
import click
import inspect

from multiprocessing import JoinableQueue, Manager, cpu_count
from codecs import open, getwriter
from datetime import datetime
from tempfile import mkdtemp, TemporaryFile
# from configobj import ConfigObj
from pprint import pprint as pp

# import vcf

import shutil
import pkg_resources
import pkgutil
import genmod

try:
    import cPickle as pickle
except:
    import pickle


from pysam import tabix_index, tabix_compress

from ped_parser import parser as ped_parser
from vcf_parser import parser as vcf_parser

from genmod import variant_consumer, variant_sorter, annotation_parser, variant_printer, variant_annotator, warning

if sys.version_info < (3,0):
    sys.stdout = getwriter('UTF-8')(sys.stdout)

version = pkg_resources.require("genmod")[0].version

def get_family(family_file, family_type):
    """Return the family"""
    
    my_family_parser = ped_parser.FamilyParser(family_file, family_type)
    # Stupid thing but for now when we only look at one family
    return my_family_parser.families.popitem()[1]

def add_metadata(head, vep=False, cadd_file=None, cadd_1000g=None, thousand_g=None, command_line_string=''):
    """Add metadata for the information added by this script."""
    # Update INFO headers
    if not vep:
        head.add_info('Annotation', '.', 'String', 'Annotates what feature(s) this variant belongs to.')
    head.add_info('Compounds', '.', 'String', "':'-separated list of compound pairs for this variant.")
    head.add_info('GeneticModels', '.', 'String', "':'-separated list of genetic models for this variant.")
    head.add_info('ModelScore', '1', 'Integer', "PHRED score for genotype models.")
    if cadd_file or cadd_1000g:
        head.add_info('CADD', '1', 'Float', "The CADD relative score for this alternative.")
    if thousand_g:
        head.add_info('1000GMAF', '1', 'Float', "Frequency in the 1000G database.")
    
    # Update version logging
    head.add_version_tracking('genmod', version, str(datetime.now()), command_line_string)
    
    return

def print_headers(head, outfile, silent=False):
    """Print the headers to a results file."""
    if outfile:
        with open(outfile, 'w', encoding='utf-8') as f:
            for head_count in head.print_header():
                f.write(head_count+'\n')
    else:
        if not silent:
            for line in head.print_header():
                print(line)
    return

def check_tabix_index(compressed_file, file_type='cadd', verbose=False):
    """Check if a compressed file have a tabix index. If not build one."""
    if file_type == 'cadd':
        try:
            tabix_index(compressed_file, seq_col=0, start_col=1, end_col=1, meta_char='#')
        except IOError as e:
            pass
    elif file_type == 'vcf':
        try:
            tabix_index(compressed_file, preset='vcf')
        except IOError as e:
            pass
    return

def print_version(ctx, param, value):
    """Callback function for printing version and exiting
    Args:
        ctx (object) : Current context
        param (object) : Click parameter(s)
        value (boolean) : Click parameter was supplied or not
    Returns:
        None:
    """
    if not value or ctx.resilient_parsing:
        return
    click.echo('genmode version: ' + version)
    ctx.exit()
# def print_version(ctx, param, value):
#     # if not value or ctx.resilient_parsing:
#     #     return
#     click.echo(pkg_resources.require("genmod")[0].version)
#     ctx.exit()

class Config(object):
    """Store variables that are used of all subprograms"""
    def __init__(self):
        super(Config, self).__init__()
        self.verbose = False
    
pass_config = click.make_pass_decorator(Config, ensure=True)


###         This is the main script         ###
@click.group()
@click.option('-v', '--verbose', 
                is_flag=True,
                help='Increase output verbosity.'
)
@click.option('--version',
                is_flag=True,
                callback=print_version,
                expose_value=False,
                is_eager=True
)
@pass_config
def run_genmod(config, verbose):
    """Annotate genetic models in variant files."""
    config.verbose = verbose

###        This is for building new annotations     ###

@click.command()
@click.argument('annotation_file', 
                nargs=1, 
                type=click.Path(exists=True),
)
@click.option('-t' ,'--type',
                type=click.Choice(['bed', 'ccds', 'gtf', 'gene_pred']), 
                default='gene_pred',
                help='Specify the format of the annotation file.'
)
@click.option('-o', '--outdir', 
                    type=click.Path(exists=True),
                    help=("""Specify the path to a folder where the annotation files should be stored. 
                            Default is the annotations dir of the ditribution.""")
)
@click.option('--splice_padding',
                    type=int, nargs=1, default=2,
                    help='Specify the the number of bases that the exons should be padded with. Default is 2 bases.'
)
@pass_config
def build_annotation(config, annotation_file, type, outdir, splice_padding):
    """Build a new annotation database."""
    if config.verbose:
        click.echo('Building new annotation databases from %s into %s.' % (annotation_file, outdir))
    
    anno_parser = annotation_parser.AnnotationParser(annotation_file, type, 
                            splice_padding = splice_padding, verbosity=config.verbose)
    
    gene_db = pkg_resources.resource_filename('genmod', 'annotations/genes.db')
    exon_db = pkg_resources.resource_filename('genmod', 'annotations/exons.db')
    
    if outdir:
        gene_db = os.path.join(outdir, 'genes.db')
        exon_db = os.path.join(outdir, 'exons.db')
    
    with open(gene_db, 'wb') as f:
        pickle.dump(anno_parser.gene_trees, f)
    
    with open(exon_db, 'wb') as g:
        pickle.dump(anno_parser.exon_trees, g)
    


###           This is for annotating the variants       ###


@click.command()
@click.argument('variant_file', 
                    nargs=1, 
                    type=click.Path(),
                    metavar='<vcf_file> or -'
)
@click.option('--family_file', '-fam',
                    nargs=1, 
                    type=click.Path(exists=True),
                    metavar='<ped_file>'
)
@click.option('-f' ,'--family_type', 
                type=click.Choice(['ped', 'alt', 'cmms', 'mip']), 
                default='ped',
                help='If the analysis use one of the known setups, please specify which one.'
)
@click.option('--vep', 
                    is_flag=True,
                    help='If variants are annotated with the Variant Effect Predictor.'
)
@click.option('--chr_prefix', 
                    is_flag=True,
                    help='If chr prefix is used in vcf file.'
)
@click.option('-p' ,'--phased', 
                    is_flag=True,
                    help='If data is phased use this flag.'
)
@click.option('-strict' ,'--strict', 
                    is_flag=True,
                    help='If strict model annotations should be used(see documentation).'
)
@click.option('-s' ,'--silent', 
                    is_flag=True,
                    help='Do not print the variants.'
)
@click.option('-g' ,'--whole_gene', 
                    is_flag=True,
                    help="""If compounds should be checked in the whole gene regions. 
                    Not only exonic/splice sites."""
)
@click.option('-a' ,'--annotation_dir', 
                    type=click.Path(exists=True), 
                    help="""Specify the path to the directory where the annotation 
                    databases are. 
                    Default is the gene pred files that comes with the distribution."""
)
@click.option('-o', '--outfile', 
                    type=click.Path(exists=False),
                    help='Specify the path to a file where results should be stored.'
)
@click.option('--cadd_file', 
                    type=click.Path(exists=True), 
                    help="""Specify the path to a bgzipped cadd file with variant scores.
                            If no index is present it will be created."""
)
@click.option('--cadd_1000g',
                    type=click.Path(exists=True), 
                    help="""Specify the path to a bgzipped cadd file with variant scores 
                            for all 1000g variants. If no index is present a new index
                            will be created."""
)
@click.option('--thousand_g',
                    type=click.Path(exists=True), 
                    help="""Specify the path to a bgzipped vcf file frequency info of all 
                            1000g variants. If no index is present a new index
                            will be created."""
)

@pass_config
def annotate(config, family_file, variant_file, family_type, vep, silent, phased, strict,
             whole_gene, annotation_dir, cadd_file, cadd_1000g, thousand_g, outfile, chr_prefix):
    """Annotate variants in a VCF file.
        It is possible to annotate from sources shown as options only
        If a ped file is provided then the genetic inheritance patterns for all individuals are followed.
        Individuals that are not present in ped file will not be considered in the analysis.
    """    
    verbosity = config.verbose
    
    # This is for logging the command line string:
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    argument_list = [i+'='+str(values[i]) for i in values if values[i] and i != 'config' and i != 'frame']
    gene_db = pkg_resources.resource_filename('genmod', 'annotations/genes.db')
    exon_db = pkg_resources.resource_filename('genmod', 'annotations/exons.db')
        
    if annotation_dir:
        gene_db = os.path.join(annotation_dir, 'genes.db')
        exon_db = os.path.join(annotation_dir, 'exons.db')
    
    if variant_file == '-':
        variant_parser = vcf_parser.VCFParser(fsock = sys.stdin)
    else:
        variant_parser = vcf_parser.VCFParser(infile = variant_file)
    
    head = variant_parser.metadata
    
    if verbosity:
        start_time_analysis = datetime.now()
    try:
        with open(gene_db, 'rb') as f:
            gene_trees = pickle.load(f)
        with open(exon_db, 'rb') as g:
            exon_trees = pickle.load(g)
    except IOError as e:
        if verbosity:
            warning('You need to build annotations! See documentation.')
            # It is possible to continue the analysis without annotation files
        pass
    
    if family_file:
        family = get_family(family_file, family_type)
        if set(family.individuals.keys()) != set(variant_parser.individuals):
            warning('There must be same individuals in ped file and vcf file! Aborting...')
            warning('Individuals in PED file: %s' % '\t'.join(list(family.individuals.keys())))
            warning('Individuals in VCF file: %s' % '\t'.join(list(variant_parser.individuals)))
            sys.exit()
    else:
        family = None
    
    if cadd_file:
        if verbosity:
            click.echo('Cadd file! %s' % cadd_file)
        check_tabix_index(cadd_file, 'cadd', verbosity)
    if cadd_1000g:
        if verbosity:
            click.echo('Cadd 1000G file! %s' % cadd_1000g)
        check_tabix_index(cadd_1000g, 'cadd', verbosity)
    if thousand_g:
        if config.verbose:
            click.echo('1000G frequency file! %s' % thousand_g)
        check_tabix_index(thousand_g, 'vcf', verbosity)
    
    
    ###################################################################
    ### The task queue is where all jobs(in this case batches that  ###
    ### represents variants in a region) is put. The consumers will ###
    ### then pick their jobs from this queue.                       ###
    ###################################################################
    
    variant_queue = JoinableQueue(maxsize=1000)
    # The consumers will put their results in the results queue
    results = Manager().Queue()
    
    # Create a directory to keep track of temp files
    temp_dir = mkdtemp()
    #Adapt the number of processes to the machine that run the analysis    
    num_model_checkers = (cpu_count()*2-1)
    # num_model_checkers = (1)
    
    if verbosity:
        print('Number of CPU:s %s' % cpu_count())
    
    # These are the workers that do the analysis
    model_checkers = [variant_consumer.VariantConsumer(variant_queue, results, family, 
                        phased, vep, cadd_file, cadd_1000g, thousand_g, chr_prefix, strict, verbosity) 
                            for i in range(num_model_checkers)]
    
    for w in model_checkers:
        w.start()
    
    # This process prints the variants to temporary files
    var_printer = variant_printer.VariantPrinter(results, temp_dir, head, verbosity)
    var_printer.start()
    
    if verbosity:
        print('Start parsing the variants ...')
        print('')
        start_time_variant_parsing = datetime.now()
    
    # For parsing the vcf:
    var_annotator = variant_annotator.VariantAnnotator(variant_parser, variant_queue, 
                        gene_trees, exon_trees, phased, vep, whole_gene, verbosity)
    var_annotator.annotate()
    
    for i in range(num_model_checkers):
        variant_queue.put(None)
    
    variant_queue.join()
    results.put(None)
    var_printer.join()
    
    chromosome_list = var_annotator.chromosomes
        
    if verbosity:
        print('Cromosomes found in variant file: %s' % ','.join(chromosome_list))
        print('Models checked!')
        print('Start sorting the variants:')
        print('')
        start_time_variant_sorting = datetime.now()
    
    # Add the new metadata to the headers:    
    add_metadata(head, vep, cadd_file, cadd_1000g, thousand_g, ' '.join(argument_list))
    print_headers(head, outfile, silent)
    
    for chromosome in chromosome_list:
        for temp_file in os.listdir(temp_dir):
            if temp_file.split('_')[0] == chromosome:
                var_sorter = variant_sorter.FileSort(os.path.join(temp_dir, temp_file), outfile, silent=silent)
                var_sorter.sort()
    
    if verbosity:
        print('Sorting done!')
        print('Time for sorting: %s' % str(datetime.now()-start_time_variant_sorting))
        print('')
        print('Time for whole analyis: %s' % str(datetime.now() - start_time_analysis))
    
    # Remove all temp files:
    shutil.rmtree(temp_dir)

###           This is for analyzing the variants       ###


def make_models(list_of_models):
    """Make a dictionary of the prefered models."""
    model_dict = {}
    # If no models are specified we allow all models
    if len(list_of_models) == 0:
        list_of_models = ['AR', 'AD', 'X']
    
    for model in list_of_models:
        if 'AR' in model:
            model_dict['AR_hom'] = ''
            model_dict['AR_hom_dn'] = ''
            model_dict['AR_comp'] = ''
            model_dict['AR_comp_dn'] = ''
        if 'AD' in model:
            model_dict['AD'] = ''
            model_dict['AD_dn'] = ''
        if 'X' in model:
            model_dict['XR'] = ''
            model_dict['XR_dn'] = ''
            model_dict['XD'] = ''
            model_dict['XD_dn'] = ''
    return model_dict

# @click.command()
# @click.argument('variant_file',
#                     nargs=1,
#                     type=click.Path(exists=True),
#                     metavar='<vcf_file> or "-"'
# )
# @click.option('-c', '--config_file',
#                     type=click.Path(exists=True),
#                     help="""Specify the path to a config file."""
# )
# @click.option('--freq',
#                     type=float, nargs=1,
#                     help='Specify the treshold for variants to be considered.'
# )
# @click.option('-p', '--patterns',
#                     type=click.Choice(['AR_comp', 'AR_comp_dn', 'AR_hom', 'AR_hom_dn',
#                                         'AD', 'AD_dn', 'XD', 'XD_dn', 'XR', 'XR_dn']),
#                     nargs=1, multiple=True,
#                     help='Specify the inheritance patterns.'
# )
#
# def analyze(variant_file, freq, patterns, config_file):
#     """Analyze the annotated variants in a VCF file."""
#     configs = ConfigObj(config_file)
#     pp(configs)
#
#     freq_treshold = float(configs.get('frequency', {}).get('uncommon', 0.1))
#     freq_keyword = configs.get('frequency', {}).get('keyword', '1000G_freq')
#
#     inheritance_patterns = [pattern for pattern in configs.get('inheritance', {}).get('patterns',[])]
#     inheritance_keyword = configs.get('inheritance', {}).get('keyword','GM')
#     prefered_models = make_models(inheritance_patterns)
#
#     prediction_keyword = configs.get('prediction', {}).get('keyword','CADD')
#
#     variant_parser = vcf_parser.VCFParser(variant_file)
#     print(freq_keyword, freq_treshold)
#     print(inheritance_keyword, inheritance_patterns)
#     print(prediction_keyword)
#
#     interesting_variants = TemporaryFile(mode='w+t')
#
#     for variant in variant_parser:
#         if float(variant['info_dict'].get(freq_keyword, '0')) <= freq_treshold:
#             for model in variant['info_dict'].get(inheritance_keyword, '').split(';'):
#                 if model in prefered_models:
#                     print_line = [variant.get(entry, '-') for entry in variant_parser.header]
#                     interesting_variants.write('\t'.join(print_line)+'\n')
#     interesting_variants.seek(0)
#     for line in interesting_variants:
#         print(line.rstrip())
#     interesting_variants.close()
#     # if freq:
#     #     freq_treshold = freq
#     # if patterns:
#     #     inheritance_patterns = {pattern:0 for pattern in patterns}
#     # print('Frequency tres: %s' % freq_treshold)
#     # print('Inheritance models: %s' % str(inheritance_patterns))
#     # vcf_reader = vcf.Reader(fsock=variant_file)
#     # for record in vcf_reader:
#     #     print(record.INFO.get('1000GMAF', 0))


run_genmod.add_command(build_annotation)
run_genmod.add_command(annotate)
# run_genmod.add_command(analyze)



if __name__ == '__main__':
    # main()
    run_genmod()
