From df28c7d73b72440f115ccf80f3840ea0ca5bb3a9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 5 Jun 2024 16:48:40 -0700 Subject: [PATCH] [Quant tool] Improve performance of int4 weight quantization (#20935) ### Description - Uses our own quantization functions instead of the ONNX reference implementation of QuantizeLinear when quantizing weights to int4. - Uses a custom function that packs bytes into 4-bit elements. ### Motivation and Context Running the quantization tool to create QDQ models with int4 weights could take up to 7x longer. This PR uses our own quantization and byte packing utilities to improve performance. #### Measurements Model with ~5M parameters to quantize to int4. - Current implementation: **84.5s** - Only replace ONNX QuantizeLinear implementation: **50.3s** (1.68x speedup) - This PR (replace onnx Q impl, custom packing func): **13.5s** (6.26x speedup) --------- Signed-off-by: adrianlizarraga --- .../tools/quantization/base_quantizer.py | 39 ++++++---- .../python/tools/quantization/quant_utils.py | 78 +++++++++++-------- .../python/quantization/test_quant_util.py | 69 +++++++++++++++- 3 files changed, 137 insertions(+), 49 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 74e213fa61..06d2ce30b9 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -25,6 +25,7 @@ from .quant_utils import ( find_by_name, model_has_infer_metadata, normalize_axis, + pack_bytes_to_4bit, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, @@ -340,13 +341,17 @@ class BaseQuantizer: f"\nraw={str(q_weight_initializer)[:200]}." ) elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed - # within int32_data is fixed. - # q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) - packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4) - q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True - ) + if q_weight_data.dtype not in (np.int8, np.uint8): + raise RuntimeError( + f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values." + ) + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -483,16 +488,18 @@ class BaseQuantizer: if not keep_float_weight: if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed - # within int32_data is fixed. - # q_weight_initializer = onnx.helper.make_tensor( - # q_weight_name, weight_qType, weights_shape, quantized_weights - # ) - packed_data = onnx.helper.pack_float32_to_4bit( - quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4 - ) + if quantized_weights.dtype not in (np.int8, np.uint8): + raise RuntimeError( + f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values." + ) + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True + q_weight_name, weight_qType, weights_shape, packed_data, raw=True ) self.model.initializer_extend([q_weight_initializer]) else: diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index bdf6d5a355..53d2eaeaba 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -21,10 +21,18 @@ from onnx.reference import ReferenceEvaluator from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions try: - from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4 + from onnx.reference.custom_element_types import float8e4m3fn except ImportError: float8e4m3fn = None +# INT4 np.dtypes added in ONNX 1.16. These map to np.int8/np.uint8 because numpy +# does not support sub-byte types. +try: + from onnx.reference.custom_element_types import int4, uint4 +except ImportError: + int4 = None + uint4 = None + __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -134,8 +142,8 @@ ONNX_TYPE_TO_NP_TYPE = { onnx_proto.TensorProto.INT16: numpy.dtype("int16"), onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, - onnx_proto.TensorProto.INT4: int4, - onnx_proto.TensorProto.UINT4: uint4, + onnx_proto.TensorProto.INT4: int4, # base_dtype is np.int8 + onnx_proto.TensorProto.UINT4: uint4, # base_dtype is np.uint8 } ONNX_INT_TYPE_RANGE = { @@ -212,36 +220,12 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): ) ref = ReferenceEvaluator(onnx_model) return _check_type(ref.run(None, {"X": arr, "scale": scale})[0]) - elif qType in ( - onnx_proto.TensorProto.INT4, - onnx_proto.TensorProto.UINT4, - ): - if arr.dtype == numpy.float32: - onnx_type = TensorProto.FLOAT - elif arr.dtype == numpy.float16: - onnx_type = TensorProto.FLOAT16 - else: - raise ValueError(f"Unexpected dtype {arr.dtype}.") - onnx_model = make_model( - make_graph( - [ - make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]), - ], - "qu", - [ - make_tensor_value_info("X", onnx_type, None), - make_tensor_value_info("scale", onnx_type, None), - make_tensor_value_info("zero_point", qType, None), - ], - [make_tensor_value_info("Y", qType, None)], - ) - ) - # The reference ONNX implementation of QuantizeLinear returns "unpacked" int8 numpy values - # because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this). - # These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor(). - ref = ReferenceEvaluator(onnx_model) - return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0]) else: + # Quantizes data for all integer types. + # + # For int4 types, the quantized data is returned as either np.int8 or np.uint8, + # which matches the python reference ONNX implementation of QuantizeLinear. + # This data can be packed into 4-bit elements by using pack_bytes_to_4bit(). dtype = ONNX_TYPE_TO_NP_TYPE[qType] (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) @@ -482,6 +466,36 @@ def normalize_axis(axis: int, rank: int) -> tuple[bool, int]: return is_valid, axis_norm +def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray: + """ + Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values. + Assumes that the source values are already in the appropriate int4 range. + :parameter src_8bit: The 8-bit element values to pack. + :return A bytearray with every two 8-bit src elements packed into a single byte. + """ + num_elems = len(src_8bit) + if num_elems == 0: + return bytearray() + + dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes + dst = bytearray(dst_size) + + src_i: int = 0 + dst_i: int = 0 + + # Pack two 8-bit elements into a single byte in each iteration. + while src_i < num_elems - 1: + dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF) + dst_i += 1 + src_i += 2 + + if src_i < num_elems: + # Odd number of elements. + dst[dst_i] = src_8bit[src_i] & 0xF + + return dst + + class QuantizedInitializer: """ Represents a linearly quantized weight input from ONNX operators diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 848857ceb2..7b3fc08982 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -13,7 +13,13 @@ import numpy import onnx from onnx import TensorProto, helper, numpy_helper -from onnxruntime.quantization.quant_utils import compute_scale_zp, load_model_with_shape_infer, model_has_infer_metadata +from onnxruntime.quantization.quant_utils import ( + compute_scale_zp, + load_model_with_shape_infer, + model_has_infer_metadata, + pack_bytes_to_4bit, + quantize_data, +) class TestQuantUtil(unittest.TestCase): @@ -101,6 +107,67 @@ class TestQuantUtil(unittest.TestCase): model_reloaded = load_model_with_shape_infer(Path(model_file_path)) self.assertTrue(model_has_infer_metadata(model_reloaded)) + def test_pack_bytes_to_4bit(self): + """ + Tests the pack_bytes_to_4bit() utility. + """ + subtest_configs = [ + (-8, 6, True), # Odd num elems, signed + (-8, 7, True), # Even num elems, signed + (0, 14, False), # Odd num elems, unsigned + (0, 15, False), # Even num elems, unsigned + ] + for min_val, max_val, signed in subtest_configs: + with self.subTest(min_val=min_val, max_val=max_val, signed=signed): + src_float = numpy.arange(min_val, max_val + 1).astype(numpy.float32) + src_int = src_float.astype(numpy.int8 if signed else numpy.uint8) + + actual_packed_vals = bytes(pack_bytes_to_4bit(src_int.tobytes())) + expected_packed_vals = onnx.helper.pack_float32_to_4bit(src_float, signed).tobytes() + self.assertEqual(actual_packed_vals, expected_packed_vals) + + def test_quantize_data_4bit(self): + """ + Test that calling quantize_data for int4 quantization returns data of the correct type and range. + """ + data_float = numpy.arange(-20, 17).astype(numpy.float32) + + subtest_configs = [ + (onnx.TensorProto.INT4, True), # int4, symmetric quant + (onnx.TensorProto.INT4, False), # int4, symmetric quant + (onnx.TensorProto.UINT4, True), # uint4, symmetric quant + (onnx.TensorProto.UINT4, False), # uint4, symmetric quant + ] + + for onnx_type, symmetric in subtest_configs: + with self.subTest(onnx_type=onnx_type, symmetric=symmetric): + _, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) + is_signed = onnx_type == onnx.TensorProto.INT4 + np_int_type = numpy.int8 if is_signed else numpy.uint8 + qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type) + qmax = numpy.array(7 if is_signed else 15, dtype=np_int_type) + + self.assertEqual(zero_point.dtype, np_int_type) + self.assertEqual(scale.dtype, data_float.dtype) + + expected_zp, expected_scale = compute_scale_zp( + data_float.min(), data_float.max(), qmin, qmax, symmetric=symmetric + ) + self.assertEqual(zero_point, expected_zp) + self.assertEqual(scale, expected_scale) + + # Even int4 quantization generates 8-bit numpy values. + self.assertEqual(data_quant.dtype, np_int_type) + for index, actual_quant_val in enumerate(data_quant.flatten()): + self.assertTrue(actual_quant_val >= qmin and actual_quant_val <= qmax) + + expected_quant_val = numpy.asarray((data_float[index] / scale).round() + zero_point).astype( + np_int_type + ) + numpy.clip(expected_quant_val, qmin, qmax, out=expected_quant_val) + + self.assertEqual(numpy.array(actual_quant_val), expected_quant_val) + if __name__ == "__main__": unittest.main()