mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Use ruff as the code formatter in place of black and isort since it is much faster, and as projects like PyTorch and ONNX have adopted ruff format as well. This PR include only auto-fixed changes in formatting.
151 lines
5.9 KiB
Python
151 lines
5.9 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
import contextlib
|
|
import os
|
|
import sys
|
|
import typing
|
|
|
|
# the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py
|
|
from util.ort_format_model.types import FbsTypeInfo # isort:skip
|
|
import ort_flatbuffers_py.fbs as fbs # isort:skip
|
|
|
|
|
|
class OrtFormatModelDumper:
|
|
"Class to dump an ORT format model."
|
|
|
|
def __init__(self, model_path: str):
|
|
"""
|
|
Initialize ORT format model dumper
|
|
:param model_path: Path to model
|
|
"""
|
|
self._file = open(model_path, "rb").read() # noqa: SIM115
|
|
self._buffer = bytearray(self._file)
|
|
if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
|
|
raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
|
|
self._inference_session = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0)
|
|
self._model = self._inference_session.Model()
|
|
|
|
def _dump_initializers(self, graph: fbs.Graph):
|
|
print("Initializers:")
|
|
for idx in range(graph.InitializersLength()):
|
|
tensor = graph.Initializers(idx)
|
|
dims = []
|
|
for dim in range(tensor.DimsLength()):
|
|
dims.append(tensor.Dims(dim))
|
|
|
|
print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}")
|
|
print("--------")
|
|
|
|
def _dump_nodeargs(self, graph: fbs.Graph):
|
|
print("NodeArgs:")
|
|
for idx in range(graph.NodeArgsLength()):
|
|
node_arg = graph.NodeArgs(idx)
|
|
type = node_arg.Type()
|
|
if not type:
|
|
# NodeArg for optional value that does not exist
|
|
continue
|
|
|
|
type_str = FbsTypeInfo.typeinfo_to_str(type)
|
|
value_type = type.ValueType()
|
|
value = type.Value()
|
|
dims = None
|
|
if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
|
|
tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
|
|
tensor_type_and_shape.Init(value.Bytes, value.Pos)
|
|
shape = tensor_type_and_shape.Shape()
|
|
if shape:
|
|
dims = []
|
|
for dim in range(shape.DimLength()):
|
|
d = shape.Dim(dim).Value()
|
|
if d.DimType() == fbs.DimensionValueType.DimensionValueType.VALUE:
|
|
dims.append(str(d.DimValue()))
|
|
elif d.DimType() == fbs.DimensionValueType.DimensionValueType.PARAM:
|
|
dims.append(d.DimParam().decode())
|
|
else:
|
|
dims.append("?")
|
|
else:
|
|
dims = None
|
|
|
|
print(f"{node_arg.Name().decode()} type={type_str} dims={dims}")
|
|
print("--------")
|
|
|
|
def _dump_node(self, node: fbs.Node):
|
|
optype = node.OpType().decode()
|
|
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
|
|
since_version = node.SinceVersion()
|
|
|
|
inputs = [node.Inputs(i).decode() for i in range(node.InputsLength())]
|
|
outputs = [node.Outputs(i).decode() for i in range(node.OutputsLength())]
|
|
print(
|
|
f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) "
|
|
f"inputs=[{','.join(inputs)}] outputs=[{','.join(outputs)}]"
|
|
)
|
|
|
|
def _dump_graph(self, graph: fbs.Graph):
|
|
"""
|
|
Process one level of the Graph, descending into any subgraphs when they are found
|
|
"""
|
|
|
|
self._dump_initializers(graph)
|
|
self._dump_nodeargs(graph)
|
|
print("Nodes:")
|
|
for i in range(graph.NodesLength()):
|
|
node = graph.Nodes(i)
|
|
self._dump_node(node)
|
|
|
|
# Read all the attributes
|
|
for j in range(node.AttributesLength()):
|
|
attr = node.Attributes(j)
|
|
attr_type = attr.Type()
|
|
if attr_type == fbs.AttributeType.AttributeType.GRAPH:
|
|
print(f"## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##")
|
|
self._dump_graph(attr.G())
|
|
print(f"## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##")
|
|
elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
|
|
# the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
|
|
# so entering this 'elif' isn't currently possible
|
|
print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##")
|
|
for k in range(attr.GraphsLength()):
|
|
print(f"## Subgraph {k} ##")
|
|
self._dump_graph(attr.Graphs(k))
|
|
print(f"## End Subgraph {k} ##")
|
|
|
|
def dump(self, output: typing.IO):
|
|
with contextlib.redirect_stdout(output):
|
|
print(f"ORT format version: {self._inference_session.OrtVersion().decode()}")
|
|
print("--------")
|
|
|
|
graph = self._model.Graph()
|
|
self._dump_graph(graph)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
os.path.basename(__file__), description="Dump an ORT format model. Output is to <model_path>.txt"
|
|
)
|
|
parser.add_argument("--stdout", action="store_true", help="Dump to stdout instead of writing to file.")
|
|
parser.add_argument("model_path", help="Path to ORT format model")
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.isfile(args.model_path):
|
|
parser.error(f"{args.model_path} is not a file.")
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
d = OrtFormatModelDumper(args.model_path)
|
|
|
|
if args.stdout:
|
|
d.dump(sys.stdout)
|
|
else:
|
|
output_filename = args.model_path + ".txt"
|
|
with open(output_filename, "w", encoding="utf-8") as ofile:
|
|
d.dump(ofile)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|