mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Description
<!-- Describe your changes. -->
Add required graph transformer to duplicate DQ nodes to ensure that QDQ
node units have unique DQ nodes. This condition is necessary for QDQ
node unit processing.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
There is an existing Python utility that does this:
c7ced7a5e9/tools/python/util/qdq_helpers/qdq_model_utils.py (L77)
This PR implements it as a graph transformer so it is integrated into
ORT and does not require a separate step to update the model. There are
also tests to ensure that its effects are not undone by basic level
graph optimizations.
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
import os
|
|
import pathlib
|
|
|
|
import onnx
|
|
|
|
|
|
def optimize_qdq_model():
|
|
parser = argparse.ArgumentParser(
|
|
os.path.basename(__file__),
|
|
description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
|
|
)
|
|
|
|
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
|
|
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
model = onnx.load(str(args.input_model.resolve(strict=True)))
|
|
|
|
# run QDQ model optimizations here
|
|
|
|
# Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
|
|
# That was moved to an ORT graph transformer.
|
|
print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
|
|
|
|
# There are no optimizations being run currently but we expect that there may be in the future.
|
|
|
|
onnx.save(model, str(args.output_model.resolve()))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
optimize_qdq_model()
|