#!python

import argparse
import ROOT
import sys
import os

from logzero import logger as log

#---------------------------------
class data:
    l_file_path = None
#---------------------------------
def GetTrees(directory):
    l_tree=list()
    l_key=directory.GetListOfKeys()

    for key in l_key:
        obj=key.ReadObj()
        if obj.InheritsFrom("TTree"):
            name=obj.GetName()
            title="{}/{}".format(directory.GetName(), name)
            obj.SetTitle(title)
            l_tree.append(obj)
        elif obj.InheritsFrom("TDirectory"):
            l_tree+=GetTrees(obj)

    return l_tree
#---------------------------------
def saveBranchNames(tree, ofile):
    l_branch=tree.GetListOfBranches()
    for branch in l_branch:
        branchname=branch.GetName()
        tree.SetBranchStatus(branchname, 1)
        leaf=branch.GetLeaf(branchname)

        try:
            typename=leaf.GetTypeName()
        except:
            print("Cannot retrieve leaf for " + branchname)
            l_leave = branch.GetListOfLeaves()
            if l_leave.GetEntries() == 1:
                leaf=l_leave[0]
                print("Found instead: " + leaf.GetName())
                typename=leaf.GetTypeName()
            else:
                print("Found {} leaves in branch".format(l_leave.GetEntries()) )
                for leave in l_leave:
                    leave.Print()
                continue

        ofile.write("{0:4}{1:40}{2:40}\n".format("    ", branchname, typename))
#---------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to dump list of branches in trees')
    parser.add_argument('-p', '--paths', nargs='+', help='Paths to files', required=True) 
    args = parser.parse_args()

    data.l_file_path = args.paths
#---------------------------------
def get_out_path(ipath):
    '''
    Will return path to text file with list of branches. If input file is in grid, file will be in current directory
    '''
    if not ipath.startswith('root://'):
        opath=ipath.replace(".root", ".txt")
        return opath

    file_name = os.path.basename(ipath)
    file_name = file_name.replace('.root', '.txt')

    return file_name
#---------------------------------
def main():
    for ipath in data.l_file_path:
        opath = get_out_path(ipath)
        ofile=open(opath, "w")

        ifile=ROOT.TFile.Open(ipath)
        l_tree=GetTrees(ifile)
        
        for tree in l_tree:
            path=tree.GetTitle()
            log.info(f'Found tree in path: {path}')
            ofile.write("\n")
            ofile.write(f"{path}\n")
            saveBranchNames(tree, ofile)
#---------------------------------
if __name__ == '__main__':
    get_args()
    main()

