mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add changes for strided calibration (#20949)
Context and motivation: When quantizing large transformer models, we faced OOM issue when the number of calibration samples goes up. To resolve this, in the PR we want to add support for reading quantization data in chunck, calculating ranges for intermediate tensors, then accumulating results for the final ranges.
This commit is contained in:
parent
f5625b8858
commit
7cf9263ee7
5 changed files with 222 additions and 9 deletions
|
|
@ -128,6 +128,9 @@ class TensorsData:
|
|||
def values(self):
|
||||
return self.data.values()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
|
||||
class CalibrationMethod(Enum):
|
||||
MinMax = 0
|
||||
|
|
@ -155,6 +158,12 @@ class CalibrationDataReader(metaclass=abc.ABCMeta):
|
|||
raise StopIteration
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_range(self, start_index: int, end_index: int):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CalibraterBase:
|
||||
def __init__(
|
||||
|
|
@ -409,13 +418,31 @@ class MinMaxCalibrater(CalibraterBase):
|
|||
return new_range
|
||||
|
||||
for key, value in old_range.items():
|
||||
if self.moving_average:
|
||||
min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0])
|
||||
max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1])
|
||||
# Handling for structured data types with TensorData
|
||||
if isinstance(value, TensorData):
|
||||
old_min = value.range_value[0]
|
||||
old_max = value.range_value[1]
|
||||
else:
|
||||
min_value = min(value[0], new_range[key][0])
|
||||
max_value = max(value[1], new_range[key][1])
|
||||
new_range[key] = (min_value, max_value)
|
||||
old_min, old_max = value
|
||||
|
||||
if isinstance(new_range[key], TensorData):
|
||||
new_min = new_range[key].range_value[0]
|
||||
new_max = new_range[key].range_value[1]
|
||||
else:
|
||||
new_min, new_max = new_range[key]
|
||||
|
||||
if self.moving_average:
|
||||
min_value = old_min + self.averaging_constant * (new_min - old_min)
|
||||
max_value = old_max + self.averaging_constant * (new_max - old_max)
|
||||
else:
|
||||
min_value = min(old_min, new_min)
|
||||
max_value = max(old_max, new_max)
|
||||
|
||||
# If structured as TensorData, wrap the result accordingly
|
||||
if isinstance(value, TensorData) or isinstance(new_range[key], TensorData):
|
||||
new_range[key] = TensorData(lowest=min_value, highest=max_value)
|
||||
else:
|
||||
new_range[key] = (min_value, max_value)
|
||||
|
||||
return new_range
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ def get_qnn_qdq_config(
|
|||
activation_symmetric: bool = False,
|
||||
weight_symmetric: bool | None = None,
|
||||
keep_removable_activations: bool = False,
|
||||
stride: int | None = None,
|
||||
) -> StaticQuantConfig:
|
||||
"""
|
||||
Returns a static quantization configuration suitable for running QDQ models on QNN EP.
|
||||
|
|
@ -171,6 +172,7 @@ def get_qnn_qdq_config(
|
|||
"TensorQuantOverrides": overrides_helper.get_dict(),
|
||||
"ActivationSymmetric": activation_symmetric,
|
||||
"WeightSymmetric": weight_symmetric,
|
||||
"CalibStridedMinMax": stride,
|
||||
}
|
||||
|
||||
# ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain
|
||||
|
|
|
|||
|
|
@ -381,6 +381,9 @@ def quantize_static(
|
|||
CalibTensorRangeSymmetric = True/False :
|
||||
Default is False. If enabled, the final range of tensor during calibration will be explicitly
|
||||
set to symmetric to central point "0".
|
||||
CalibStridedMinMax = Optional[int] :
|
||||
Default is None. If set to an integer, during calculation of the min-max, only stride amount of
|
||||
data will be used and then all results will be merged in the end.
|
||||
CalibMovingAverage = True/False :
|
||||
Default is False. If enabled, the moving average of the minimum and maximum values will be
|
||||
computed when the calibration method selected is MinMax.
|
||||
|
|
@ -522,7 +525,19 @@ def quantize_static(
|
|||
use_external_data_format=use_external_data_format,
|
||||
extra_options=calib_extra_options,
|
||||
)
|
||||
calibrator.collect_data(calibration_data_reader)
|
||||
|
||||
stride = extra_options.get("CalibStridedMinMax", None)
|
||||
if stride:
|
||||
total_data_size = len(calibration_data_reader)
|
||||
if total_data_size % stride != 0:
|
||||
raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).")
|
||||
|
||||
for start in range(0, total_data_size, stride):
|
||||
end_index = start + stride
|
||||
calibration_data_reader.set_range(start_index=start, end_index=end_index)
|
||||
calibrator.collect_data(calibration_data_reader)
|
||||
else:
|
||||
calibrator.collect_data(calibration_data_reader)
|
||||
tensors_range = calibrator.compute_data()
|
||||
if not isinstance(tensors_range, TensorsData):
|
||||
raise TypeError(
|
||||
|
|
|
|||
|
|
@ -217,10 +217,13 @@ class TestDataFeeds(CalibrationDataReader):
|
|||
self.iter_next = iter(self.data_feeds)
|
||||
|
||||
|
||||
def input_feeds_neg_one_zero_one(n, name2shape):
|
||||
def input_feeds_neg_one_zero_one(n, name2shape, seed=None):
|
||||
"""
|
||||
randomize n feed according to shape, its values are from -1, 0, and 1
|
||||
"""
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
input_data_list = []
|
||||
for _i in range(n):
|
||||
inputs = {}
|
||||
|
|
@ -231,6 +234,120 @@ def input_feeds_neg_one_zero_one(n, name2shape):
|
|||
return dr
|
||||
|
||||
|
||||
def input_feeds_neg_one_zero_one_list(n, name2shape, seed=None):
|
||||
"""
|
||||
randomize n feed according to shape, its values are from -1, 0, and 1
|
||||
"""
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
input_data_list = []
|
||||
for _i in range(n):
|
||||
inputs = {}
|
||||
for name, shape in name2shape.items():
|
||||
inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)})
|
||||
input_data_list.extend([inputs])
|
||||
return input_data_list
|
||||
|
||||
|
||||
class GenerateCalibrationData(CalibrationDataReader):
|
||||
def __init__(self, data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last=None):
|
||||
print("Generating calibration dataset from " + str(data_list))
|
||||
print("input nodes are ", input_nodes, "input shapes are ", input_shapes)
|
||||
if inputs_conv_channel_last:
|
||||
print(f"Inputs that will be converted to channel last: {inputs_conv_channel_last}")
|
||||
|
||||
self.enum_data_dicts = []
|
||||
self.input_nodes = input_nodes
|
||||
self.input_shapes = input_shapes
|
||||
self.inputs_conv_channel_last = inputs_conv_channel_last
|
||||
self.calibration_dataset = data_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.calibration_dataset)
|
||||
|
||||
def get_next(self):
|
||||
feed_dict = {}
|
||||
inp = next(self.calibration_dataset, None)
|
||||
if inp is not None:
|
||||
for i in range(len(self.input_nodes)):
|
||||
input_data = inp[i].reshape(self.input_shapes[i])
|
||||
if self.inputs_conv_channel_last is not None and self.input_nodes[i] in self.inputs_conv_channel_last:
|
||||
input_data = np.moveaxis(input_data, 1, -1)
|
||||
dict_item = {self.input_nodes[i]: input_data}
|
||||
feed_dict.update(dict_item)
|
||||
return feed_dict
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class StridedDataReader(GenerateCalibrationData):
|
||||
def __init__(
|
||||
self,
|
||||
data_list,
|
||||
input_nodes,
|
||||
input_shapes,
|
||||
no_tensor_num,
|
||||
in_dtypes,
|
||||
inputs_conv_channel_last=None,
|
||||
stride=1,
|
||||
start_index=0,
|
||||
end_index=None,
|
||||
):
|
||||
super().__init__(data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last)
|
||||
|
||||
self.stride = max(1, stride) # Ensure stride is at least 1
|
||||
self.start_index = start_index
|
||||
self.end_index = (
|
||||
end_index if end_index is not None else len(self.calibration_dataset)
|
||||
) # Default to the end of the dataset
|
||||
self.enum_data_dicts = iter([])
|
||||
|
||||
def get_next(self):
|
||||
iter_data = next(self.enum_data_dicts, None)
|
||||
if iter_data:
|
||||
return iter_data
|
||||
|
||||
self.enum_data_dicts = None
|
||||
if self.start_index < self.end_index:
|
||||
print(f"start index is {self.start_index}")
|
||||
data = self.load_serial()
|
||||
|
||||
self.start_index += self.stride
|
||||
self.enum_data_dicts = iter(data)
|
||||
|
||||
return next(self.enum_data_dicts, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_serial(self):
|
||||
batch_data = []
|
||||
end_loop = min(self.end_index, self.start_index + self.stride)
|
||||
for i in range(self.start_index, end_loop):
|
||||
print(f"debugging the load serial index {i}")
|
||||
data_item = self.calibration_dataset[i]
|
||||
processed_item = self.process_data_item(data_item)
|
||||
batch_data.append(processed_item)
|
||||
return batch_data
|
||||
|
||||
def process_data_item(self, data_item):
|
||||
feed_dict = {}
|
||||
for _, node in enumerate(self.input_nodes):
|
||||
# input_data = data_item[i].reshape(self.input_shapes[i])
|
||||
feed_dict[node] = data_item["input"]
|
||||
return feed_dict
|
||||
|
||||
def set_range(self, start_index, end_index=None):
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index if end_index is not None else len(self.calibration_dataset)
|
||||
self.enum_data_dicts = iter([])
|
||||
|
||||
def rewind(self):
|
||||
"""Rewind the data reader to the beginning of the dataset."""
|
||||
self.start_index = 0
|
||||
self.enum_data_dicts = iter([])
|
||||
|
||||
|
||||
def check_op_type_order(testcase, model_to_check, ops):
|
||||
if isinstance(model_to_check, str):
|
||||
model = onnx.load(model_to_check)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,15 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
from op_test_utils import check_model_correctness, generate_random_initializer, input_feeds_neg_one_zero_one
|
||||
from op_test_utils import (
|
||||
StridedDataReader,
|
||||
check_model_correctness,
|
||||
generate_random_initializer,
|
||||
input_feeds_neg_one_zero_one,
|
||||
input_feeds_neg_one_zero_one_list,
|
||||
)
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.quantization import QuantType, StaticQuantConfig, quantize, quantize_static
|
||||
|
||||
|
||||
|
|
@ -89,6 +96,51 @@ class TestStaticQuantization(unittest.TestCase):
|
|||
check_model_correctness(self, self._model_fp32_path, quant_model_path, data_reader.get_next())
|
||||
data_reader.rewind()
|
||||
|
||||
def run_inference(self, model_path, input_data):
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
output_name = session.get_outputs()[0].name
|
||||
result = session.run([output_name], {input_name: input_data})
|
||||
return result
|
||||
|
||||
def test_stride_effect_on_data_collection(self):
|
||||
# Define the stride and test quantize_static with different stride values
|
||||
stride = 5
|
||||
input_shapes = [1, self._channel_size, 1, 3]
|
||||
data_list = input_feeds_neg_one_zero_one_list(10, {"input": [1, self._channel_size, 1, 3]}, 123)
|
||||
input_nodes = ["input"]
|
||||
in_dtypes = [np.float32] # Example dtype, adjust as needed
|
||||
|
||||
# strided calibration
|
||||
quant_model_path_1 = str(Path(self._tmp_model_dir.name) / "quant.strided.onnx")
|
||||
data_reader_1 = StridedDataReader(
|
||||
data_list, input_nodes, input_shapes, no_tensor_num=0, in_dtypes=in_dtypes, stride=stride
|
||||
)
|
||||
quant_config_1 = StaticQuantConfig(data_reader_1, extra_options={"CalibStridedMinMax": stride})
|
||||
quantize(str(self._model_fp32_path), str(quant_model_path_1), quant_config_1)
|
||||
|
||||
# non-strided calibration
|
||||
quant_model_path_2 = str(Path(self._tmp_model_dir.name) / "quant.non.strided.onnx")
|
||||
data_reader_2 = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}, 123)
|
||||
quant_config_2 = StaticQuantConfig(data_reader_2)
|
||||
quantize(str(self._model_fp32_path), str(quant_model_path_2), quant_config_2)
|
||||
|
||||
# Inference with both models and assert output closeness
|
||||
np.random.seed(123)
|
||||
input_data = np.random.choice([-1, 0, 1], size=[1, self._channel_size, 1, 3]).astype(np.float32)
|
||||
|
||||
result_1 = self.run_inference(quant_model_path_1, input_data)
|
||||
result_2 = self.run_inference(quant_model_path_2, input_data)
|
||||
|
||||
# Assert that the outputs are close
|
||||
np.testing.assert_allclose(
|
||||
result_1,
|
||||
result_2,
|
||||
rtol=0.01,
|
||||
atol=0.01,
|
||||
err_msg="Outputs from strided and non-strided models are not close enough.",
|
||||
)
|
||||
|
||||
def test_static_quant_config(self):
|
||||
data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]})
|
||||
quant_config = StaticQuantConfig(data_reader)
|
||||
|
|
|
|||
Loading…
Reference in a new issue