Fix static quantization for QDQ and Percentile distribution (#17649)

### Description
One quantization case was not covered by the current list of unit tests.
This PR adds a unit test to cover that case with the fix. It fixes the
issue #17619.



### Motivation and Context
This commit is contained in:
Xavier Dupré 2023-09-25 19:11:58 +02:00 committed by GitHub
parent df15a3a335
commit 905faea3b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 13909 additions and 7 deletions

View file

@ -77,7 +77,8 @@ class QLinearConv : public OpKernel {
W_zero_point_value = W_zero_point_data[0];
for (int64_t i = 1; i < W_zero_point_size; i++) {
ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value,
"QLinearConv : zero point of per-channel filter must be same");
"QLinearConv : zero point of per-channel filter must be same. "
"This happens by design if the quantization is symmetric.");
}
}

View file

@ -22,7 +22,7 @@ from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distrib
class TensorData:
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"])
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
def __init__(self, **kwargs):
for k, v in kwargs.items():
@ -55,7 +55,7 @@ class TensorsData:
self.data[k] = TensorData(lowest=v[0], highest=v[1])
continue
if len(v) == 4:
self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3])
self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3])
continue
raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.")
if not isinstance(v, TensorData):

View file

@ -157,7 +157,7 @@ class QLinearConv(QuantOperatorBase):
nodes,
) = self.quantizer.quantize_activation(node, [0])
quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[1], onnx_proto.TensorProto.INT8, 0
node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
)
quantized_input_names.append(quant_weight_tuple[0])
zero_point_names.append(quant_weight_tuple[1])

View file

@ -47,10 +47,10 @@ class LSTMQuant(QuantOperatorBase):
R.dims[0] = R_num_dir * R_4_hidden_size
quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[1], onnx_proto.TensorProto.INT8, 0
node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
)
quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[2], onnx_proto.TensorProto.INT8, 0
node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
)
W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806

View file

@ -283,7 +283,13 @@ class QDQQuantizer(ONNXQuantizer):
raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.")
q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel(
weight_name,
self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType,
# Quantization type is forced to be TensorProto.INT8.
# when the expected value would be (see below)
# self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType.
# QLinearConv expects to have a unique value for all channels.
# This code does not enforce that but it is necessarily the case when the
# quantization is symmetric (as for INT8).
onnx_proto.TensorProto.INT8,
axis,
keep_float_weight=self.add_qdq_pair_to_weight,
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,138 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import random
import tempfile
import unittest
import numpy as np
import onnx
from numpy.testing import assert_allclose
from onnx.numpy_helper import to_array
from resnet_code import create_model
from onnxruntime import InferenceSession
from onnxruntime import __version__ as ort_version
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
class FakeResnetCalibrationDataReader(CalibrationDataReader):
def __init__(self, batch_size: int = 16):
super().__init__()
self.dataset = [
(np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size)
]
self.iterator = iter(self.dataset)
def get_next(self) -> dict:
try:
return {"input": next(self.iterator)[0]}
except Exception:
return None
class TestStaticQuantizationResNet(unittest.TestCase):
def test_quantize_static_resnet(self):
kwargs = {
"activation_type": QuantType.QUInt8,
"weight_type": QuantType.QInt8,
"calibrate_method": CalibrationMethod.Percentile,
"extra_options": {
"ActivationSymmetric": False,
"EnableSubgraph": False,
"ForceQuantizeNoInputCheck": False,
"MatMulConstBOnly": False,
"WeightSymmetric": True,
"extra.Sigmoid.nnapi": False,
},
"nodes_to_exclude": None,
"nodes_to_quantize": None,
"op_types_to_quantize": None,
"per_channel": True,
"quant_format": QuantFormat.QDQ,
"reduce_range": False,
}
proto = create_model()
with tempfile.TemporaryDirectory() as temp:
model = os.path.join(temp, "resnet_first_nodes.onnx")
with open(model, "wb") as f:
f.write(proto.SerializeToString())
for per_channel in [True, False]:
kwargs["per_channel"] = per_channel
dataloader = FakeResnetCalibrationDataReader(16)
with self.subTest(per_channel=per_channel):
qdq_file = os.path.join(
temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx"
)
quantize_static(
model_input=model,
model_output=qdq_file,
calibration_data_reader=dataloader,
use_external_data_format=False,
**kwargs,
)
# With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is:
# * uint8(128) if per_channel is False
# * int8([0, 0, ....]) if per_channel is True
# With onnxruntime>1.16.0
# * uint8(128) if per_channel is False
# * uint8([128, 128, ..., 127, ...]) if per_channel is True
# QLinearConv : zero point of per-channel filter must be same.
# That's why the quantization forces a symmetric quantization into INT8.
# zero_point is guaranted to be zero whatever the channel is.
with open(qdq_file, "rb") as f:
onx = onnx.load(f)
for init in onx.graph.initializer:
arr = to_array(init)
if (
arr.dtype == np.int8
and "zero_point" not in init.name
and not init.name.endswith("quantized")
):
raise AssertionError(
f"Initializer {init.name!r} has type {arr.dtype} and "
f"shape {arr.shape} but should be {np.uint8}."
)
sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"])
shape = (1, 3, 32, 32)
size = np.prod(shape)
dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape)
got = sess.run(None, {"input": dummy})
self.assertEqual(got[0].shape, (1, 64, 8, 8))
self.assertEqual(got[0].dtype, np.float32)
if per_channel:
expected = np.array(
[
[[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]],
[[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]],
[[0.0, 0.0], [0.0, 0.0]],
[[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]],
],
dtype=np.float32,
)
assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2)
else:
expected = np.array(
[
[[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]],
[[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]],
[[0.0, 0.0], [0.0, 0.0]],
[[0.0, 0.0], [1.2182037830352783, 1.050175666809082]],
],
dtype=np.float32,
)
assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2)
if __name__ == "__main__":
unittest.main(verbosity=2)