Add changes for strided calibration (#20949)

Context and motivation:
When quantizing large transformer models, we faced OOM issue when the
number of calibration samples goes up. To resolve this, in the PR we
want to add support for reading quantization data in chunck, calculating
ranges for intermediate tensors, then accumulating results for the final
ranges.
This commit is contained in:
RuomeiMS 2024-06-21 16:23:23 +01:00 committed by GitHub
parent f5625b8858
commit 7cf9263ee7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 222 additions and 9 deletions

View file

@ -128,6 +128,9 @@ class TensorsData:
def values(self):
return self.data.values()
def items(self):
return self.data.items()
class CalibrationMethod(Enum):
MinMax = 0
@ -155,6 +158,12 @@ class CalibrationDataReader(metaclass=abc.ABCMeta):
raise StopIteration
return result
def __len__(self):
raise NotImplementedError
def set_range(self, start_index: int, end_index: int):
raise NotImplementedError
class CalibraterBase:
def __init__(
@ -409,13 +418,31 @@ class MinMaxCalibrater(CalibraterBase):
return new_range
for key, value in old_range.items():
if self.moving_average:
min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0])
max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1])
# Handling for structured data types with TensorData
if isinstance(value, TensorData):
old_min = value.range_value[0]
old_max = value.range_value[1]
else:
min_value = min(value[0], new_range[key][0])
max_value = max(value[1], new_range[key][1])
new_range[key] = (min_value, max_value)
old_min, old_max = value
if isinstance(new_range[key], TensorData):
new_min = new_range[key].range_value[0]
new_max = new_range[key].range_value[1]
else:
new_min, new_max = new_range[key]
if self.moving_average:
min_value = old_min + self.averaging_constant * (new_min - old_min)
max_value = old_max + self.averaging_constant * (new_max - old_max)
else:
min_value = min(old_min, new_min)
max_value = max(old_max, new_max)
# If structured as TensorData, wrap the result accordingly
if isinstance(value, TensorData) or isinstance(new_range[key], TensorData):
new_range[key] = TensorData(lowest=min_value, highest=max_value)
else:
new_range[key] = (min_value, max_value)
return new_range

View file

@ -52,6 +52,7 @@ def get_qnn_qdq_config(
activation_symmetric: bool = False,
weight_symmetric: bool | None = None,
keep_removable_activations: bool = False,
stride: int | None = None,
) -> StaticQuantConfig:
"""
Returns a static quantization configuration suitable for running QDQ models on QNN EP.
@ -171,6 +172,7 @@ def get_qnn_qdq_config(
"TensorQuantOverrides": overrides_helper.get_dict(),
"ActivationSymmetric": activation_symmetric,
"WeightSymmetric": weight_symmetric,
"CalibStridedMinMax": stride,
}
# ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain

View file

@ -381,6 +381,9 @@ def quantize_static(
CalibTensorRangeSymmetric = True/False :
Default is False. If enabled, the final range of tensor during calibration will be explicitly
set to symmetric to central point "0".
CalibStridedMinMax = Optional[int] :
Default is None. If set to an integer, during calculation of the min-max, only stride amount of
data will be used and then all results will be merged in the end.
CalibMovingAverage = True/False :
Default is False. If enabled, the moving average of the minimum and maximum values will be
computed when the calibration method selected is MinMax.
@ -522,7 +525,19 @@ def quantize_static(
use_external_data_format=use_external_data_format,
extra_options=calib_extra_options,
)
calibrator.collect_data(calibration_data_reader)
stride = extra_options.get("CalibStridedMinMax", None)
if stride:
total_data_size = len(calibration_data_reader)
if total_data_size % stride != 0:
raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).")
for start in range(0, total_data_size, stride):
end_index = start + stride
calibration_data_reader.set_range(start_index=start, end_index=end_index)
calibrator.collect_data(calibration_data_reader)
else:
calibrator.collect_data(calibration_data_reader)
tensors_range = calibrator.compute_data()
if not isinstance(tensors_range, TensorsData):
raise TypeError(

View file

@ -217,10 +217,13 @@ class TestDataFeeds(CalibrationDataReader):
self.iter_next = iter(self.data_feeds)
def input_feeds_neg_one_zero_one(n, name2shape):
def input_feeds_neg_one_zero_one(n, name2shape, seed=None):
"""
randomize n feed according to shape, its values are from -1, 0, and 1
"""
if seed is not None:
np.random.seed(seed)
input_data_list = []
for _i in range(n):
inputs = {}
@ -231,6 +234,120 @@ def input_feeds_neg_one_zero_one(n, name2shape):
return dr
def input_feeds_neg_one_zero_one_list(n, name2shape, seed=None):
"""
randomize n feed according to shape, its values are from -1, 0, and 1
"""
if seed is not None:
np.random.seed(seed)
input_data_list = []
for _i in range(n):
inputs = {}
for name, shape in name2shape.items():
inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)})
input_data_list.extend([inputs])
return input_data_list
class GenerateCalibrationData(CalibrationDataReader):
def __init__(self, data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last=None):
print("Generating calibration dataset from " + str(data_list))
print("input nodes are ", input_nodes, "input shapes are ", input_shapes)
if inputs_conv_channel_last:
print(f"Inputs that will be converted to channel last: {inputs_conv_channel_last}")
self.enum_data_dicts = []
self.input_nodes = input_nodes
self.input_shapes = input_shapes
self.inputs_conv_channel_last = inputs_conv_channel_last
self.calibration_dataset = data_list
def __len__(self):
return len(self.calibration_dataset)
def get_next(self):
feed_dict = {}
inp = next(self.calibration_dataset, None)
if inp is not None:
for i in range(len(self.input_nodes)):
input_data = inp[i].reshape(self.input_shapes[i])
if self.inputs_conv_channel_last is not None and self.input_nodes[i] in self.inputs_conv_channel_last:
input_data = np.moveaxis(input_data, 1, -1)
dict_item = {self.input_nodes[i]: input_data}
feed_dict.update(dict_item)
return feed_dict
else:
return None
class StridedDataReader(GenerateCalibrationData):
def __init__(
self,
data_list,
input_nodes,
input_shapes,
no_tensor_num,
in_dtypes,
inputs_conv_channel_last=None,
stride=1,
start_index=0,
end_index=None,
):
super().__init__(data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last)
self.stride = max(1, stride) # Ensure stride is at least 1
self.start_index = start_index
self.end_index = (
end_index if end_index is not None else len(self.calibration_dataset)
) # Default to the end of the dataset
self.enum_data_dicts = iter([])
def get_next(self):
iter_data = next(self.enum_data_dicts, None)
if iter_data:
return iter_data
self.enum_data_dicts = None
if self.start_index < self.end_index:
print(f"start index is {self.start_index}")
data = self.load_serial()
self.start_index += self.stride
self.enum_data_dicts = iter(data)
return next(self.enum_data_dicts, None)
else:
return None
def load_serial(self):
batch_data = []
end_loop = min(self.end_index, self.start_index + self.stride)
for i in range(self.start_index, end_loop):
print(f"debugging the load serial index {i}")
data_item = self.calibration_dataset[i]
processed_item = self.process_data_item(data_item)
batch_data.append(processed_item)
return batch_data
def process_data_item(self, data_item):
feed_dict = {}
for _, node in enumerate(self.input_nodes):
# input_data = data_item[i].reshape(self.input_shapes[i])
feed_dict[node] = data_item["input"]
return feed_dict
def set_range(self, start_index, end_index=None):
self.start_index = start_index
self.end_index = end_index if end_index is not None else len(self.calibration_dataset)
self.enum_data_dicts = iter([])
def rewind(self):
"""Rewind the data reader to the beginning of the dataset."""
self.start_index = 0
self.enum_data_dicts = iter([])
def check_op_type_order(testcase, model_to_check, ops):
if isinstance(model_to_check, str):
model = onnx.load(model_to_check)

View file

@ -13,8 +13,15 @@ from pathlib import Path
import numpy as np
import onnx
from onnx import TensorProto, helper
from op_test_utils import check_model_correctness, generate_random_initializer, input_feeds_neg_one_zero_one
from op_test_utils import (
StridedDataReader,
check_model_correctness,
generate_random_initializer,
input_feeds_neg_one_zero_one,
input_feeds_neg_one_zero_one_list,
)
import onnxruntime as ort
from onnxruntime.quantization import QuantType, StaticQuantConfig, quantize, quantize_static
@ -89,6 +96,51 @@ class TestStaticQuantization(unittest.TestCase):
check_model_correctness(self, self._model_fp32_path, quant_model_path, data_reader.get_next())
data_reader.rewind()
def run_inference(self, model_path, input_data):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: input_data})
return result
def test_stride_effect_on_data_collection(self):
# Define the stride and test quantize_static with different stride values
stride = 5
input_shapes = [1, self._channel_size, 1, 3]
data_list = input_feeds_neg_one_zero_one_list(10, {"input": [1, self._channel_size, 1, 3]}, 123)
input_nodes = ["input"]
in_dtypes = [np.float32] # Example dtype, adjust as needed
# strided calibration
quant_model_path_1 = str(Path(self._tmp_model_dir.name) / "quant.strided.onnx")
data_reader_1 = StridedDataReader(
data_list, input_nodes, input_shapes, no_tensor_num=0, in_dtypes=in_dtypes, stride=stride
)
quant_config_1 = StaticQuantConfig(data_reader_1, extra_options={"CalibStridedMinMax": stride})
quantize(str(self._model_fp32_path), str(quant_model_path_1), quant_config_1)
# non-strided calibration
quant_model_path_2 = str(Path(self._tmp_model_dir.name) / "quant.non.strided.onnx")
data_reader_2 = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}, 123)
quant_config_2 = StaticQuantConfig(data_reader_2)
quantize(str(self._model_fp32_path), str(quant_model_path_2), quant_config_2)
# Inference with both models and assert output closeness
np.random.seed(123)
input_data = np.random.choice([-1, 0, 1], size=[1, self._channel_size, 1, 3]).astype(np.float32)
result_1 = self.run_inference(quant_model_path_1, input_data)
result_2 = self.run_inference(quant_model_path_2, input_data)
# Assert that the outputs are close
np.testing.assert_allclose(
result_1,
result_2,
rtol=0.01,
atol=0.01,
err_msg="Outputs from strided and non-strided models are not close enough.",
)
def test_static_quant_config(self):
data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]})
quant_config = StaticQuantConfig(data_reader)