[QNN Quant] Ensure 16bit tensor quant overrides set MS domain (#19684)

### Description
Ensures that DQ and Q ops use the msft domain if tensor quantization
overrides specify 16-bit integer types.

### Motivation and Context
ONNX does not yet support 16bit integer types for QuantizeLinear and
DequantizeLinear ops (coming soon). For now, DQ/Q ops must use the MSFT
domain.

We have to also check if tensor quantization overrides force the use of
16-bit quantization types. If so, we must correctly set the domain for
Q/DQ ops.
This commit is contained in:
Adrian Lizarraga 2024-02-29 01:19:25 -08:00 committed by GitHub
parent d2e6dd25ea
commit c1bf7fcd2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 42 additions and 6 deletions

View file

@ -154,7 +154,7 @@ class ONNXQuantizer:
if self.mode not in QuantizationMode:
raise ValueError(f"unsupported quantization mode {self.mode}")
self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides()
self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides()
self.quantization_params = self.calculate_quantization_params()
# QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
@ -177,8 +177,10 @@ class ONNXQuantizer:
def _get_and_check_tensor_quant_overrides(self):
"""
Get tensor quantization overrides and check correctness.
Also returns a set of quantization types (as TensorProto) specified across all overrides.
"""
tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {})
tensor_quant_override_types = set()
# Validate that compatible/valid overrides are provided.
if tensor_quant_overrides:
@ -211,6 +213,8 @@ class ONNXQuantizer:
# other channels.
if index == 0:
quant_type = quant_overrides.get("quant_type")
if quant_type is not None:
tensor_quant_override_types.add(quant_type.tensor_type)
elif quant_type != quant_overrides.get("quant_type"):
raise ValueError(
"Channel quantization types for tensor '{tensor_name}' do not match at index {index}."
@ -231,7 +235,7 @@ class ONNXQuantizer:
f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'"
)
return tensor_quant_overrides
return tensor_quant_overrides, tensor_quant_override_types
def get_per_tensor_quant_overrides(self, tensor_name):
quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}])
@ -747,8 +751,7 @@ class ONNXQuantizer:
raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
scale_values = np.array([params["scale"]])
assert scale_values.dtype != np.float64
# zero_point_type = params["quant_type"]
assert zero_point_type == params["quant_type"]
zero_point_type = params["quant_type"]
else:
zero_point_values = np.array([use_zeropoint])
scale_values = np.array([use_scale])

View file

@ -116,7 +116,10 @@ class QDQQuantizer(ONNXQuantizer):
# if the activation or weight types are 16-bit integers.
# TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support.
int16_types = (TensorProto.UINT16, TensorProto.INT16)
if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types):
overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types)
if not self.qdq_op_domain and (
self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16
):
logging.warning(
"ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. "
f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "

View file

@ -13,7 +13,7 @@ import onnx
from onnxruntime import quantization
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType
from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain
class DummyDataReader(quantization.CalibrationDataReader):
@ -423,6 +423,36 @@ class TestTensorQuantOverridesOption(unittest.TestCase):
self.assertEqual(zp, expected_zp)
self.assertEqual(scale, np.float32(expected_scale))
def test_16bit_overrides_set_ms_domain(self):
"""
Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft'
domain on DQ and Q ops.
"""
qdq_model_name = "model_quant_overrides_to_16bit.onnx"
inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization(
qdq_model_name,
activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations
extra_options={
"TensorQuantOverrides": {
"INP": [{"quant_type": quantization.QuantType.QUInt16}],
"SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}],
}
},
)
# Input and Sigmoid's output should be overridden to 16bit
self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16)
self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
# Output should the default uint8 type
self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8)
# Q/DQ ops should all have the 'com.microsoft' domain
qdq_model = onnx.load_model(qdq_model_name)
for node in qdq_model.graph.node:
if node.op_type in {"QuantizeLinear", "DequantizeLinear"}:
self.assertEqual(node.domain, ms_domain)
def test_override_validation_nonexisting_tensor(self):
"""
Test that specifying a non-existing tensor should fail.