mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Fix static quantization for QDQ and Percentile distribution (#17649)
### Description One quantization case was not covered by the current list of unit tests. This PR adds a unit test to cover that case with the fix. It fixes the issue #17619. ### Motivation and Context
This commit is contained in:
parent
df15a3a335
commit
905faea3b2
7 changed files with 13909 additions and 7 deletions
|
|
@ -77,7 +77,8 @@ class QLinearConv : public OpKernel {
|
|||
W_zero_point_value = W_zero_point_data[0];
|
||||
for (int64_t i = 1; i < W_zero_point_size; i++) {
|
||||
ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value,
|
||||
"QLinearConv : zero point of per-channel filter must be same");
|
||||
"QLinearConv : zero point of per-channel filter must be same. "
|
||||
"This happens by design if the quantization is symmetric.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distrib
|
|||
|
||||
|
||||
class TensorData:
|
||||
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"])
|
||||
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
|
|
@ -55,7 +55,7 @@ class TensorsData:
|
|||
self.data[k] = TensorData(lowest=v[0], highest=v[1])
|
||||
continue
|
||||
if len(v) == 4:
|
||||
self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3])
|
||||
self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3])
|
||||
continue
|
||||
raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.")
|
||||
if not isinstance(v, TensorData):
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class QLinearConv(QuantOperatorBase):
|
|||
nodes,
|
||||
) = self.quantizer.quantize_activation(node, [0])
|
||||
quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
|
||||
node.input[1], onnx_proto.TensorProto.INT8, 0
|
||||
node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
|
||||
)
|
||||
quantized_input_names.append(quant_weight_tuple[0])
|
||||
zero_point_names.append(quant_weight_tuple[1])
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ class LSTMQuant(QuantOperatorBase):
|
|||
R.dims[0] = R_num_dir * R_4_hidden_size
|
||||
|
||||
quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(
|
||||
node.input[1], onnx_proto.TensorProto.INT8, 0
|
||||
node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
|
||||
)
|
||||
quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(
|
||||
node.input[2], onnx_proto.TensorProto.INT8, 0
|
||||
node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType?
|
||||
)
|
||||
|
||||
W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806
|
||||
|
|
|
|||
|
|
@ -283,7 +283,13 @@ class QDQQuantizer(ONNXQuantizer):
|
|||
raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.")
|
||||
q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel(
|
||||
weight_name,
|
||||
self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType,
|
||||
# Quantization type is forced to be TensorProto.INT8.
|
||||
# when the expected value would be (see below)
|
||||
# self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType.
|
||||
# QLinearConv expects to have a unique value for all channels.
|
||||
# This code does not enforce that but it is necessarily the case when the
|
||||
# quantization is symmetric (as for INT8).
|
||||
onnx_proto.TensorProto.INT8,
|
||||
axis,
|
||||
keep_float_weight=self.add_qdq_pair_to_weight,
|
||||
)
|
||||
|
|
|
|||
13757
onnxruntime/test/python/quantization/resnet_code.py
Normal file
13757
onnxruntime/test/python/quantization/resnet_code.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,138 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from numpy.testing import assert_allclose
|
||||
from onnx.numpy_helper import to_array
|
||||
from resnet_code import create_model
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from onnxruntime import __version__ as ort_version
|
||||
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
|
||||
from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
|
||||
|
||||
|
||||
class FakeResnetCalibrationDataReader(CalibrationDataReader):
|
||||
def __init__(self, batch_size: int = 16):
|
||||
super().__init__()
|
||||
self.dataset = [
|
||||
(np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size)
|
||||
]
|
||||
self.iterator = iter(self.dataset)
|
||||
|
||||
def get_next(self) -> dict:
|
||||
try:
|
||||
return {"input": next(self.iterator)[0]}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class TestStaticQuantizationResNet(unittest.TestCase):
|
||||
def test_quantize_static_resnet(self):
|
||||
kwargs = {
|
||||
"activation_type": QuantType.QUInt8,
|
||||
"weight_type": QuantType.QInt8,
|
||||
"calibrate_method": CalibrationMethod.Percentile,
|
||||
"extra_options": {
|
||||
"ActivationSymmetric": False,
|
||||
"EnableSubgraph": False,
|
||||
"ForceQuantizeNoInputCheck": False,
|
||||
"MatMulConstBOnly": False,
|
||||
"WeightSymmetric": True,
|
||||
"extra.Sigmoid.nnapi": False,
|
||||
},
|
||||
"nodes_to_exclude": None,
|
||||
"nodes_to_quantize": None,
|
||||
"op_types_to_quantize": None,
|
||||
"per_channel": True,
|
||||
"quant_format": QuantFormat.QDQ,
|
||||
"reduce_range": False,
|
||||
}
|
||||
|
||||
proto = create_model()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
model = os.path.join(temp, "resnet_first_nodes.onnx")
|
||||
with open(model, "wb") as f:
|
||||
f.write(proto.SerializeToString())
|
||||
|
||||
for per_channel in [True, False]:
|
||||
kwargs["per_channel"] = per_channel
|
||||
dataloader = FakeResnetCalibrationDataReader(16)
|
||||
with self.subTest(per_channel=per_channel):
|
||||
qdq_file = os.path.join(
|
||||
temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx"
|
||||
)
|
||||
quantize_static(
|
||||
model_input=model,
|
||||
model_output=qdq_file,
|
||||
calibration_data_reader=dataloader,
|
||||
use_external_data_format=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is:
|
||||
# * uint8(128) if per_channel is False
|
||||
# * int8([0, 0, ....]) if per_channel is True
|
||||
# With onnxruntime>1.16.0
|
||||
# * uint8(128) if per_channel is False
|
||||
# * uint8([128, 128, ..., 127, ...]) if per_channel is True
|
||||
# QLinearConv : zero point of per-channel filter must be same.
|
||||
# That's why the quantization forces a symmetric quantization into INT8.
|
||||
# zero_point is guaranted to be zero whatever the channel is.
|
||||
|
||||
with open(qdq_file, "rb") as f:
|
||||
onx = onnx.load(f)
|
||||
for init in onx.graph.initializer:
|
||||
arr = to_array(init)
|
||||
if (
|
||||
arr.dtype == np.int8
|
||||
and "zero_point" not in init.name
|
||||
and not init.name.endswith("quantized")
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Initializer {init.name!r} has type {arr.dtype} and "
|
||||
f"shape {arr.shape} but should be {np.uint8}."
|
||||
)
|
||||
|
||||
sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"])
|
||||
shape = (1, 3, 32, 32)
|
||||
size = np.prod(shape)
|
||||
dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape)
|
||||
got = sess.run(None, {"input": dummy})
|
||||
self.assertEqual(got[0].shape, (1, 64, 8, 8))
|
||||
self.assertEqual(got[0].dtype, np.float32)
|
||||
if per_channel:
|
||||
expected = np.array(
|
||||
[
|
||||
[[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]],
|
||||
[[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]],
|
||||
[[0.0, 0.0], [0.0, 0.0]],
|
||||
[[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2)
|
||||
else:
|
||||
expected = np.array(
|
||||
[
|
||||
[[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]],
|
||||
[[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]],
|
||||
[[0.0, 0.0], [0.0, 0.0]],
|
||||
[[0.0, 0.0], [1.2182037830352783, 1.050175666809082]],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Loading…
Reference in a new issue