# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import argparse import contextlib import os import sys import typing # the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py from util.ort_format_model.types import FbsTypeInfo # isort:skip import ort_flatbuffers_py.fbs as fbs # isort:skip class OrtFormatModelDumper: "Class to dump an ORT format model." def __init__(self, model_path: str): """ Initialize ORT format model dumper :param model_path: Path to model """ self._file = open(model_path, "rb").read() # noqa: SIM115 self._buffer = bytearray(self._file) if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0): raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'") self._inference_session = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0) self._model = self._inference_session.Model() def _dump_initializers(self, graph: fbs.Graph): print("Initializers:") for idx in range(graph.InitializersLength()): tensor = graph.Initializers(idx) dims = [] for dim in range(tensor.DimsLength()): dims.append(tensor.Dims(dim)) print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}") print("--------") def _dump_nodeargs(self, graph: fbs.Graph): print("NodeArgs:") for idx in range(graph.NodeArgsLength()): node_arg = graph.NodeArgs(idx) type = node_arg.Type() if not type: # NodeArg for optional value that does not exist continue type_str = FbsTypeInfo.typeinfo_to_str(type) value_type = type.ValueType() value = type.Value() dims = None if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type: tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape() tensor_type_and_shape.Init(value.Bytes, value.Pos) shape = tensor_type_and_shape.Shape() if shape: dims = [] for dim in range(shape.DimLength()): d = shape.Dim(dim).Value() if d.DimType() == fbs.DimensionValueType.DimensionValueType.VALUE: dims.append(str(d.DimValue())) elif d.DimType() == fbs.DimensionValueType.DimensionValueType.PARAM: dims.append(d.DimParam().decode()) else: dims.append("?") else: dims = None print(f"{node_arg.Name().decode()} type={type_str} dims={dims}") print("--------") def _dump_node(self, node: fbs.Node): optype = node.OpType().decode() domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx since_version = node.SinceVersion() inputs = [node.Inputs(i).decode() for i in range(node.InputsLength())] outputs = [node.Outputs(i).decode() for i in range(node.OutputsLength())] print( f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) " f"inputs=[{','.join(inputs)}] outputs=[{','.join(outputs)}]" ) def _dump_graph(self, graph: fbs.Graph): """ Process one level of the Graph, descending into any subgraphs when they are found """ self._dump_initializers(graph) self._dump_nodeargs(graph) print("Nodes:") for i in range(graph.NodesLength()): node = graph.Nodes(i) self._dump_node(node) # Read all the attributes for j in range(node.AttributesLength()): attr = node.Attributes(j) attr_type = attr.Type() if attr_type == fbs.AttributeType.AttributeType.GRAPH: print(f"## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##") self._dump_graph(attr.G()) print(f"## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##") elif attr_type == fbs.AttributeType.AttributeType.GRAPHS: # the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute # so entering this 'elif' isn't currently possible print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##") for k in range(attr.GraphsLength()): print(f"## Subgraph {k} ##") self._dump_graph(attr.Graphs(k)) print(f"## End Subgraph {k} ##") def dump(self, output: typing.IO): with contextlib.redirect_stdout(output): print(f"ORT format version: {self._inference_session.OrtVersion().decode()}") print("--------") graph = self._model.Graph() self._dump_graph(graph) def parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), description="Dump an ORT format model. Output is to .txt" ) parser.add_argument("--stdout", action="store_true", help="Dump to stdout instead of writing to file.") parser.add_argument("model_path", help="Path to ORT format model") args = parser.parse_args() if not os.path.isfile(args.model_path): parser.error(f"{args.model_path} is not a file.") return args def main(): args = parse_args() d = OrtFormatModelDumper(args.model_path) if args.stdout: d.dump(sys.stdout) else: output_filename = args.model_path + ".txt" with open(output_filename, "w", encoding="utf-8") as ofile: d.dump(ofile) if __name__ == "__main__": main()