#!python

import matplotlib.pyplot as plt
import utils_noroot      as utnr
import argparse
import pprint
import utils 
import numpy
import ROOT
import tqdm
import os

from log_store import log_store

log=log_store.add_logger('rx_scripts:compare_files')
#-------------------------------------
class data:
    json_path = None
    l_t_file  = None
#-------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to compare contents of trees in ROOT files')
    parser.add_argument('-f', '--files_list', type=str, help='Path to JSON file with list of files to compare.', required=True)
    args = parser.parse_args()

    data.json_path = args.files_list
#-------------------------------------
def compare(p_file_1, p_file_2):
    log.info(f'{p_file_1:<80}{"<===>":<20}{p_file_2:<80}')
    utnr.check_file(p_file_1)
    utnr.check_file(p_file_2)

    ifile_1 = ROOT.TFile(p_file_1)
    ifile_2 = ROOT.TFile(p_file_2)

    d_tree_1 = utils.getTrees(ifile_1, rtype='dict')
    d_tree_2 = utils.getTrees(ifile_2, rtype='dict')
    l_key    = get_keys(d_tree_1, d_tree_2)

    plot_dir = get_plot_dir(p_file_1, p_file_2)
    for key in l_key: 
        tree_1 = d_tree_1[key]
        tree_2 = d_tree_2[key]
        compare_trees(tree_1, tree_2, plot_dir)

    ifile_1.Close()
    ifile_2.Close()
#-------------------------------------
def get_plot_dir(p_file_1, p_file_2):
    name_1 = os.path.basename(p_file_1).replace('.root', '')

    plot_dir = f'plots/{name_1}'
    os.makedirs(plot_dir, exist_ok=True)

    return plot_dir
#-------------------------------------
def get_keys(d_tree_1, d_tree_2):
    s_key_1 = set(d_tree_1.keys())
    s_key_2 = set(d_tree_2.keys())

    if s_key_1 != s_key_2:
        log.error(f'Trees are different:')
        log.info(s_key_1)
        log.info(s_key_2)
        raise

    return s_key_1
#-------------------------------------
def get_columns(rdf_1, rdf_2):
    v_col_1 = rdf_1.GetColumnNames()
    v_col_2 = rdf_2.GetColumnNames()

    s_col_1 = { col_1.c_str() for col_1 in v_col_1 }
    s_col_2 = { col_2.c_str() for col_2 in v_col_2 }

    if s_col_1 != s_col_2:
        log.error(f'Columns are different')

        s_ocol_1 = s_col_1.difference(s_col_2)
        s_ocol_2 = s_col_2.difference(s_col_1)

        log.info('Only in file 1')
        pprint.pprint(s_ocol_1)
        log.info('Only in file 2')
        pprint.pprint(s_ocol_2)

        raise

    return s_col_1
#-------------------------------------
def equal_sizes(rdf_1, rdf_2):
    size_1 = rdf_1.Count().GetValue()
    size_2 = rdf_2.Count().GetValue()

    if size_1 != size_2:
        log.warning(f'Tree sizes differ: {size_1}/{size_2}')
        return False

    return True 
#-------------------------------------
def define_events(arr_common):
    if not hasattr(ROOT, 'v_evt'):
        ROOT.gInterpreter.Declare("""
        std::vector<int> v_evt;
        """)

    for val in arr_common:
        ROOT.v_evt.push_back(int(val))
#-------------------------------------
def filter_dataframes(rdf_1, rdf_2):
    arr_evt_1 = rdf_1.AsNumpy(['eventNumber'])['eventNumber']
    arr_evt_2 = rdf_2.AsNumpy(['eventNumber'])['eventNumber']

    arr_common= numpy.intersect1d(arr_evt_1, arr_evt_2)
    log.info(f'Common events: {arr_common.shape[0]} = {arr_evt_1.shape[0]} ^ {arr_evt_2.shape[0]}')

    define_events(arr_common)

    rdf_1     =rdf_1.Filter('std::find(v_evt.begin(), v_evt.end(), eventNumber) != v_evt.end()')
    rdf_2     =rdf_2.Filter('std::find(v_evt.begin(), v_evt.end(), eventNumber) != v_evt.end()')

    ROOT.v_evt.clear()

    return rdf_1, rdf_2
#-------------------------------------
def compare_trees(tree_1, tree_2, plot_dir):
    rdf_1 = ROOT.RDataFrame(tree_1)
    rdf_2 = ROOT.RDataFrame(tree_2)

    if not equal_sizes(rdf_1, rdf_2):
        rdf_1, rdf_2 = filter_dataframes(rdf_1, rdf_2)

    if not equal_sizes(rdf_1, rdf_2):
        log.error(f'Filtering failed')
        raise

    s_col = get_columns(rdf_1, rdf_2)

    d_data_1 = rdf_1.AsNumpy(s_col)
    d_data_2 = rdf_2.AsNumpy(s_col)

    for col in tqdm.tqdm(s_col, ascii=' -'):
        arr_val_1 = d_data_1[col]
        arr_val_2 = d_data_2[col]

        arr_val_1, arr_val_2 = cast_arrays(arr_val_1, arr_val_2)
        if arr_val_1 is None or arr_val_2 is None:
            log.warning(f'Cannot cast as float: {col}')
            continue

        if not numpy.allclose(arr_val_1, arr_val_2):
            plot_path = f'{plot_dir}/{col}.png'
            plot(arr_val_1, arr_val_2, plot_path, col)
            log.warning(f'{"":<4}{col}')
#-------------------------------------
def cast_arrays(arr_val_1, arr_val_2):
    try:
        arr_val_1 = arr_val_1.astype(float)
        arr_val_2 = arr_val_2.astype(float)
    except:
        arr_val_1 = extract_first_element(arr_val_1)
        arr_val_2 = extract_first_element(arr_val_1)

    return arr_val_1, arr_val_2
#-------------------------------------
def plot(arr_val_1, arr_val_2, plot_path, col):
    lval = min( min(arr_val_1), min(arr_val_2) )
    hval = max( max(arr_val_1), max(arr_val_2) )

    plt.hist(arr_val_1, histtype='step', range=[lval, hval], bins=30)
    plt.hist(arr_val_2, alpha=0.3      , range=[lval, hval], bins=30)
    plt.title(col)
    plt.savefig(plot_path)
    plt.close()
#-------------------------------------
def extract_first_element(arr_val):
    try:
        l_val = [ float(val[0]) for val in arr_val ]
    except:
        return None

    arr_val = numpy.array(l_val)

    return arr_val 
#-------------------------------------
def main():
    get_args()
    data.l_t_file = utnr.load_json(data.json_path)

    for p_file_1, p_file_2 in data.l_t_file:
        compare(p_file_1, p_file_2)
#-------------------------------------
if __name__ == '__main__':
    main()

