add option DefaultTensorType to specify the default tensor type to quantize (#19455)

### Description
The current quantization tool relies on shape inference to provide the
type of every intermediate tensor, then the tool knows which type it
must dequantize into (float32, float16). However, this information is
not available if shape inference fails. That happens every time the
model include an operator from a custom domain such as com.microsoft.

This PR introduces an extra option `DefaultTensorType` as a fall back
when the quantizer cannot find the type it needs.

### Motivation and Context
This fixes issue #19409.
This commit is contained in:
Xavier Dupré 2024-02-20 17:22:44 +01:00 committed by GitHub
parent e832562d70
commit 7efb0dbe12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 115 additions and 5 deletions

View file

@ -385,7 +385,7 @@ class ONNXQuantizer:
def quantize_model(self):
if self.has_QDQ_nodes():
logging.warning(
"Please check if the model is already quantized."
"Please check if the model is already quantized. "
"Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
)
@ -442,6 +442,23 @@ class ONNXQuantizer:
return False
return self.parent.is_valid_quantize_weight(weight_name)
def _get_default_tensor_type(self, tensor_name):
if "DefaultTensorType" in self.extra_options:
logging.info(
"get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
tensor_name,
self.extra_options["DefaultTensorType"],
)
return self.extra_options["DefaultTensorType"]
raise RuntimeError(
f"Unable to find data type for weight_name={tensor_name!r}. "
f"shape_inference failed to return a type probably this node is "
f"from a different domain or using an input produced by such an operator. "
f"This may happen if you quantize a model already quantized. "
f"You may use extra_options `DefaultTensorType` to indicate "
f"the default weight type, usually `onnx.TensorProto.FLOAT`."
)
def get_tensor_type(self, tensor_name, mandatory=False):
weight = find_by_name(tensor_name, self.model.initializer())
if weight is not None:
@ -450,11 +467,11 @@ class ONNXQuantizer:
vi = self.value_infos[tensor_name]
if vi.type.HasField("tensor_type"):
if mandatory and vi.type.tensor_type.elem_type == 0:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return vi.type.tensor_type.elem_type
if (not self.enable_subgraph_quantization) or (self.parent is None):
if mandatory:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return None
otype = self.parent.is_valid_quantize_weight(tensor_name)
if otype is not None:
@ -464,7 +481,7 @@ class ONNXQuantizer:
if res is not None:
return res
if mandatory:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return None
def is_float_tensor(self, tensor_name):

View file

@ -7,7 +7,7 @@
import logging
import os
import onnx # noqa: F401
import onnx
import torch
from transformers.modeling_utils import Conv1D
@ -69,6 +69,7 @@ class QuantizeHelper:
onnx_model_path,
quantized_model_path,
use_external_data_format=use_external_data_format,
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
)
logger.info(f"quantized model saved to:{quantized_model_path}")
# TODO: inlcude external data in total model size.

View file

@ -0,0 +1,92 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType
class TestQuantizerShapeInference(unittest.TestCase):
def test_com_microsoft(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("MatMul", ["X", "W1"], ["T1"]),
oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"),
oh.make_node("MatMul", ["T2", "W3"], ["T3"]),
oh.make_node("MatMul", ["T3", "W4"], ["Y"]),
],
"name",
[oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])],
[oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])],
[
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"),
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"),
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"),
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"),
],
),
opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)],
)
model_shaped = onnx.shape_inference.infer_shapes(model)
shaped_results = set(t.name for t in model_shaped.graph.value_info)
# every result after T1 depends on T2 coming from a node com.microsoft,
# shape_inference cannot go beyond this point
self.assertEqual(shaped_results, {"T1"})
# first try: checks it raises an exception
quantizer = ONNXQuantizer(
model,
False, # per_channel
False, # reduce_range
QuantizationMode.IntegerOps, # mode
False, # static
QuantType.QInt8, # weight_type,
QuantType.QUInt8, # dynamic activation only supports uint8
None,
[], # nodes_to_quantize,
[], # nodes_to_exclude
["MatMul"], # op_types_to_quantize,
{"MatMulConstBOnly": True}, # extra_options,
# {'DefaultTensorType': 1, }
)
with self.assertRaises(RuntimeError) as e:
quantizer.quantize_model()
self.assertIn("Unable to find data type for weight_name=", str(e))
# second try: checks it works
quantizer = ONNXQuantizer(
model,
False, # per_channel
False, # reduce_range
QuantizationMode.IntegerOps, # mode
False, # static
QuantType.QInt8, # weight_type,
QuantType.QUInt8, # dynamic activation only supports uint8
None,
[], # nodes_to_quantize,
[], # nodes_to_exclude
["MatMul"], # op_types_to_quantize,
{
"MatMulConstBOnly": True,
"DefaultTensorType": 1,
},
)
model = quantizer.quantize_model()
ops = {n.op_type for n in model.graph.node}
self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"})
if __name__ == "__main__":
unittest.main(verbosity=2)