diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index cc8bd622df..c0cc4f038c 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -712,14 +712,20 @@ class MatMul4BitsQuantizer: if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] - opset_import = self.model.opset_import() - has_ms_domain = False - for opset in opset_import: - if opset.domain == "com.microsoft": - has_ms_domain = True - if not has_ms_domain: - opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + # Update domain opset + if self.algo_config.quant_format == QuantFormat.QOperator: + self.model.set_opset_import("com.microsoft", 1) + else: + opset_import = self.model.opset_import() + for opset in opset_import: + if opset.domain in [None, "ai.onnx", ""] and opset.version < 21: + logger.warning( + "The opset of the input model is under 21 and doesn't support int4 data type. " + "Force to update it to opset 21, but the generated model may not be a valid model." + ) + self.model.set_opset_import(opset.domain, 21) + self._process_subgraph(graph_stack) self.model.clean_initializers() else: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 4cc8a0c151..0438d93227 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -156,6 +156,9 @@ class TestOpMatMul4Bits(unittest.TestCase): } ) check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes) + for op in quant.model.opset_import(): + if op.domain in [None, "", "ai.onnx"] and op.version < 21: + self.fail(f"In QDQ format {op.domain} opset should be >= 21") data_reader.rewind()