#!/usr/bin/env python3

from builder import builder 

import os
import zfit
import numpy
import mplhep
import argparse
import matplotlib.pyplot as plt

from logzero import logger as log

#-----------------------------------------------
class data:
    l_dset    = ['r1', 'r2p1', '2017', '2018']
    l_trig    = ['ETOS', 'GTIS', 'MTOS']
    l_q2bin   = ['jpsi', 'psi2', 'high']

    dset      = None
    trig      = None
    q2bin     = None
    const_pref= None
    mass_const= None

    obs       = None
    vers      = None
    plots_dir = None
#----------------------------
def plot_pdf(pdf):
    [[minx], [maxx]] = data.obs.limits

    arr_x = numpy.linspace(minx, maxx, 2000)
    arr_y = pdf.pdf(arr_x)

    plt.plot(arr_x, arr_y) 
#----------------------------
def plot_comp():
    d_sys_1  = {'mu' :  0, 'lm' : +1}
    d_sys_2  = {'mu' :  0, 'lm' : -1}
    d_sys_3  = {'mu' : +1, 'lm' :  0}
    d_sys_4  = {'mu' : -1, 'lm' :  0}


    obj      = builder(dset=data.dset, trigger=data.trig, vers=data.vers, q2bin=data.q2bin, const=data.mass_const)
    pdf_0, _ = obj.get_pdf(obs=data.obs, unc =      0, preffix='comp_z')
    pdf_1, _ = obj.get_pdf(obs=data.obs, unc =   None, preffix='comp_o')
    pdf_2, _ = obj.get_pdf(obs=data.obs, unc= d_sys_1, preffix='comp_1')
    pdf_3, _ = obj.get_pdf(obs=data.obs, unc= d_sys_2, preffix='comp_2')
    pdf_4, _ = obj.get_pdf(obs=data.obs, unc= d_sys_3, preffix='comp_3')
    pdf_5, _ = obj.get_pdf(obs=data.obs, unc= d_sys_4, preffix='comp_4')

    plot_pdf(pdf_0)
    plot_pdf(pdf_1)
    plot_pdf(pdf_2)
    plot_pdf(pdf_3)
    plot_pdf(pdf_4)
    plot_pdf(pdf_5)

    plt.legend(['SS', 
        r'$\bar{OS}$',  
        r'$\bar{OS} + \sigma_{\lambda}$', 
        r'$\bar{OS} - \sigma_{\lambda}$',
        r'$\bar{OS} + \sigma_{\mu}$', 
        r'$\bar{OS} - \sigma_{\mu}$'
        ])

    plt.xlabel(r'$B\to K \ell^{\pm}\ell^{\pm}$ Mass[MeV]')
    plt.ylabel('Normalized')
    plt.grid()
    plot_path = f'{data.plots_dir}/{data.q2bin}_{data.const_pref}_comp.png'
    log.info(f'Saving to: {plot_path}')
    plt.savefig(plot_path)
    plt.close('all')
#----------------------------
def get_plots_dir():
    if 'CASDIR' not in os.environ:
        log.error(f'Caching directory, $CASDIR, not found in environment')
        raise
    else:
        cache_dir = f'{os.environ["CASDIR"]}/cb_calculator/{data.vers}'
        plots_dir = f'{cache_dir}/uncertainty_{data.dset}_{data.trig}'

    os.makedirs(plots_dir, exist_ok=True)

    return plots_dir
#----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Plots nominal and systematically fluctuated combinatorial PDF')
    parser.add_argument('-d', '--dset' , type=str, help='Year of data to process', required=True, choices=data.l_dset )
    parser.add_argument('-t', '--trig' , type=str, help='Trigger'                , required=True, choices=data.l_trig )
    parser.add_argument('-q', '--q2bin', type=str, help='q2 bin'                 , required=True, choices=data.l_q2bin)
    parser.add_argument('-v', '--vers' , type=str, help='Version of output'      , required=True)
    parser.add_argument('-c', '--const',           help='Use mass constrain'     , action='store_true')
    args = parser.parse_args()

    data.dset      = args.dset
    data.trig      = args.trig
    data.q2bin     = args.q2bin
    data.vers      = args.vers
    data.mass_const= args.const
    data.const_pref= 'yconst' if args.const else 'nconst'

    data.plots_dir = get_plots_dir()
    data.obs       = get_obs()

    plt.style.use(mplhep.style.LHCb2)
#----------------------------
def get_obs():
    if   data.q2bin == 'high':
        obs   = zfit.Space('mass', limits=(4480, 6500))
    elif data.q2bin in ['jpsi', 'psi2']:
        obs   = zfit.Space('mass', limits=(4000, 6500))
    else:
        log.error(f'Invalid q2bin: {data.q2bin}')
        raise

    log.info(f'Using {data.q2bin}')

    return obs
#----------------------------
def main():
    get_args()
    plot_comp()
#----------------------------
if __name__ == '__main__':
    main()

