mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
[QNN Quant] Ensure 16bit tensor quant overrides set MS domain (#19684)
### Description Ensures that DQ and Q ops use the msft domain if tensor quantization overrides specify 16-bit integer types. ### Motivation and Context ONNX does not yet support 16bit integer types for QuantizeLinear and DequantizeLinear ops (coming soon). For now, DQ/Q ops must use the MSFT domain. We have to also check if tensor quantization overrides force the use of 16-bit quantization types. If so, we must correctly set the domain for Q/DQ ops.
This commit is contained in:
parent
d2e6dd25ea
commit
c1bf7fcd2f
3 changed files with 42 additions and 6 deletions
|
|
@ -154,7 +154,7 @@ class ONNXQuantizer:
|
|||
if self.mode not in QuantizationMode:
|
||||
raise ValueError(f"unsupported quantization mode {self.mode}")
|
||||
|
||||
self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides()
|
||||
self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides()
|
||||
self.quantization_params = self.calculate_quantization_params()
|
||||
|
||||
# QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
|
||||
|
|
@ -177,8 +177,10 @@ class ONNXQuantizer:
|
|||
def _get_and_check_tensor_quant_overrides(self):
|
||||
"""
|
||||
Get tensor quantization overrides and check correctness.
|
||||
Also returns a set of quantization types (as TensorProto) specified across all overrides.
|
||||
"""
|
||||
tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {})
|
||||
tensor_quant_override_types = set()
|
||||
|
||||
# Validate that compatible/valid overrides are provided.
|
||||
if tensor_quant_overrides:
|
||||
|
|
@ -211,6 +213,8 @@ class ONNXQuantizer:
|
|||
# other channels.
|
||||
if index == 0:
|
||||
quant_type = quant_overrides.get("quant_type")
|
||||
if quant_type is not None:
|
||||
tensor_quant_override_types.add(quant_type.tensor_type)
|
||||
elif quant_type != quant_overrides.get("quant_type"):
|
||||
raise ValueError(
|
||||
"Channel quantization types for tensor '{tensor_name}' do not match at index {index}."
|
||||
|
|
@ -231,7 +235,7 @@ class ONNXQuantizer:
|
|||
f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'"
|
||||
)
|
||||
|
||||
return tensor_quant_overrides
|
||||
return tensor_quant_overrides, tensor_quant_override_types
|
||||
|
||||
def get_per_tensor_quant_overrides(self, tensor_name):
|
||||
quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}])
|
||||
|
|
@ -747,8 +751,7 @@ class ONNXQuantizer:
|
|||
raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
|
||||
scale_values = np.array([params["scale"]])
|
||||
assert scale_values.dtype != np.float64
|
||||
# zero_point_type = params["quant_type"]
|
||||
assert zero_point_type == params["quant_type"]
|
||||
zero_point_type = params["quant_type"]
|
||||
else:
|
||||
zero_point_values = np.array([use_zeropoint])
|
||||
scale_values = np.array([use_scale])
|
||||
|
|
|
|||
|
|
@ -116,7 +116,10 @@ class QDQQuantizer(ONNXQuantizer):
|
|||
# if the activation or weight types are 16-bit integers.
|
||||
# TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support.
|
||||
int16_types = (TensorProto.UINT16, TensorProto.INT16)
|
||||
if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types):
|
||||
overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types)
|
||||
if not self.qdq_op_domain and (
|
||||
self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16
|
||||
):
|
||||
logging.warning(
|
||||
"ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. "
|
||||
f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import onnx
|
|||
|
||||
from onnxruntime import quantization
|
||||
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config
|
||||
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType
|
||||
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain
|
||||
|
||||
|
||||
class DummyDataReader(quantization.CalibrationDataReader):
|
||||
|
|
@ -423,6 +423,36 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
|
|||
self.assertEqual(zp, expected_zp)
|
||||
self.assertEqual(scale, np.float32(expected_scale))
|
||||
|
||||
def test_16bit_overrides_set_ms_domain(self):
|
||||
"""
|
||||
Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft'
|
||||
domain on DQ and Q ops.
|
||||
"""
|
||||
qdq_model_name = "model_quant_overrides_to_16bit.onnx"
|
||||
inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization(
|
||||
qdq_model_name,
|
||||
activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations
|
||||
extra_options={
|
||||
"TensorQuantOverrides": {
|
||||
"INP": [{"quant_type": quantization.QuantType.QUInt16}],
|
||||
"SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Input and Sigmoid's output should be overridden to 16bit
|
||||
self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16)
|
||||
self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
|
||||
|
||||
# Output should the default uint8 type
|
||||
self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8)
|
||||
|
||||
# Q/DQ ops should all have the 'com.microsoft' domain
|
||||
qdq_model = onnx.load_model(qdq_model_name)
|
||||
for node in qdq_model.graph.node:
|
||||
if node.op_type in {"QuantizeLinear", "DequantizeLinear"}:
|
||||
self.assertEqual(node.domain, ms_domain)
|
||||
|
||||
def test_override_validation_nonexisting_tensor(self):
|
||||
"""
|
||||
Test that specifying a non-existing tensor should fail.
|
||||
|
|
|
|||
Loading…
Reference in a new issue