onnxruntime/tools/python/util/qdq_helpers/qdq_model_utils.py
Justin Chu d834ec895a
Adopt linrtunner as the linting tool - take 2 (#15085)
### Description

`lintrunner` is a linter runner successfully used by pytorch, onnx and
onnx-script. It provides a uniform experience running linters locally
and in CI. It supports all major dev systems: Windows, Linux and MacOs.
The checks are enforced by the `Python format` workflow.

This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors
in Python code. `lintrunner` now runs all required python lints
including `ruff`(replacing `flake8`), `black` and `isort`. Future lints
like `clang-format` can be added.

Most errors are auto-fixed by `ruff` and the fixes should be considered
robust.

Lints that are more complicated to fix are applied `# noqa` for now and
should be fixed in follow up PRs.

### Notable changes

1. This PR **removed some suboptimal patterns**:

	- `not xxx in` -> `xxx not in` membership checks
	- bare excepts (`except:` -> `except Exception`)
	- unused imports
	
	The follow up PR will remove:
	
	- `import *`
	- mutable values as default in function definitions (`def func(a=[])`)
	- more unused imports
	- unused local variables

2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than
flake8 and is more robust. We are using it successfully in onnx and
onnx-script. It also supports auto-fixing many flake8 errors.

3. Removed the legacy flake8 ci flow and updated docs.

4. The added workflow supports SARIF code scanning reports on github,
example snapshot:
	

![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png)

5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Unified linting experience in CI and local.

Replacing https://github.com/microsoft/onnxruntime/pull/14306

---------

Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 15:29:03 -07:00

108 lines
4.7 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from ..onnx_model_utils import get_producer_consumer_maps, iterate_graph_per_graph_func
def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs):
updated_graphs = kwargs["updated_graphs"]
node_to_consumers = kwargs["node_to_consumers"]
validate_updates = kwargs["validate_updates"]
nodes_to_update = []
for node in filter(lambda node: node.op_type == "DequantizeLinear", graph.node):
# node providing graph output won't have consumer nodes
consumers = node_to_consumers[node] if node in node_to_consumers else []
if len(consumers) > 1:
if not all(consumer in graph.node for consumer in consumers):
# TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that
# value as is (no need to handle recursing into the subgraph) and update the consumers in this
# graph only
raise IndexError(
"DequantizeLinear node output is consumed by a subgraph. " "This is not currently supported."
)
nodes_to_update.append(node)
if validate_updates:
if nodes_to_update:
# internal error. we somehow missed an update in the first pass when validate_upates was false
raise ValueError("Graph still has DequantizeLinear nodes with multiple consumers.")
return
if nodes_to_update:
dup_idx = 0
new_graph = onnx.GraphProto()
graph_outputs = {output.name for output in graph.output}
for node in graph.node:
new_graph.node.append(node)
if node in nodes_to_update:
is_graph_output = node.output[0] in graph_outputs
# create duplicate DQ nodes as needed so that there is one consumer per node.
# this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups.
# if the node produces a graph output we need a duplicate DQ node for every consumer node.
# if not, we can leave the first consumer as is and create duplicate nodes for the other consumers.
start_idx = 0 if is_graph_output else 1
consumers = list(node_to_consumers[node])[start_idx:]
for idx, consumer in enumerate(consumers):
# create duplicate DQ node
duplicate = onnx.NodeProto()
duplicate.CopyFrom(node)
# update node name for debugging. use the global dup idx for node duplication
duplicate.name += f"/qdq_utils_dup_{dup_idx}"
# update output. use the local idx for value duplication
orig_output = node.output[0]
new_output = f"{orig_output}/qdq_utils_dup_{idx}"
duplicate.output[0] = new_output
# update input on the consumer node.
for input_idx, input_name in enumerate(consumer.input):
if input_name == orig_output:
consumer.input[input_idx] = new_output
new_graph.node.append(duplicate)
dup_idx += 1
# replace nodes
del graph.node[:]
graph.node.extend(new_graph.node)
updated_graphs.append(graph)
def fix_dq_nodes_with_multiple_consumers(model):
"""
Update a model if any DequantizeLinear nodes have multiple consumers.
The QDQ node unit processing is overly complicated if this is the case, as the DQ node would be in multiple units,
and the units may end up in different partitions at runtime.
:param model: QDQ model to update
"""
node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph)
updated_graphs = [] # list of GraphProto instances that were updated_graphs
iterate_graph_per_graph_func(
model.graph,
_duplicate_dq_nodes_with_multiple_consumers,
node_to_consumers=node_to_consumers,
validate_updates=False,
updated_graphs=updated_graphs,
)
if updated_graphs:
updated_graphs = []
node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph)
iterate_graph_per_graph_func(
model.graph,
_duplicate_dq_nodes_with_multiple_consumers,
node_to_consumers=node_to_consumers,
validate_updates=True,
updated_graphs=updated_graphs,
)
# validate with check and by running shape inference.
onnx.checker.check_model(model)
_ = onnx.shape_inference.infer_shapes(model)