#!/usr/bin/env python3

import os
import glob
import numpy
import mplhep
import argparse

import pandas            as pnd
import matplotlib.pyplot as plt

from logzero import logger as log

#-----------------------------------------------
class data:
    l_dset    = ['r1', 'r2p1', '2017', '2018']
    l_trig    = ['ETOS', 'GTIS']
    l_kind    = ['SS', 'OS']

    dset      = None
    trig      = None
    vers      = None
    kind      = None
    plots_dir = None
#----------------------------
def get_data():
    try:
        cas_dir = os.environ['CASDIR']
    except:
        log.error(f'$CASDIR variable not found')
        raise

    sample   = 'cmb' if data.kind == 'SS' else 'data'
    json_dir = f'{cas_dir}/cb_calculator/{data.vers}/{sample}_{data.dset}_{data.trig}'

    l_json = glob.glob(f'{json_dir}/*.json')
    if len(l_json) == 0:
        log.error(f'No JSON file found in: {json_dir}')
        raise

    l_df   = [ pnd.read_json(json_path) for json_path in l_json ]
    df     = pnd.concat(l_df, axis=0)

    return df
#----------------------------
def plot_charges():
    df = get_data()

    b1 = len(df.query('L1_ID == -11 & L2_ID == -11 & H_ID ==+321'))
    b2 = len(df.query('L1_ID == -11 & L2_ID == -11 & H_ID ==-321'))
    b3 = len(df.query('L1_ID == +11 & L2_ID == +11 & H_ID ==+321'))
    b4 = len(df.query('L1_ID == +11 & L2_ID == +11 & H_ID ==-321'))

    b5 = len(df.query('L1_ID == +11 & L2_ID == -11 & H_ID ==+321'))
    b6 = len(df.query('L1_ID == -11 & L2_ID == +11 & H_ID ==+321'))
    b7 = len(df.query('L1_ID == +11 & L2_ID == -11 & H_ID ==-321'))
    b8 = len(df.query('L1_ID == -11 & L2_ID == +11 & H_ID ==-321'))

    ax = plt.gca()

    l_proc     = [
            '$e^+e^+K^+$',
            '$e^+e^+K^-$',
            '$e^-e^-K^+$',
            '$e^-e^-K^-$',

            '$e^-e^+K^+$',
            '$e^+e^-K^+$',
            '$e^-e^+K^-$',
            '$e^+e^-K^-$']

    l_bin_cont = [b1, b2, b3, b4, b5, b6, b7, b8]

    ax.bar(l_proc, l_bin_cont)

    plot_path = f'{data.plots_dir}/{data.kind}.png'
    log.info(f'Saving to: {plot_path}')
    plt.xticks(rotation=60)
    plt.tight_layout()
    plt.savefig(plot_path)
    plt.close('all')
#----------------------------
def get_plots_dir():
    if 'CASDIR' not in os.environ:
        log.error(f'Caching directory, $CASDIR, not found in environment')
        raise
    else:
        cache_dir = f'{os.environ["CASDIR"]}/cb_calculator/{data.vers}'
        plots_dir = f'{cache_dir}/plots_{data.dset}_{data.trig}'

    os.makedirs(plots_dir, exist_ok=True)

    return plots_dir
#----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Plots quantities from data, used for checks')
    parser.add_argument('-d', '--dset' , type=str, help='Year of data to process', default=data.l_dset, choices=data.l_dset)
    parser.add_argument('-t', '--trig' , type=str, help='Trigger'                , default=data.l_trig, choices=data.l_trig)
    parser.add_argument('-k', '--kind' , type=str, help='Kind of data'           , default=data.l_kind, choices=data.l_kind)
    parser.add_argument('-v', '--vers' , type=str, help='Version of output'      , required=True)
    args = parser.parse_args()

    data.dset      = args.dset
    data.trig      = args.trig
    data.vers      = args.vers
    data.kind      = args.kind
    data.plots_dir = get_plots_dir()

    plt.style.use(mplhep.style.LHCb2)
#----------------------------
def main():
    get_args()

    plot_charges()
#----------------------------
if __name__ == '__main__':
    main()

