diff --git a/onnxruntime/test/testdata/qdq_with_multi_consumer_dq_nodes.onnx b/onnxruntime/test/testdata/qdq_with_multi_consumer_dq_nodes.onnx new file mode 100644 index 0000000000..644b370f73 Binary files /dev/null and b/onnxruntime/test/testdata/qdq_with_multi_consumer_dq_nodes.onnx differ diff --git a/tools/python/util/qdq_helpers/__init__.py b/tools/python/util/qdq_helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/python/util/qdq_helpers/optimize_qdq_model.py b/tools/python/util/qdq_helpers/optimize_qdq_model.py new file mode 100644 index 0000000000..7fbc227b45 --- /dev/null +++ b/tools/python/util/qdq_helpers/optimize_qdq_model.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import onnx +import os +import pathlib + +from .qdq_model_utils import fix_dq_nodes_with_multiple_consumers + + +def optimize_qdq_model(): + parser = argparse.ArgumentParser(os.path.basename(__file__), + description=''' + Update a QDQ format ONNX model to ensure optimal performance when executed using + ONNX Runtime. + ''') + + parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.') + parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write updated ONNX model to.') + + args = parser.parse_args() + + model = onnx.load(str(args.input_model.resolve(strict=True))) + + # there's just one utility to run currently but we expect that will grow + fix_dq_nodes_with_multiple_consumers(model) + + onnx.save(model, str(args.output_model.resolve())) + + +if __name__ == '__main__': + optimize_qdq_model() diff --git a/tools/python/util/qdq_helpers/qdq_model_utils.py b/tools/python/util/qdq_helpers/qdq_model_utils.py new file mode 100644 index 0000000000..62e95f948b --- /dev/null +++ b/tools/python/util/qdq_helpers/qdq_model_utils.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from ..onnx_model_utils import get_producer_consumer_maps, iterate_graph_per_graph_func + + +def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs): + updated_graphs = kwargs['updated_graphs'] + node_to_consumers = kwargs['node_to_consumers'] + validate_updates = kwargs['validate_updates'] + + nodes_to_update = [] + for node in filter(lambda node: node.op_type == 'DequantizeLinear', graph.node): + # node providing graph output won't have consumer nodes + consumers = node_to_consumers[node] if node in node_to_consumers else [] + if len(consumers) > 1: + if not all(consumer in graph.node for consumer in consumers): + # TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that + # value as is (no need to handle recursing into the subgraph) and update the consumers in this + # graph only + raise IndexError("DequantizeLinear node output is consumed by a subgraph. " + "This is not currently supported.") + + nodes_to_update.append(node) + + if validate_updates: + if nodes_to_update: + # internal error. we somehow missed an update in the first pass when validate_upates was false + raise ValueError('Graph still has DequantizeLinear nodes with multiple consumers.') + + return + + if nodes_to_update: + dup_idx = 0 + new_graph = onnx.GraphProto() + graph_outputs = set([output.name for output in graph.output]) + for node in graph.node: + new_graph.node.append(node) + if node in nodes_to_update: + is_graph_output = node.output[0] in graph_outputs + # create duplicate DQ nodes as needed so that there is one consumer per node. + # this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups. + # if the node produces a graph output we need a duplicate DQ node for every consumer node. + # if not, we can leave the first consumer as is and create duplicate nodes for the other consumers. + start_idx = 0 if is_graph_output else 1 + consumers = list(node_to_consumers[node])[start_idx:] + + for idx, consumer in enumerate(consumers): + # create duplicate DQ node + duplicate = onnx.NodeProto() + duplicate.CopyFrom(node) + # update node name for debugging. use the global dup idx for node duplication + duplicate.name += f'/qdq_utils_dup_{dup_idx}' + + # update output. use the local idx for value duplication + orig_output = node.output[0] + new_output = f'{orig_output}/qdq_utils_dup_{idx}' + duplicate.output[0] = new_output + + # update input on the consumer node. + for input_idx, input_name in enumerate(consumer.input): + if input_name == orig_output: + consumer.input[input_idx] = new_output + + new_graph.node.append(duplicate) + dup_idx += 1 + + # replace nodes + del graph.node[:] + graph.node.extend(new_graph.node) + updated_graphs.append(graph) + + +def fix_dq_nodes_with_multiple_consumers(model): + ''' + Update a model if any DequantizeLinear nodes have multiple consumers. + The QDQ node unit processing is overly complicated if this is the case, as the DQ node would be in multiple units, + and the units may end up in different partitions at runtime. + :param model: QDQ model to update + ''' + node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) + + updated_graphs = [] # list of GraphProto instances that were updated_graphs + iterate_graph_per_graph_func(model.graph, _duplicate_dq_nodes_with_multiple_consumers, + node_to_consumers=node_to_consumers, validate_updates=False, + updated_graphs=updated_graphs) + + if updated_graphs: + updated_graphs = [] + node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) + iterate_graph_per_graph_func(model.graph, _duplicate_dq_nodes_with_multiple_consumers, + node_to_consumers=node_to_consumers, validate_updates=True, + updated_graphs=updated_graphs) + + # validate with check and by running shape inference. + onnx.checker.check_model(model) + _ = onnx.shape_inference.infer_shapes(model) diff --git a/tools/python/util/qdq_helpers/test/__init__.py b/tools/python/util/qdq_helpers/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py b/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py new file mode 100644 index 0000000000..26e2c74707 --- /dev/null +++ b/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +import pathlib +import unittest + +from ..qdq_model_utils import fix_dq_nodes_with_multiple_consumers + +script_dir = pathlib.Path(__file__).parent +ort_root = script_dir.parents[4] + +# example usage from /tools/python +# python -m unittest util/qdq_helpers/test/test_qdq_model_utils.py +# NOTE: at least on Windows you must use that as the working directory for all the imports to be happy + + +class TestQDQUtils(unittest.TestCase): + def test_fix_DQ_with_multiple_consumers(self): + ''' + ''' + model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'qdq_with_multi_consumer_dq_nodes.onnx' + model = onnx.load(str(model_path)) + + orig_dq_nodes = [n for n in model.graph.node if n.op_type == 'DequantizeLinear'] + fix_dq_nodes_with_multiple_consumers(model) + new_dq_nodes = [n for n in model.graph.node if n.op_type == 'DequantizeLinear'] + + # there are 3 DQ nodes with 2 consumers (an earlier Conv and later Add) + # additionally the last one also provides a graph output + # based on that there should be 3 new DQ nodes for the internal consumers and 1 new one for the graph output + self.assertEqual(len(orig_dq_nodes) + 4, len(new_dq_nodes))