#!/usr/bin/env python3

from rk.ds_getter import ds_getter as dsg
from logzero      import logger    as log

import os
import argparse
import pandas       as pnd
import utils_noroot as utnr
#----------------------------------------
class data:
    dset = None
    vers = None
    trig = None
    ipart= None
    npart= None
#----------------------------------------
class cache_data:
    def __init__(self, dset=None, trig=None, ipart=None, npart=None, dvers=None):
        self._dset        = dset
        self._trig        = trig
        self._ipart       = ipart 
        self._npart       = npart 

        #Dataset getter needs this, so hardcoded, but q2 cut is removed later
        self._q2bin       = 'high'
        self._dvers       = dvers 
        self._selection   = 'all_gorder'

        self._cache_dir   = None

        self._initialized = False
    #----------------------------------------
    def _initialize(self):
        if self._initialized:
            return

        if 'CASDIR' not in os.environ:
            log.error(f'Caching directory, $CASDIR, not found in environment')
            raise
        else:
            self._cache_dir = os.environ['CASDIR']

        self._initialized = True
    #----------------------------------------
    def _cache_path(self, sample):
        path_dir = f'{self._cache_dir}/cb_calculator/{sample}_{self._dset}_{self._trig}'
        path     = f'{path_dir}/{self._ipart}_{self._npart}.json'

        if os.path.isfile(path):
            log.info(f'Loading cached data: {path}')
            return path

        return
    #----------------------------------------
    def _rdf_to_df(self, rdf):
        rdf      = rdf.Define('B_jpsi_M',       'B_const_mass_M[0]')
        rdf      = rdf.Define('B_psi2_M', 'B_const_mass_psi2S_M[0]')
        d_data   = rdf.AsNumpy([
            'B_M', 
            'BDT_cmb', 
            'BDT_prc', 
            'Jpsi_M', 
            'B_jpsi_M', 
            'B_psi2_M',
            'L1_P', 
            'L2_P',
            'L1_PT', 
            'L2_PT', 
            'H_PT', 
            'L1_PE', 
            'L2_PE', 
            'H_PE', 
            'L1_PX', 
            'L2_PX',
            'H_PX',
            'L1_PY', 
            'L2_PY',
            'H_PY',
            'L1_PZ', 
            'L2_PZ',
            'H_PZ',
            'L1_ID',
            'L2_ID',
            'H_ID'])
        df       = pnd.DataFrame(d_data)

        return df
    #----------------------------------------
    def _cache(self, rdf, sample, chan):
        path_dir = f'{self._cache_dir}/cb_calculator/{sample}_{chan}_{self._dset}_{self._trig}'
        os.makedirs(path_dir, exist_ok=True)
        mass_path = f'{path_dir}/data_{self._ipart}_{self._npart}.json'
        ctfl_path = f'{path_dir}/ctfl_{self._ipart}_{self._npart}.json'

        log.info(f'Caching to: {ctfl_path}')
        ctf = rdf.cf
        ctf.to_json(ctfl_path)

        log.info(f'Caching to: {mass_path}')
        df = self._rdf_to_df(rdf)
        df.to_json(mass_path, indent=4)

        return df
    #----------------------------------------
    def _get_data(self, sample=None, trig=None):
        df_path = self._cache_path(sample)
        if df_path is not None:
            df = pnd.read_json(df_path)
            return df

        part = (self._ipart, self._npart)
        obj = dsg(self._q2bin, trig, self._dset, self._dvers, part, sample, self._selection)

        d_def={'bdt' : '(1)', 'mass' : '(1)', 'q2' : '(1)', 'pid' : '(1)'}
        if trig == 'MTOS':
            log.warning('Skipping acceptance for MTOS')
            d_def['acceptance'] = '(1)'
            
        rdf = obj.get_df(d_redefine=d_def)

        chan= 'mm' if trig == 'MTOS' else 'ee'
        df  = self._cache(rdf, sample, chan)

        return df
    #----------------------------------------
    def save(self):
        self._initialize()

        if self._trig == 'TOS':
            self._get_data(sample='cmb' , trig='ETOS')
            self._get_data(sample='cmb' , trig='MTOS')

            self._get_data(sample='data', trig='ETOS')
            self._get_data(sample='data', trig='MTOS')
        else:
            self._get_data(sample='cmb' , trig='GTIS')
            self._get_data(sample='data', trig='GTIS')
#----------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to get datasets from OS and SS data, after full selection, but BDT, used for combinatorial PDF studies')
    parser.add_argument('-d', '--dset' , type= str, help='Dataset'  , choices=['2016', '2017', '2018'], required=True)
    parser.add_argument('-t', '--trig' , type= str, help='Trigger'  , choices=['TOS', 'TIS'], required=True)
    parser.add_argument('-v', '--vers' , type= str, help='Version of ntuples', required=True) 
    parser.add_argument('-p', '--part' , nargs='+', help='partition', required=True)
    args = parser.parse_args()

    data.dset = args.dset
    data.trig = args.trig
    data.vers = args.vers

    data.ipart, data.npart = [ int(part) for part in args.part]
#----------------------------------------
if __name__ == '__main__':
    get_args()
    obj=cache_data(dset=data.dset, trig=data.trig, ipart=data.ipart, npart=data.npart, dvers=data.vers)
    obj.save()

