onnxruntime/tools/python/util/qdq_helpers/optimize_qdq_model.py
Edward Chen 9f942e1a3e
Graph transformer to ensure unique DQ nodes for QDQ node units (#15145)
### Description
<!-- Describe your changes. -->

Add required graph transformer to duplicate DQ nodes to ensure that QDQ
node units have unique DQ nodes. This condition is necessary for QDQ
node unit processing.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

There is an existing Python utility that does this: 

c7ced7a5e9/tools/python/util/qdq_helpers/qdq_model_utils.py (L77)

This PR implements it as a graph transformer so it is integrated into
ORT and does not require a separate step to update the model. There are
also tests to ensure that its effects are not undone by basic level
graph optimizations.
2023-03-31 08:39:43 +10:00

37 lines
1.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import os
import pathlib
import onnx
def optimize_qdq_model():
parser = argparse.ArgumentParser(
os.path.basename(__file__),
description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
)
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
args = parser.parse_args()
model = onnx.load(str(args.input_model.resolve(strict=True)))
# run QDQ model optimizations here
# Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
# That was moved to an ORT graph transformer.
print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
# There are no optimizations being run currently but we expect that there may be in the future.
onnx.save(model, str(args.output_model.resolve()))
if __name__ == "__main__":
optimize_qdq_model()