mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[Quant tool] Ensure MSFT opset for Q/DQ models (#19335)
### Description Updates qdq quantization to ensure the final model has the `com.microsoft` opset import if the model uses Q/DQ ops with the `com.microsoft` domain (e.g., for int16 quantization) ### Motivation and Context Need to ensure the MSFT domain is correctly set for all relevant cases. Otherwise, shape inferencing tools will raise an exception.
This commit is contained in:
parent
90883a366a
commit
0c38e96bb5
2 changed files with 9 additions and 0 deletions
|
|
@ -270,6 +270,8 @@ class QDQQuantizer(ONNXQuantizer):
|
|||
|
||||
self.model.model.producer_name = __producer__
|
||||
self.model.model.producer_version = __version__
|
||||
if self.qdq_op_domain == ms_domain:
|
||||
self.model.set_opset_import(ms_domain, 1)
|
||||
|
||||
return self.model.model
|
||||
|
||||
|
|
|
|||
|
|
@ -601,6 +601,13 @@ class TestQDQFormatConvRelu(TestQDQFormat):
|
|||
)
|
||||
check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next())
|
||||
|
||||
# If the model uses Q/DQ ops with "com.microsoft" domain (e.g., for int16 support),
|
||||
# then ensure the model has the appropriate opset import.
|
||||
if extra_options and extra_options.get("UseQDQContribOps", False):
|
||||
qdq_model = onnx.load_model(model_qdq_path)
|
||||
ms_opset = next((opset for opset in qdq_model.opset_import if opset.domain == "com.microsoft"), None)
|
||||
self.assertIsNot(ms_opset, None)
|
||||
|
||||
def verify_qop(self, per_channel, is_quant_type_int8):
|
||||
np.random.seed(1)
|
||||
model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx")
|
||||
|
|
|
|||
Loading…
Reference in a new issue