#!/usr/bin/env python3

import argparse
import logging
import shutil
import ROOT
import tqdm
import os

import utils_noroot as utnr
import utils

#--------------------------------
class data:
    root_dir = None
    fraction = None
    targ_dir = None
    dry_run  = None
    debug    = None
    list_path= None

    log      = utnr.getLogger(__name__)
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to copy ntuple tree structure')
    parser.add_argument('-l','--list'     , type=str  , help='Path to text file with list of files to copy'  , required=True)
    parser.add_argument('-d','--directory', type=str  , help='Directory path where tree structure begins'    , required=True)
    parser.add_argument('-f','--fraction' , type=float, help='Fraction of events to copy'                    , default =1)
    parser.add_argument('-t','--target'   , type=str  , help='Destination of ntuple tree'                    , required=True)
    parser.add_argument('-x','--dry_run'  , type=int  , help='Will copy the files if set to 0'               , choices=[0, 1], default=1)
    parser.add_argument('-m','--debug'    , type=int  , help='1 will turn on debugging messages'             , choices=[0, 1], default=0)
    args = parser.parse_args()

    data.list_path= args.list
    data.root_dir = args.directory
    data.fraction = args.fraction
    data.targ_dir = args.target
    data.dry_run  = args.dry_run
    data.debug    = args.debug
#--------------------------------
def copy_tuples():
    with open(data.list_path) as ifile:
        l_path = ifile.read().splitlines()

    data.log.info(f'Copying {len(l_path)} ntuples')
    for path in tqdm.tqdm(l_path, ascii=' -'):
        copy_tuple(path)
#--------------------------------
def skip_copy(org_path):
    new_path = org_path.replace(data.root_dir, data.targ_dir)
    if not os.path.isfile(new_path):
        return False 

    org_time = os.path.getctime(org_path)
    new_time = os.path.getctime(new_path)
    if new_time < org_time:
        return False

    org_size = os.path.getsize(org_path)
    new_size = os.path.getsize(new_path)
    if org_size != new_size:
        return False

    return True
#--------------------------------
def copy_tuple(org_path):
    if skip_copy(org_path):
        data.log.debug(f'Skipping: {org_path}')
        return
    else:
        data.log.debug(f'Not skipping: {org_path}')

    if data.fraction < 1 and data.dry_run == 0:
        org_path = filter_file(org_path)
        new_path = org_path.replace('/tmp', data.targ_dir)
    else:
        new_path = org_path.replace(data.root_dir, data.targ_dir)

    if os.path.exists(new_path):
        os.remove(new_path)

    data.log.debug('-' * 30)
    data.log.debug(org_path)
    data.log.debug('--->')
    data.log.debug(new_path)
    data.log.debug('-' * 30)

    new_dir   = os.path.dirname(new_path)
    os.makedirs(new_dir, exist_ok=True)

    if not data.dry_run:
        shutil.copy(org_path, new_path)
#--------------------------------
def filter_file(org_path):
    ifile     = ROOT.TFile(org_path)
    l_org_tree= utils.getTrees(ifile)
    new_path  = org_path.replace(data.root_dir, '/tmp') 
    new_dir   = os.path.dirname(new_path)
    os.makedirs(new_dir, exist_ok=True)

    ofile = ROOT.TFile(new_path, 'recreate')
    ofile.Close()

    opts = ROOT.RDF.RSnapshotOptions()
    opts.fMode = 'update'

    for org_tree in l_org_tree:
        nentries = org_tree.GetEntries()
        nentries = nentries * data.fraction
        nentries = int(nentries)

        rdf = ROOT.RDataFrame(org_tree)
        rdf = rdf.Range(nentries)

        rdf.Snapshot(org_tree.GetName(), new_path, '', opts)

    ifile.Close()

    return new_path
#--------------------------------
def check_args():
    if data.dry_run:
        data.log.warning('Running a dry run')

    try:
        os.makedirs(data.targ_dir, exist_ok=True)
    except:
        data.log.error(f'Cannot make target directory: {data.targ_dir}')
        raise

    if data.fraction < 0 or data.fraction > 1:
        data.log.error(f'Invalid value of fraction: {data.fraction}')
        raise

    if not os.path.isdir(data.root_dir):
        data.log.error(f'Root of directory tree not found: {data.root_dir}')
        raise

    if data.debug == 1:
        data.log.setLevel(logging.DEBUG)

    if not os.path.isfile(data.list_path):
        data.log.error(f'Cannot find list of paths to ntuples to copy: {data.list_path}')
        raise
#--------------------------------
def main():
    get_args()
    check_args()
    copy_tuples()
#--------------------------------
if __name__ == '__main__':
    main()
#--------------------------------

