diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 3f5e4e6600..10492ae419 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -128,6 +128,9 @@ class TensorsData: def values(self): return self.data.values() + def items(self): + return self.data.items() + class CalibrationMethod(Enum): MinMax = 0 @@ -155,6 +158,12 @@ class CalibrationDataReader(metaclass=abc.ABCMeta): raise StopIteration return result + def __len__(self): + raise NotImplementedError + + def set_range(self, start_index: int, end_index: int): + raise NotImplementedError + class CalibraterBase: def __init__( @@ -409,13 +418,31 @@ class MinMaxCalibrater(CalibraterBase): return new_range for key, value in old_range.items(): - if self.moving_average: - min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0]) - max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1]) + # Handling for structured data types with TensorData + if isinstance(value, TensorData): + old_min = value.range_value[0] + old_max = value.range_value[1] else: - min_value = min(value[0], new_range[key][0]) - max_value = max(value[1], new_range[key][1]) - new_range[key] = (min_value, max_value) + old_min, old_max = value + + if isinstance(new_range[key], TensorData): + new_min = new_range[key].range_value[0] + new_max = new_range[key].range_value[1] + else: + new_min, new_max = new_range[key] + + if self.moving_average: + min_value = old_min + self.averaging_constant * (new_min - old_min) + max_value = old_max + self.averaging_constant * (new_max - old_max) + else: + min_value = min(old_min, new_min) + max_value = max(old_max, new_max) + + # If structured as TensorData, wrap the result accordingly + if isinstance(value, TensorData) or isinstance(new_range[key], TensorData): + new_range[key] = TensorData(lowest=min_value, highest=max_value) + else: + new_range[key] = (min_value, max_value) return new_range diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 1ad56dc3ac..eac5b3b786 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -52,6 +52,7 @@ def get_qnn_qdq_config( activation_symmetric: bool = False, weight_symmetric: bool | None = None, keep_removable_activations: bool = False, + stride: int | None = None, ) -> StaticQuantConfig: """ Returns a static quantization configuration suitable for running QDQ models on QNN EP. @@ -171,6 +172,7 @@ def get_qnn_qdq_config( "TensorQuantOverrides": overrides_helper.get_dict(), "ActivationSymmetric": activation_symmetric, "WeightSymmetric": weight_symmetric, + "CalibStridedMinMax": stride, } # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index f8b74a7ae4..2340c995d3 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -381,6 +381,9 @@ def quantize_static( CalibTensorRangeSymmetric = True/False : Default is False. If enabled, the final range of tensor during calibration will be explicitly set to symmetric to central point "0". + CalibStridedMinMax = Optional[int] : + Default is None. If set to an integer, during calculation of the min-max, only stride amount of + data will be used and then all results will be merged in the end. CalibMovingAverage = True/False : Default is False. If enabled, the moving average of the minimum and maximum values will be computed when the calibration method selected is MinMax. @@ -522,7 +525,19 @@ def quantize_static( use_external_data_format=use_external_data_format, extra_options=calib_extra_options, ) - calibrator.collect_data(calibration_data_reader) + + stride = extra_options.get("CalibStridedMinMax", None) + if stride: + total_data_size = len(calibration_data_reader) + if total_data_size % stride != 0: + raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") + + for start in range(0, total_data_size, stride): + end_index = start + stride + calibration_data_reader.set_range(start_index=start, end_index=end_index) + calibrator.collect_data(calibration_data_reader) + else: + calibrator.collect_data(calibration_data_reader) tensors_range = calibrator.compute_data() if not isinstance(tensors_range, TensorsData): raise TypeError( diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index b30282f2ab..cf7fc292ea 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -217,10 +217,13 @@ class TestDataFeeds(CalibrationDataReader): self.iter_next = iter(self.data_feeds) -def input_feeds_neg_one_zero_one(n, name2shape): +def input_feeds_neg_one_zero_one(n, name2shape, seed=None): """ randomize n feed according to shape, its values are from -1, 0, and 1 """ + if seed is not None: + np.random.seed(seed) + input_data_list = [] for _i in range(n): inputs = {} @@ -231,6 +234,120 @@ def input_feeds_neg_one_zero_one(n, name2shape): return dr +def input_feeds_neg_one_zero_one_list(n, name2shape, seed=None): + """ + randomize n feed according to shape, its values are from -1, 0, and 1 + """ + if seed is not None: + np.random.seed(seed) + + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + return input_data_list + + +class GenerateCalibrationData(CalibrationDataReader): + def __init__(self, data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last=None): + print("Generating calibration dataset from " + str(data_list)) + print("input nodes are ", input_nodes, "input shapes are ", input_shapes) + if inputs_conv_channel_last: + print(f"Inputs that will be converted to channel last: {inputs_conv_channel_last}") + + self.enum_data_dicts = [] + self.input_nodes = input_nodes + self.input_shapes = input_shapes + self.inputs_conv_channel_last = inputs_conv_channel_last + self.calibration_dataset = data_list + + def __len__(self): + return len(self.calibration_dataset) + + def get_next(self): + feed_dict = {} + inp = next(self.calibration_dataset, None) + if inp is not None: + for i in range(len(self.input_nodes)): + input_data = inp[i].reshape(self.input_shapes[i]) + if self.inputs_conv_channel_last is not None and self.input_nodes[i] in self.inputs_conv_channel_last: + input_data = np.moveaxis(input_data, 1, -1) + dict_item = {self.input_nodes[i]: input_data} + feed_dict.update(dict_item) + return feed_dict + else: + return None + + +class StridedDataReader(GenerateCalibrationData): + def __init__( + self, + data_list, + input_nodes, + input_shapes, + no_tensor_num, + in_dtypes, + inputs_conv_channel_last=None, + stride=1, + start_index=0, + end_index=None, + ): + super().__init__(data_list, input_nodes, input_shapes, no_tensor_num, in_dtypes, inputs_conv_channel_last) + + self.stride = max(1, stride) # Ensure stride is at least 1 + self.start_index = start_index + self.end_index = ( + end_index if end_index is not None else len(self.calibration_dataset) + ) # Default to the end of the dataset + self.enum_data_dicts = iter([]) + + def get_next(self): + iter_data = next(self.enum_data_dicts, None) + if iter_data: + return iter_data + + self.enum_data_dicts = None + if self.start_index < self.end_index: + print(f"start index is {self.start_index}") + data = self.load_serial() + + self.start_index += self.stride + self.enum_data_dicts = iter(data) + + return next(self.enum_data_dicts, None) + else: + return None + + def load_serial(self): + batch_data = [] + end_loop = min(self.end_index, self.start_index + self.stride) + for i in range(self.start_index, end_loop): + print(f"debugging the load serial index {i}") + data_item = self.calibration_dataset[i] + processed_item = self.process_data_item(data_item) + batch_data.append(processed_item) + return batch_data + + def process_data_item(self, data_item): + feed_dict = {} + for _, node in enumerate(self.input_nodes): + # input_data = data_item[i].reshape(self.input_shapes[i]) + feed_dict[node] = data_item["input"] + return feed_dict + + def set_range(self, start_index, end_index=None): + self.start_index = start_index + self.end_index = end_index if end_index is not None else len(self.calibration_dataset) + self.enum_data_dicts = iter([]) + + def rewind(self): + """Rewind the data reader to the beginning of the dataset.""" + self.start_index = 0 + self.enum_data_dicts = iter([]) + + def check_op_type_order(testcase, model_to_check, ops): if isinstance(model_to_check, str): model = onnx.load(model_to_check) diff --git a/onnxruntime/test/python/quantization/test_quantize_static.py b/onnxruntime/test/python/quantization/test_quantize_static.py index 5ad5a49f00..01976ba633 100644 --- a/onnxruntime/test/python/quantization/test_quantize_static.py +++ b/onnxruntime/test/python/quantization/test_quantize_static.py @@ -13,8 +13,15 @@ from pathlib import Path import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import check_model_correctness, generate_random_initializer, input_feeds_neg_one_zero_one +from op_test_utils import ( + StridedDataReader, + check_model_correctness, + generate_random_initializer, + input_feeds_neg_one_zero_one, + input_feeds_neg_one_zero_one_list, +) +import onnxruntime as ort from onnxruntime.quantization import QuantType, StaticQuantConfig, quantize, quantize_static @@ -89,6 +96,51 @@ class TestStaticQuantization(unittest.TestCase): check_model_correctness(self, self._model_fp32_path, quant_model_path, data_reader.get_next()) data_reader.rewind() + def run_inference(self, model_path, input_data): + session = ort.InferenceSession(model_path) + input_name = session.get_inputs()[0].name + output_name = session.get_outputs()[0].name + result = session.run([output_name], {input_name: input_data}) + return result + + def test_stride_effect_on_data_collection(self): + # Define the stride and test quantize_static with different stride values + stride = 5 + input_shapes = [1, self._channel_size, 1, 3] + data_list = input_feeds_neg_one_zero_one_list(10, {"input": [1, self._channel_size, 1, 3]}, 123) + input_nodes = ["input"] + in_dtypes = [np.float32] # Example dtype, adjust as needed + + # strided calibration + quant_model_path_1 = str(Path(self._tmp_model_dir.name) / "quant.strided.onnx") + data_reader_1 = StridedDataReader( + data_list, input_nodes, input_shapes, no_tensor_num=0, in_dtypes=in_dtypes, stride=stride + ) + quant_config_1 = StaticQuantConfig(data_reader_1, extra_options={"CalibStridedMinMax": stride}) + quantize(str(self._model_fp32_path), str(quant_model_path_1), quant_config_1) + + # non-strided calibration + quant_model_path_2 = str(Path(self._tmp_model_dir.name) / "quant.non.strided.onnx") + data_reader_2 = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}, 123) + quant_config_2 = StaticQuantConfig(data_reader_2) + quantize(str(self._model_fp32_path), str(quant_model_path_2), quant_config_2) + + # Inference with both models and assert output closeness + np.random.seed(123) + input_data = np.random.choice([-1, 0, 1], size=[1, self._channel_size, 1, 3]).astype(np.float32) + + result_1 = self.run_inference(quant_model_path_1, input_data) + result_2 = self.run_inference(quant_model_path_2, input_data) + + # Assert that the outputs are close + np.testing.assert_allclose( + result_1, + result_2, + rtol=0.01, + atol=0.01, + err_msg="Outputs from strided and non-strided models are not close enough.", + ) + def test_static_quant_config(self): data_reader = input_feeds_neg_one_zero_one(10, {"input": [1, self._channel_size, 1, 3]}) quant_config = StaticQuantConfig(data_reader)