import onnx
import pytest
import re
import argparse
import pprint
import numpy as np
import collections
from onnxsharp import (
    Model,
    Node,
    LogicalSubgraphInfo,
    create_graph_from_logical_subgraph,
    Model,
    Graph,
    Node,
    clip_subgraph_around,
    auto_cluster_pointwise_graphs,
)
from onnx import helper, defs, numpy_helper, checker, onnx_pb
from collections import OrderedDict
import pprint


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str)
    parser.add_argument("--mpath", type=str)
    args = parser.parse_args()

    # {"ph":"X","pid":10,"tid":2,628,"ts":570,425,344,"dur":8,388,608,"name":"onnx::Transpose_5207","cname":"cq_build_attempt_passed","args":{"name":"onnx::Transpose_5207","offset":570,425,344,"size":8,388,608}},
    # {"ph":"X","pid":10,"tid":2,628,"ts":8,296,333,312,"dur":16,777,216,"name":"onnx::MatMul_1397","cname":"thread_state_unknown","args":{"name":"onnx::MatMul_1397","offset":8,296,333,312,"size":16,777,216}},
    regexp = '[\s\S]*{"ph":"X","pid":([0-9,]+),"tid":([0-9,]+),"ts":[-]?[0-9,]+,"dur":([0-9,]+),"name":"([\s\S]+)","cname":"[\s\S]+","args":{"name":"[\s\S]+","offset":[0-9,]+,"size":([0-9,]+)}},[\s\S]*'

    m = Model.load_model(args.mpath)

    op_type_count = OrderedDict()
    op_type_to_shapes = OrderedDict()
    total_mem_consumption = 0
    with open(args.path) as f:
        for line in f:
            match = re.match(regexp, line)
            if match:
                name = str(match.group(4))
                total_bytes = int(str(match.group(5)).replace(",", ""))

                if m._graph.is_activation(name):
                    j, o_index = m._graph.get_node_with_output_arg_name(name)
                    if j.type not in op_type_count:
                        op_type_count[j.type] = 0
                        op_type_to_shapes[j.type] = {}

                    op_type_count[j.type] += 1

                    o_arg = j.output_arg(o_index)
                    shape = o_arg.shape
                    if (
                        o_arg
                        and o_arg._value_info
                        and hasattr(o_arg._value_info, "_type")
                        and hasattr(o_arg._value_info._type, "_elem_type")
                    ):
                        elem_type = o_arg._value_info._type._elem_type
                    else:
                        elem_type = onnx_pb.TensorProto.FLOAT16

                    bytes_per_element = 1
                    if elem_type in [
                        onnx_pb.TensorProto.UINT64,
                        onnx_pb.TensorProto.INT64,
                        onnx_pb.TensorProto.UINT64,
                        onnx_pb.TensorProto.DOUBLE,
                    ]:
                        bytes_per_element = 8
                    elif elem_type in [
                        onnx_pb.TensorProto.FLOAT,
                        onnx_pb.TensorProto.INT32,
                        onnx_pb.TensorProto.UINT32,
                    ]:
                        bytes_per_element = 4
                    elif elem_type in [
                        onnx_pb.TensorProto.FLOAT16,
                        onnx_pb.TensorProto.INT16,
                        onnx_pb.TensorProto.UINT16,
                    ]:
                        bytes_per_element = 2
                    elif elem_type in [
                        onnx_pb.TensorProto.INT8,
                        onnx_pb.TensorProto.UINT8,
                        onnx_pb.TensorProto.BOOL,
                    ]:
                        bytes_per_element = 1
                    else:
                        raise RuntimeError(
                            "did not support data type {}".format(elem_type)
                        )

                    shape = (
                        str(elem_type)
                        + "-"
                        + (
                            "None"
                            if shape is None
                            else ",".join([str(s) for s in shape])
                        )
                    )
                    if shape not in op_type_to_shapes[j.type]:
                        op_type_to_shapes[j.type][shape] = [
                            0,
                            total_bytes,
                            bytes_per_element,
                            None,
                            [],
                        ]

                    op_type_to_shapes[j.type][shape][0] += 1
                    op_type_to_shapes[j.type][shape][4].append(name)
                    total_mem_consumption += total_bytes
            else:
                print("warning: the line is not parsed correctly:", line)

    sorted_tuples = sorted(
        op_type_count.items(), key=lambda item: item[1], reverse=True
    )

    print(f"## node summary:")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(sorted_tuples)

    print(f"## Total memory consumption: {total_mem_consumption} Bytes.")
    for op_type in op_type_to_shapes:
        for shape, stats in op_type_to_shapes[op_type].items():
            stats[3] = (
                "{:.2f}".format((stats[0] * stats[1]) / total_mem_consumption * 100)
                + "%"
            )

            stats[4] = ",".join(stats[4])

    pp.pprint(op_type_to_shapes)


if __name__ == "__main__":
    main()
