From 7cf9263ee741222bbc1744ca2970d5786d23e9e6 Mon Sep 17 00:00:00 2001 From: RuomeiMS Date: Fri, 21 Jun 2024 16:23:23 +0100 Subject: [PATCH] Add changes for strided calibration (#20949) Context and motivation: When quantizing large transformer models, we faced OOM issue when the number of calibration samples goes up. To resolve this, in the PR we want to add support for reading quantization data in chunck, calculating ranges for intermediate tensors, then accumulating results for the final ranges. --- .../python/tools/quantization/calibrate.py | 39 +++++- .../execution_providers/qnn/quant_config.py | 2 + .../python/tools/quantization/quantize.py | 17 ++- .../test/python/quantization/op_test_utils.py | 119 +++++++++++++++++- .../quantization/test_quantize_static.py | 54 +++++++- 5 files changed, 222 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 3f5e4e6600..10492ae419 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -128,6 +128,9 @@ class TensorsData: def values(self): return self.data.values() + def items(self): + return self.data.items() + class CalibrationMethod(Enum): MinMax = 0 @@ -155,6 +158,12 @@ class CalibrationDataReader(metaclass=abc.ABCMeta): raise StopIteration return result + def __len__(self): + raise NotImplementedError + + def set_range(self, start_index: int, end_index: int): + raise NotImplementedError + class CalibraterBase: def __init__( @@ -409,13 +418,31 @@ class MinMaxCalibrater(CalibraterBase): return new_range for key, value in old_range.items(): - if self.moving_average: - min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0]) - max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1]) + # Handling for structured data types with TensorData + if isinstance(value, TensorData): + old_min = value.range_value[0] + old_max = value.range_value[1] else: - min_value = min(value[0], new_range[key][0]) - max_value = max(value[1], new_range[key][1]) - new_range[key] = (min_value, max_value) + old_min, old_max = value + + if isinstance(new_range[key], TensorData): + new_min = new_range[key].range_value[0] + new_max = new_range[key].range_value[1] + else: + new_min, new_max = new_range[key] + + if self.moving_average: + min_value = old_min + self.averaging_constant * (new_min - old_min) + max_value = old_max + self.averaging_constant * (new_max - old_max) + else: + min_value = min(old_min, new_min) + max_value = max(old_max, new_max) + + # If structured as TensorData, wrap the result accordingly + if isinstance(value, TensorData) or isinstance(new_range[key], TensorData): + new_range[key] = TensorData(lowest=min_value, highest=max_value) + else: + new_range[key] = (min_value, max_value) return new_range diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 1ad56dc3ac..eac5b3b786 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -52,6 +52,7 @@ def get_qnn_qdq_config( activation_symmetric: bool = False, weight_symmetric: bool | None = None, keep_removable_activations: bool = False, + stride: int | None = None, ) -> StaticQuantConfig: """ Returns a static quantization configuration suitable for running QDQ models on QNN EP. @@ -171,6 +172,7 @@ def get_qnn_qdq_config( "TensorQuantOverrides": overrides_helper.get_dict(), "ActivationSymmetric": activation_symmetric, "WeightSymmetric": weight_symmetric, + "CalibStridedMinMax": stride, } # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index f8b74a7ae4..2340c995d3 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -381,6 +381,9 @@ def quantize_static( CalibTensorRangeSymmetric = True/False : Default is False. If enabled, the final range of tensor during calibration will be explicitly set to symmetric to central point "0". + CalibStridedMinMax = Optional[int] : + Default is None. If set to an integer, during calculation of the min-max, only stride amount of + data will be used and then all results will be merged in the end. CalibMovingAverage = True/False : Default is False. If enabled, the moving average of the minimum and maximum values will be computed when the calibration method selected is MinMax. @@ -522,7 +525,19 @@ def quantize_static( use_external_data_format=use_external_data_format, extra_options=calib_extra_options, ) - calibrator.collect_data(calibration_data_reader) + + stride = extra_options.get("CalibStridedMinMax", None) + if stride: + total_data_size = len(calibration_data_reader) + if total_data_size % stride != 0: + raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") + + for start in range(0, total_data_size, stride): + end_index = start + stride + calibration_data_reader.set_range(start_index=start, end_index=end_index) + calibrator.collect_data(calibration_data_reader) + else: + calibrator.collect_data(calibration_data_reader) tensors_range = calibrator.compute_data() if not isinstance(tensors_range, TensorsData): raise TypeError( diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index b30282f2ab..cf7fc292ea 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -217,10 +217,13 @@ class TestDataFeeds(CalibrationDataReader): self.iter_next = iter(self.data_feeds) -def input_feeds_neg_one_zero_one(n, name2shape): +def input_feeds_neg_one_zero_one(n, name2shape, seed=None): """ randomize n feed according to shape, its values are from -1, 0, and 1 """ + if seed is not None: + np.random.seed(seed) + input_data_list = [] for _i in range(n): inputs = {} @@ -231,6 +234,120 @@ def input_feeds_neg_one_zero_one(n, name2shape): return dr +def input_feeds_neg_one_zero_one_list(n, name2shape, seed=None): + """ + randomize n feed according to shape, its values are from -1, 0, and 1 + """ + if seed is not None: + np.random.seed(seed) + + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + return input_data_list + + +class GenerateCalibrationData(CalibrationDataReader): + def __init__(self, data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last=None): + print("Generating calibration dataset from " + str(data_list)) + print("input nodes are ", input_nodes, "input shapes are ", input_shapes) + if inputs_conv_channel_last: + print(f"Inputs that will be converted to channel last: {inputs_conv_channel_last}") + + self.enum_data_dicts = [] + self.input_nodes = input_nodes + self.input_shapes = input_shapes + self.inputs_conv_channel_last = inputs_conv_channel_last + self.calibration_dataset = data_list + + def __len__(self): + return len(self.calibration_dataset) + + def get_next(self): + feed_dict = {} + inp = next(self.calibration_dataset, None) + if inp is not None: + for i in range(len(self.input_nodes)): + input_data = inp[i].reshape(self.input_shapes[i]) + if self.inputs_conv_channel_last is not None and self.input_nodes[i] in self.inputs_conv_channel_last: + input_data = np.moveaxis(input_data, 1, -1) + dict_item = {self.input_nodes[i]: input_data} + feed_dict.update(dict_item) + return feed_dict + else: + return None + + +class StridedDataReader(GenerateCalibrationData): + def __init__( + self, + data_list, + input_nodes, + input_shapes, + no_tensor_num, + in_dtypes, + inputs_conv_channel_last=None, + stride=1, + start_index=0, + end_index=None, + ): + super().__init__(data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last) + + self.stride = max(1, stride) # Ensure stride is at least 1 + self.start_index = start_index + self.end_index = ( + end_index if end_index is not None else len(self.calibration_dataset) + ) # Default to the end of the dataset + self.enum_data_dicts = iter([]) + + def get_next(self): + iter_data = next(self.enum_data_dicts, None) + if iter_data: + return iter_data + + self.enum_data_dicts = None + if self.start_index < self.end_index: + print(f"start index is {self.start_index}") + data = self.load_serial() + + self.start_index += self.stride + self.enum_data_dicts = iter(data) + + return next(self.enum_data_dicts, None) + else: + return None + + def load_serial(self): + batch_data = [] + end_loop = min(self.end_index, self.start_index + self.stride) + for i in range(self.start_index, end_loop): + print(f"debugging the load serial index {i}") + data_item = self.calibration_dataset[i] + processed_item = self.process_data_item(data_item) + batch_data.append(processed_item) + return batch_data + + def process_data_item(self, data_item): + feed_dict = {} + for _, node in enumerate(self.input_nodes): + # input_data = data_item[i].reshape(self.input_shapes[i]) + feed_dict[node] = data_item["input"] + return feed_dict + + def set_range(self, start_index, end_index=None): + self.start_index = start_index + self.end_index = end_index if end_index is not None else len(self.calibration_dataset) + self.enum_data_dicts = iter([]) + + def rewind(self): + """Rewind the data reader to the beginning of the dataset.""" + self.start_index = 0 + self.enum_data_dicts = iter([]) + + def check_op_type_order(testcase, model_to_check, ops): if isinstance(model_to_check, str): model = onnx.load(model_to_check) diff --git a/onnxruntime/test/python/quantization/test_quantize_static.py b/onnxruntime/test/python/quantization/test_quantize_static.py index 5ad5a49f00..01976ba633 100644 --- a/onnxruntime/test/python/quantization/test_quantize_static.py +++ b/onnxruntime/test/python/quantization/test_quantize_static.py @@ -13,8 +13,15 @@ from pathlib import Path import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import check_model_correctness, generate_random_initializer, input_feeds_neg_one_zero_one +from op_test_utils import ( + StridedDataReader, + check_model_correctness, + generate_random_initializer, + input_feeds_neg_one_zero_one, + input_feeds_neg_one_zero_one_list, +) +import onnxruntime as ort from onnxruntime.quantization import QuantType, StaticQuantConfig, quantize, quantize_static @@ -89,6 +96,51 @@ class TestStaticQuantization(unittest.TestCase): check_model_correctness(self, self._model_fp32_path, quant_model_path, data_reader.get_next()) data_reader.rewind() + def run_inference(self, model_path, input_data): + session = ort.InferenceSession(model_path) + input_name = session.get_inputs()[0].name + output_name = session.get_outputs()[0].name + result = session.run([output_name], {input_name: input_data}) + return result + + def test_stride_effect_on_data_collection(self): + # Define the stride and test quantize_static with different stride values + stride = 5 + input_shapes = [1, self._channel_size, 1, 3] + data_list = input_feeds_neg_one_zero_one_list(10, {"input": [1, self._channel_size, 1, 3]}, 123) + input_nodes = ["input"] + in_dtypes = [np.float32] # Example dtype, adjust as needed + + # strided calibration + quant_model_path_1 = str(Path(self._tmp_model_dir.name) / "quant.strided.onnx") + data_reader_1 = StridedDataReader( + data_list, input_nodes, input_shapes, no_tensor_num=0, in_dtypes=in_dtypes, stride=stride + ) + quant_config_1 = StaticQuantConfig(data_reader_1, extra_options={"CalibStridedMinMax": stride}) + quantize(str(self._model_fp32_path), str(quant_model_path_1), quant_config_1) + + # non-strided calibration + quant_model_path_2 = str(Path(self._tmp_model_dir.name) / "quant.non.strided.onnx") + data_reader_2 = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}, 123) + quant_config_2 = StaticQuantConfig(data_reader_2) + quantize(str(self._model_fp32_path), str(quant_model_path_2), quant_config_2) + + # Inference with both models and assert output closeness + np.random.seed(123) + input_data = np.random.choice([-1, 0, 1], size=[1, self._channel_size, 1, 3]).astype(np.float32) + + result_1 = self.run_inference(quant_model_path_1, input_data) + result_2 = self.run_inference(quant_model_path_2, input_data) + + # Assert that the outputs are close + np.testing.assert_allclose( + result_1, + result_2, + rtol=0.01, + atol=0.01, + err_msg="Outputs from strided and non-strided models are not close enough.", + ) + def test_static_quant_config(self): data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}) quant_config = StaticQuantConfig(data_reader)