#!/usr/bin/env python3

import re
import argparse
import pprint
import collections


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str)
    parser.add_argument("--step", type=int)
    args = parser.parse_args()

    # {"ph":"X","pid":10,"tid":15117,"ts":-556704358,"dur":556704358,"name":"Summary","cname":"black","args":{"total_bytes":363857920,"used_bytes":363857920,"free_bytes":0,"used_percent":1,"free_percent":0,"bytes for pattern":363857920}},
    # {"ph":"X","pid":10,"tid":860,"ts":-2,278,642,099,"dur":2,278,642,099,"name":"Summary","cname":"black","args":{"total_bytes":5,362,686,976,"used_bytes":5,362,686,976,"free_bytes":0,"used_percent":1,"free_percent":0,"bytes for pattern":5,362,686,976}}
    regexp = '[\s\S]*{"ph":"X","pid":([0-9,]+),"tid":([0-9,]+),"ts":[-]?[0-9,]+,"dur":([0-9,]+),"name":"Summary","cname":"black","args":{"total_bytes":([0-9,]+),"used_bytes":([0-9,]+),"free_bytes":([0-9,]+),"used_percent":([0-9,]+),"free_percent":([0-9,]+),"bytes for pattern":([0-9,]+)}}[,]?[\s\S]*'

    process_dict = collections.OrderedDict()
    with open(args.path) as f:
        for line in f:
            match = re.match(regexp, line)
            if match:
                step = int(str(match.group(1)).replace(",", ""))
                t_id = int(str(match.group(2)).replace(",", ""))
                total_bytes = int(str(match.group(4)).replace(",", ""))
                process_dict[t_id] = [total_bytes, step]
            else:
                print("warning: the line is not parsed correctly:", line)

    max_mem_usage = -1
    max_mem_usage_step = -1
    execution_step = -1
    for t_id, value in process_dict.items():
        if args.step != -1 and value[1] != args.step:
            continue
        if value[0] > max_mem_usage:
            max_mem_usage = value[0]
            max_mem_usage_step = value[1]
            execution_step = t_id


    print(f"max_mem_usage: {max_mem_usage} at execution step {execution_step} at training step {max_mem_usage_step}")

    # sorted_tuples = sorted(
    #     process_dict.items(), key=lambda item: item[1][0], reverse=False
    # )
    # print(f"## The sorted memory consumption for each execution step:")
    # print(f"(execution_step, total_bytes)")
    # pp = pprint.PrettyPrinter(indent=4)
    # pp.pprint(sorted_tuples)


if __name__ == "__main__":
    main()
