#!/usr/bin/env python3

import re
import argparse
import pprint
import collections


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str)
    parser.add_argument("--dump_dir", type=str)
    args = parser.parse_args()


    # [1,0]<stdout>:Output 0 Name: onnx::ReduceMean_460_grad
    # [1,0]<stdout>:Output 1 Name: _original_module.roberta.embeddings.LayerNorm.weight_grad
    # [1,0]<stdout>:Output 2 Name: _original_module.roberta.embeddings.LayerNorm.bias_grad
    # [1,0]<stdout>:Output 0 Name: _original_module.roberta.embeddings.position_embeddings.weight_grad
    # [1,0]<stdout>:Output 0 Name: _original_module.roberta.embeddings.token_type_embeddings.weight_grad
    # [1,0]<stdout>:Output 0 Name: _original_module.roberta.embeddings.word_embeddings.weight_grad

    regexp = '[\s\S]*:Input ([0-9,]+) Name: ([\s\S]+)'
    output_path = args.path + ".output"
    with open(output_path, "w") as f2:
        with open(args.path) as f:
            for line in f:
                match = re.match(regexp, line)
                if match:
                    output_index = int(str(match.group(1)).replace(",", ""))
                    output_name = str(match.group(2)).replace("-", "_")
                    f2.write(output_name)
                else:
                    print("warning: the line is not parsed correctly:", line)

    from onnxsharp import ort_scan_tensor_from_dump
    with open(output_path, "r") as f:
        for line in f:
            output_name = line.replace(".", "_").replace(":", "_").replace("/", "_").replace("-", "_").replace("\r", "").replace("\n", "")
            if output_name == "":
                continue
            # print("scan tensor output name ", line, "-->", output_name)
            ort_scan_tensor_from_dump(args.dump_dir, output_name)

if __name__ == "__main__":
    main()
