When quantize 4bit mamtul, force upgrade onnx domain opset to 21 (#21693)

### Description
When quantize MatMul to DQ + MatMul using 4bit QDQ tool chain,
previously the opsets of domains are not changed.
Now, when quantize MatMul to DQ + MatMul in QDQ format, force upgrade
onnx domain to opset 21.

### Motivation and Context
In QDQ format, DQ with int4 and blocked quantization is used. This
requires DQ with opset >= 21.
When quantize MatMul to DQ + MatMul, force upgrade onnx domain to opset
21.
This commit is contained in:
Jing Fang 2024-08-09 13:50:12 -07:00 committed by GitHub
parent c6a73defb8
commit 53a66f4e02
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 7 deletions

View file

@ -712,14 +712,20 @@ class MatMul4BitsQuantizer:
if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
# use a stack to keep track of sub-graphs
graph_stack = [self.model.graph()]
opset_import = self.model.opset_import()
has_ms_domain = False
for opset in opset_import:
if opset.domain == "com.microsoft":
has_ms_domain = True
if not has_ms_domain:
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
# Update domain opset
if self.algo_config.quant_format == QuantFormat.QOperator:
self.model.set_opset_import("com.microsoft", 1)
else:
opset_import = self.model.opset_import()
for opset in opset_import:
if opset.domain in [None, "ai.onnx", ""] and opset.version < 21:
logger.warning(
"The opset of the input model is under 21 and doesn't support int4 data type. "
"Force to update it to opset 21, but the generated model may not be a valid model."
)
self.model.set_opset_import(opset.domain, 21)
self._process_subgraph(graph_stack)
self.model.clean_initializers()
else:

View file

@ -156,6 +156,9 @@ class TestOpMatMul4Bits(unittest.TestCase):
}
)
check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes)
for op in quant.model.opset_import():
if op.domain in [None, "", "ai.onnx"] and op.version < 21:
self.fail(f"In QDQ format {op.domain} opset should be >= 21")
data_reader.rewind()