onnxruntime/tools/python/dump_ort_model.py
Justin Chu d834ec895a
Adopt linrtunner as the linting tool - take 2 (#15085)
### Description

`lintrunner` is a linter runner successfully used by pytorch, onnx and
onnx-script. It provides a uniform experience running linters locally
and in CI. It supports all major dev systems: Windows, Linux and MacOs.
The checks are enforced by the `Python format` workflow.

This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors
in Python code. `lintrunner` now runs all required python lints
including `ruff`(replacing `flake8`), `black` and `isort`. Future lints
like `clang-format` can be added.

Most errors are auto-fixed by `ruff` and the fixes should be considered
robust.

Lints that are more complicated to fix are applied `# noqa` for now and
should be fixed in follow up PRs.

### Notable changes

1. This PR **removed some suboptimal patterns**:

	- `not xxx in` -> `xxx not in` membership checks
	- bare excepts (`except:` -> `except Exception`)
	- unused imports
	
	The follow up PR will remove:
	
	- `import *`
	- mutable values as default in function definitions (`def func(a=[])`)
	- more unused imports
	- unused local variables

2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than
flake8 and is more robust. We are using it successfully in onnx and
onnx-script. It also supports auto-fixing many flake8 errors.

3. Removed the legacy flake8 ci flow and updated docs.

4. The added workflow supports SARIF code scanning reports on github,
example snapshot:
	

![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png)

5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Unified linting experience in CI and local.

Replacing https://github.com/microsoft/onnxruntime/pull/14306

---------

Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 15:29:03 -07:00

151 lines
5.9 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import contextlib
import os
import sys
import typing
# the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py
from util.ort_format_model.types import FbsTypeInfo # isort:skip
import ort_flatbuffers_py.fbs as fbs # isort:skip
class OrtFormatModelDumper:
"Class to dump an ORT format model."
def __init__(self, model_path: str):
"""
Initialize ORT format model dumper
:param model_path: Path to model
"""
self._file = open(model_path, "rb").read() # noqa: SIM115
self._buffer = bytearray(self._file)
if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
self._inference_session = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0)
self._model = self._inference_session.Model()
def _dump_initializers(self, graph: fbs.Graph):
print("Initializers:")
for idx in range(0, graph.InitializersLength()):
tensor = graph.Initializers(idx)
dims = []
for dim in range(0, tensor.DimsLength()):
dims.append(tensor.Dims(dim))
print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}")
print("--------")
def _dump_nodeargs(self, graph: fbs.Graph):
print("NodeArgs:")
for idx in range(0, graph.NodeArgsLength()):
node_arg = graph.NodeArgs(idx)
type = node_arg.Type()
if not type:
# NodeArg for optional value that does not exist
continue
type_str = FbsTypeInfo.typeinfo_to_str(type)
value_type = type.ValueType()
value = type.Value()
dims = None
if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
tensor_type_and_shape.Init(value.Bytes, value.Pos)
shape = tensor_type_and_shape.Shape()
if shape:
dims = []
for dim in range(0, shape.DimLength()):
d = shape.Dim(dim).Value()
if d.DimType() == fbs.DimensionValueType.DimensionValueType.VALUE:
dims.append(str(d.DimValue()))
elif d.DimType() == fbs.DimensionValueType.DimensionValueType.PARAM:
dims.append(d.DimParam().decode())
else:
dims.append("?")
else:
dims = None
print(f"{node_arg.Name().decode()} type={type_str} dims={dims}")
print("--------")
def _dump_node(self, node: fbs.Node):
optype = node.OpType().decode()
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
since_version = node.SinceVersion()
inputs = [node.Inputs(i).decode() for i in range(0, node.InputsLength())]
outputs = [node.Outputs(i).decode() for i in range(0, node.OutputsLength())]
print(
f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) "
f'inputs=[{",".join(inputs)}] outputs=[{",".join(outputs)}]'
)
def _dump_graph(self, graph: fbs.Graph):
"""
Process one level of the Graph, descending into any subgraphs when they are found
"""
self._dump_initializers(graph)
self._dump_nodeargs(graph)
print("Nodes:")
for i in range(0, graph.NodesLength()):
node = graph.Nodes(i)
self._dump_node(node)
# Read all the attributes
for j in range(0, node.AttributesLength()):
attr = node.Attributes(j)
attr_type = attr.Type()
if attr_type == fbs.AttributeType.AttributeType.GRAPH:
print(f"## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##")
self._dump_graph(attr.G())
print(f"## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##")
elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
# the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
# so entering this 'elif' isn't currently possible
print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##")
for k in range(0, attr.GraphsLength()):
print(f"## Subgraph {k} ##")
self._dump_graph(attr.Graphs(k))
print(f"## End Subgraph {k} ##")
def dump(self, output: typing.IO):
with contextlib.redirect_stdout(output):
print(f"ORT format version: {self._inference_session.OrtVersion().decode()}")
print("--------")
graph = self._model.Graph()
self._dump_graph(graph)
def parse_args():
parser = argparse.ArgumentParser(
os.path.basename(__file__), description="Dump an ORT format model. Output is to <model_path>.txt"
)
parser.add_argument("--stdout", action="store_true", help="Dump to stdout instead of writing to file.")
parser.add_argument("model_path", help="Path to ORT format model")
args = parser.parse_args()
if not os.path.isfile(args.model_path):
parser.error(f"{args.model_path} is not a file.")
return args
def main():
args = parse_args()
d = OrtFormatModelDumper(args.model_path)
if args.stdout:
d.dump(sys.stdout)
else:
output_filename = args.model_path + ".txt"
with open(output_filename, "w", encoding="utf-8") as ofile:
d.dump(ofile)
if __name__ == "__main__":
main()