mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
When quantize 4bit mamtul, force upgrade onnx domain opset to 21 (#21693)
### Description When quantize MatMul to DQ + MatMul using 4bit QDQ tool chain, previously the opsets of domains are not changed. Now, when quantize MatMul to DQ + MatMul in QDQ format, force upgrade onnx domain to opset 21. ### Motivation and Context In QDQ format, DQ with int4 and blocked quantization is used. This requires DQ with opset >= 21. When quantize MatMul to DQ + MatMul, force upgrade onnx domain to opset 21.
This commit is contained in:
parent
c6a73defb8
commit
53a66f4e02
2 changed files with 16 additions and 7 deletions
|
|
@ -712,14 +712,20 @@ class MatMul4BitsQuantizer:
|
|||
if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
|
||||
# use a stack to keep track of sub-graphs
|
||||
graph_stack = [self.model.graph()]
|
||||
opset_import = self.model.opset_import()
|
||||
|
||||
has_ms_domain = False
|
||||
for opset in opset_import:
|
||||
if opset.domain == "com.microsoft":
|
||||
has_ms_domain = True
|
||||
if not has_ms_domain:
|
||||
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
|
||||
# Update domain opset
|
||||
if self.algo_config.quant_format == QuantFormat.QOperator:
|
||||
self.model.set_opset_import("com.microsoft", 1)
|
||||
else:
|
||||
opset_import = self.model.opset_import()
|
||||
for opset in opset_import:
|
||||
if opset.domain in [None, "ai.onnx", ""] and opset.version < 21:
|
||||
logger.warning(
|
||||
"The opset of the input model is under 21 and doesn't support int4 data type. "
|
||||
"Force to update it to opset 21, but the generated model may not be a valid model."
|
||||
)
|
||||
self.model.set_opset_import(opset.domain, 21)
|
||||
|
||||
self._process_subgraph(graph_stack)
|
||||
self.model.clean_initializers()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -156,6 +156,9 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes)
|
||||
for op in quant.model.opset_import():
|
||||
if op.domain in [None, "", "ai.onnx"] and op.version < 21:
|
||||
self.fail(f"In QDQ format {op.domain} opset should be >= 21")
|
||||
|
||||
data_reader.rewind()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue