#!/usr/bin/python
""" This program converts single detector xml files into hdf files 
in preparation for coincidence
"""
import numpy, argparse, h5py, os, logging
import pycbc.version

def read_files(trigger_files, columns, column_types, attribute_columns):
    """ Read in the column of data from the ligolw xml format
    """
    data = {}
    attrs = {}
    other = {}
    ths = []
    
    from glue.ligolw import ligolw, table, lsctables, utils as ligolw_utils
    # dummy class needed for loading LIGOLW files
    class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
        pass
    lsctables.use_in(LIGOLWContentHandler)
    
    retrieved_attrs = False
    for i, filename in enumerate(trigger_files):
        indoc = ligolw_utils.load_filename(filename, False, 
                                           contenthandler=LIGOLWContentHandler)
        sngl_table = table.get_table(indoc, 'sngl_inspiral') 
        
        for column, ctype in zip(columns, column_types):
            if column not in data:
                data[column] = []    
            col_data = numpy.array(sngl_table.get_column(column), dtype=ctype)
            data[column] = numpy.concatenate([data[column], col_data]).astype(ctype)
            
        if i==0:
            search_sum = table.get_table(indoc, lsctables.SearchSummaryTable.tableName)
            other['search/start_time'] = numpy.array([search_sum[0].out_start_time], dtype=numpy.float64)
            other['search/end_time'] = numpy.array([search_sum[0].out_end_time], dtype=numpy.float64)
            attrs['ifo'] = str(search_sum[0].ifos)
       
        if len(sngl_table) > 0 and not retrieved_attrs:
            retrieved_attrs = True
            for attr in attribute_columns:
                attrs[attr] = sngl_table.get_column(attr)[0]
                
        # Template id hack
        m1 = numpy.array(sngl_table.get_column('mass1'), dtype=numpy.float32)
        m2 = numpy.array(sngl_table.get_column('mass2'), dtype=numpy.float32)
        s1 = numpy.array(sngl_table.get_column('spin1z'), dtype=numpy.float32)
        s2 = numpy.array(sngl_table.get_column('spin2z'), dtype=numpy.float32)
 
        th = numpy.zeros(len(m1), dtype=int)
        for j, v in enumerate(zip(m1, m2, s1, s2)):
            th[j] = hash(v)
        ths.append(th)
    data['template_hash'] = numpy.concatenate(ths)
     
    return data, attrs, other

def hash_to_row(template_hashes):
    row_lookup = {}
    for i, h in enumerate(template_hashes):
        row_lookup[h] = i
    return row_lookup

def find_ids(template_hash, trig_hash):
    template_map = hash_to_row(template_hash)
    return numpy.array([template_map[i] for i in trig_hash], dtype=numpy.uint32) 

parser = argparse.ArgumentParser()
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-files', nargs='+')
parser.add_argument('--bank-file')
parser.add_argument('--output-file')
parser.add_argument('--verbose', '-v', action='count')
args = parser.parse_args()

if args.verbose == 1:
    log_level = logging.INFO
elif args.verbose == 2:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level) 

logging.info('reading template bank')
f = h5py.File(args.bank_file, "r")
template_hashes = f['template_hash'][:]

logging.info('reading triggers')
data, attrs, other = read_files(args.trigger_files, 
                columns=['snr',
                         'chisq',
                         'bank_chisq',
                         'cont_chisq',
                         'end_time',
                         'end_time_ns',
                         'sigmasq',
                         'coa_phase',
                         'mass1',
                         'mass2',
                         'chisq_dof',
                         'bank_chisq_dof',
                         'cont_chisq_dof',
                         'template_duration',
                        ],
                column_types=[numpy.float32,
                              numpy.float32,
                              numpy.float32,
                              numpy.float32,
                              numpy.float64,
                              numpy.float64,
                              numpy.float32,
                              numpy.float32,
                              numpy.float32,
                              numpy.float32,
                              numpy.uint32,
                              numpy.uint32,
                              numpy.uint32,
                              numpy.float32,
                              ],
                 attribute_columns=[
                                    ])
ifo = attrs['ifo']
                                
logging.info('group triggers by hash')
tids = find_ids(template_hashes, data['template_hash'])


data['end_time'] = data['end_time'] + 1e-9 * data['end_time_ns']
data.pop('end_time_ns', None)                                 

logging.info('saving to hdf')                            
f = h5py.File(args.output_file, "w")

for col in other.keys():
    dcol = '%s/%s' % (ifo, col)
    dset = f.create_dataset(dcol, (len(other[col]),), compression='gzip', dtype=other[col].dtype)
    dset[:] = other[col]

if len(data['snr']) > 0:
    for col in data.keys():
        dcol = '%s/%s' % (ifo, col)
        dset = f.create_dataset(dcol, (len(tids),), compression='gzip', dtype=data[col].dtype)
        dset[:] = data[col]
else:
    logging.info('There were no triggers in the sngl_inspiral table')
