mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
add option DefaultTensorType to specify the default tensor type to quantize (#19455)
### Description The current quantization tool relies on shape inference to provide the type of every intermediate tensor, then the tool knows which type it must dequantize into (float32, float16). However, this information is not available if shape inference fails. That happens every time the model include an operator from a custom domain such as com.microsoft. This PR introduces an extra option `DefaultTensorType` as a fall back when the quantizer cannot find the type it needs. ### Motivation and Context This fixes issue #19409.
This commit is contained in:
parent
e832562d70
commit
7efb0dbe12
3 changed files with 115 additions and 5 deletions
|
|
@ -385,7 +385,7 @@ class ONNXQuantizer:
|
|||
def quantize_model(self):
|
||||
if self.has_QDQ_nodes():
|
||||
logging.warning(
|
||||
"Please check if the model is already quantized."
|
||||
"Please check if the model is already quantized. "
|
||||
"Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
|
||||
)
|
||||
|
||||
|
|
@ -442,6 +442,23 @@ class ONNXQuantizer:
|
|||
return False
|
||||
return self.parent.is_valid_quantize_weight(weight_name)
|
||||
|
||||
def _get_default_tensor_type(self, tensor_name):
|
||||
if "DefaultTensorType" in self.extra_options:
|
||||
logging.info(
|
||||
"get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
|
||||
tensor_name,
|
||||
self.extra_options["DefaultTensorType"],
|
||||
)
|
||||
return self.extra_options["DefaultTensorType"]
|
||||
raise RuntimeError(
|
||||
f"Unable to find data type for weight_name={tensor_name!r}. "
|
||||
f"shape_inference failed to return a type probably this node is "
|
||||
f"from a different domain or using an input produced by such an operator. "
|
||||
f"This may happen if you quantize a model already quantized. "
|
||||
f"You may use extra_options `DefaultTensorType` to indicate "
|
||||
f"the default weight type, usually `onnx.TensorProto.FLOAT`."
|
||||
)
|
||||
|
||||
def get_tensor_type(self, tensor_name, mandatory=False):
|
||||
weight = find_by_name(tensor_name, self.model.initializer())
|
||||
if weight is not None:
|
||||
|
|
@ -450,11 +467,11 @@ class ONNXQuantizer:
|
|||
vi = self.value_infos[tensor_name]
|
||||
if vi.type.HasField("tensor_type"):
|
||||
if mandatory and vi.type.tensor_type.elem_type == 0:
|
||||
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
|
||||
return self._get_default_tensor_type(tensor_name)
|
||||
return vi.type.tensor_type.elem_type
|
||||
if (not self.enable_subgraph_quantization) or (self.parent is None):
|
||||
if mandatory:
|
||||
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
|
||||
return self._get_default_tensor_type(tensor_name)
|
||||
return None
|
||||
otype = self.parent.is_valid_quantize_weight(tensor_name)
|
||||
if otype is not None:
|
||||
|
|
@ -464,7 +481,7 @@ class ONNXQuantizer:
|
|||
if res is not None:
|
||||
return res
|
||||
if mandatory:
|
||||
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
|
||||
return self._get_default_tensor_type(tensor_name)
|
||||
return None
|
||||
|
||||
def is_float_tensor(self, tensor_name):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import onnx # noqa: F401
|
||||
import onnx
|
||||
import torch
|
||||
from transformers.modeling_utils import Conv1D
|
||||
|
||||
|
|
@ -69,6 +69,7 @@ class QuantizeHelper:
|
|||
onnx_model_path,
|
||||
quantized_model_path,
|
||||
use_external_data_format=use_external_data_format,
|
||||
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
|
||||
)
|
||||
logger.info(f"quantized model saved to:{quantized_model_path}")
|
||||
# TODO: inlcude external data in total model size.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx.helper as oh
|
||||
import onnx.numpy_helper as onh
|
||||
|
||||
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
|
||||
from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType
|
||||
|
||||
|
||||
class TestQuantizerShapeInference(unittest.TestCase):
|
||||
def test_com_microsoft(self):
|
||||
model = oh.make_model(
|
||||
oh.make_graph(
|
||||
[
|
||||
oh.make_node("MatMul", ["X", "W1"], ["T1"]),
|
||||
oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"),
|
||||
oh.make_node("MatMul", ["T2", "W3"], ["T3"]),
|
||||
oh.make_node("MatMul", ["T3", "W4"], ["Y"]),
|
||||
],
|
||||
"name",
|
||||
[oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])],
|
||||
[oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])],
|
||||
[
|
||||
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"),
|
||||
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"),
|
||||
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"),
|
||||
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"),
|
||||
],
|
||||
),
|
||||
opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)],
|
||||
)
|
||||
model_shaped = onnx.shape_inference.infer_shapes(model)
|
||||
shaped_results = set(t.name for t in model_shaped.graph.value_info)
|
||||
# every result after T1 depends on T2 coming from a node com.microsoft,
|
||||
# shape_inference cannot go beyond this point
|
||||
self.assertEqual(shaped_results, {"T1"})
|
||||
|
||||
# first try: checks it raises an exception
|
||||
quantizer = ONNXQuantizer(
|
||||
model,
|
||||
False, # per_channel
|
||||
False, # reduce_range
|
||||
QuantizationMode.IntegerOps, # mode
|
||||
False, # static
|
||||
QuantType.QInt8, # weight_type,
|
||||
QuantType.QUInt8, # dynamic activation only supports uint8
|
||||
None,
|
||||
[], # nodes_to_quantize,
|
||||
[], # nodes_to_exclude
|
||||
["MatMul"], # op_types_to_quantize,
|
||||
{"MatMulConstBOnly": True}, # extra_options,
|
||||
# {'DefaultTensorType': 1, }
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError) as e:
|
||||
quantizer.quantize_model()
|
||||
self.assertIn("Unable to find data type for weight_name=", str(e))
|
||||
|
||||
# second try: checks it works
|
||||
quantizer = ONNXQuantizer(
|
||||
model,
|
||||
False, # per_channel
|
||||
False, # reduce_range
|
||||
QuantizationMode.IntegerOps, # mode
|
||||
False, # static
|
||||
QuantType.QInt8, # weight_type,
|
||||
QuantType.QUInt8, # dynamic activation only supports uint8
|
||||
None,
|
||||
[], # nodes_to_quantize,
|
||||
[], # nodes_to_exclude
|
||||
["MatMul"], # op_types_to_quantize,
|
||||
{
|
||||
"MatMulConstBOnly": True,
|
||||
"DefaultTensorType": 1,
|
||||
},
|
||||
)
|
||||
|
||||
model = quantizer.quantize_model()
|
||||
ops = {n.op_type for n in model.graph.node}
|
||||
self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Loading…
Reference in a new issue