diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 74e213fa61..06d2ce30b9 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -25,6 +25,7 @@ from .quant_utils import ( find_by_name, model_has_infer_metadata, normalize_axis, + pack_bytes_to_4bit, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, @@ -340,13 +341,17 @@ class BaseQuantizer: f"\nraw={str(q_weight_initializer)[:200]}." ) elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed - # within int32_data is fixed. - # q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) - packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4) - q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True - ) + if q_weight_data.dtype not in (np.int8, np.uint8): + raise RuntimeError( + f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values." + ) + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -483,16 +488,18 @@ class BaseQuantizer: if not keep_float_weight: if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed - # within int32_data is fixed. - # q_weight_initializer = onnx.helper.make_tensor( - # q_weight_name, weight_qType, weights_shape, quantized_weights - # ) - packed_data = onnx.helper.pack_float32_to_4bit( - quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4 - ) + if quantized_weights.dtype not in (np.int8, np.uint8): + raise RuntimeError( + f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values." + ) + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True + q_weight_name, weight_qType, weights_shape, packed_data, raw=True ) self.model.initializer_extend([q_weight_initializer]) else: diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index bdf6d5a355..53d2eaeaba 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -21,10 +21,18 @@ from onnx.reference import ReferenceEvaluator from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions try: - from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4 + from onnx.reference.custom_element_types import float8e4m3fn except ImportError: float8e4m3fn = None +# INT4 np.dtypes added in ONNX 1.16. These map to np.int8/np.uint8 because numpy +# does not support sub-byte types. +try: + from onnx.reference.custom_element_types import int4, uint4 +except ImportError: + int4 = None + uint4 = None + __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -134,8 +142,8 @@ ONNX_TYPE_TO_NP_TYPE = { onnx_proto.TensorProto.INT16: numpy.dtype("int16"), onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, - onnx_proto.TensorProto.INT4: int4, - onnx_proto.TensorProto.UINT4: uint4, + onnx_proto.TensorProto.INT4: int4, # base_dtype is np.int8 + onnx_proto.TensorProto.UINT4: uint4, # base_dtype is np.uint8 } ONNX_INT_TYPE_RANGE = { @@ -212,36 +220,12 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): ) ref = ReferenceEvaluator(onnx_model) return _check_type(ref.run(None, {"X": arr, "scale": scale})[0]) - elif qType in ( - onnx_proto.TensorProto.INT4, - onnx_proto.TensorProto.UINT4, - ): - if arr.dtype == numpy.float32: - onnx_type = TensorProto.FLOAT - elif arr.dtype == numpy.float16: - onnx_type = TensorProto.FLOAT16 - else: - raise ValueError(f"Unexpected dtype {arr.dtype}.") - onnx_model = make_model( - make_graph( - [ - make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]), - ], - "qu", - [ - make_tensor_value_info("X", onnx_type, None), - make_tensor_value_info("scale", onnx_type, None), - make_tensor_value_info("zero_point", qType, None), - ], - [make_tensor_value_info("Y", qType, None)], - ) - ) - # The reference ONNX implementation of QuantizeLinear returns "unpacked" int8 numpy values - # because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this). - # These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor(). - ref = ReferenceEvaluator(onnx_model) - return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0]) else: + # Quantizes data for all integer types. + # + # For int4 types, the quantized data is returned as either np.int8 or np.uint8, + # which matches the python reference ONNX implementation of QuantizeLinear. + # This data can be packed into 4-bit elements by using pack_bytes_to_4bit(). dtype = ONNX_TYPE_TO_NP_TYPE[qType] (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) @@ -482,6 +466,36 @@ def normalize_axis(axis: int, rank: int) -> tuple[bool, int]: return is_valid, axis_norm +def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray: + """ + Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values. + Assumes that the source values are already in the appropriate int4 range. + :parameter src_8bit: The 8-bit element values to pack. + :return A bytearray with every two 8-bit src elements packed into a single byte. + """ + num_elems = len(src_8bit) + if num_elems == 0: + return bytearray() + + dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes + dst = bytearray(dst_size) + + src_i: int = 0 + dst_i: int = 0 + + # Pack two 8-bit elements into a single byte in each iteration. + while src_i < num_elems - 1: + dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF) + dst_i += 1 + src_i += 2 + + if src_i < num_elems: + # Odd number of elements. + dst[dst_i] = src_8bit[src_i] & 0xF + + return dst + + class QuantizedInitializer: """ Represents a linearly quantized weight input from ONNX operators diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 848857ceb2..7b3fc08982 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -13,7 +13,13 @@ import numpy import onnx from onnx import TensorProto, helper, numpy_helper -from onnxruntime.quantization.quant_utils import compute_scale_zp, load_model_with_shape_infer, model_has_infer_metadata +from onnxruntime.quantization.quant_utils import ( + compute_scale_zp, + load_model_with_shape_infer, + model_has_infer_metadata, + pack_bytes_to_4bit, + quantize_data, +) class TestQuantUtil(unittest.TestCase): @@ -101,6 +107,67 @@ class TestQuantUtil(unittest.TestCase): model_reloaded = load_model_with_shape_infer(Path(model_file_path)) self.assertTrue(model_has_infer_metadata(model_reloaded)) + def test_pack_bytes_to_4bit(self): + """ + Tests the pack_bytes_to_4bit() utility. + """ + subtest_configs = [ + (-8, 6, True), # Odd num elems, signed + (-8, 7, True), # Even num elems, signed + (0, 14, False), # Odd num elems, unsigned + (0, 15, False), # Even num elems, unsigned + ] + for min_val, max_val, signed in subtest_configs: + with self.subTest(min_val=min_val, max_val=max_val, signed=signed): + src_float = numpy.arange(min_val, max_val + 1).astype(numpy.float32) + src_int = src_float.astype(numpy.int8 if signed else numpy.uint8) + + actual_packed_vals = bytes(pack_bytes_to_4bit(src_int.tobytes())) + expected_packed_vals = onnx.helper.pack_float32_to_4bit(src_float, signed).tobytes() + self.assertEqual(actual_packed_vals, expected_packed_vals) + + def test_quantize_data_4bit(self): + """ + Test that calling quantize_data for int4 quantization returns data of the correct type and range. + """ + data_float = numpy.arange(-20, 17).astype(numpy.float32) + + subtest_configs = [ + (onnx.TensorProto.INT4, True), # int4, symmetric quant + (onnx.TensorProto.INT4, False), # int4, symmetric quant + (onnx.TensorProto.UINT4, True), # uint4, symmetric quant + (onnx.TensorProto.UINT4, False), # uint4, symmetric quant + ] + + for onnx_type, symmetric in subtest_configs: + with self.subTest(onnx_type=onnx_type, symmetric=symmetric): + _, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) + is_signed = onnx_type == onnx.TensorProto.INT4 + np_int_type = numpy.int8 if is_signed else numpy.uint8 + qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type) + qmax = numpy.array(7 if is_signed else 15, dtype=np_int_type) + + self.assertEqual(zero_point.dtype, np_int_type) + self.assertEqual(scale.dtype, data_float.dtype) + + expected_zp, expected_scale = compute_scale_zp( + data_float.min(), data_float.max(), qmin, qmax, symmetric=symmetric + ) + self.assertEqual(zero_point, expected_zp) + self.assertEqual(scale, expected_scale) + + # Even int4 quantization generates 8-bit numpy values. + self.assertEqual(data_quant.dtype, np_int_type) + for index, actual_quant_val in enumerate(data_quant.flatten()): + self.assertTrue(actual_quant_val >= qmin and actual_quant_val <= qmax) + + expected_quant_val = numpy.asarray((data_float[index] / scale).round() + zero_point).astype( + np_int_type + ) + numpy.clip(expected_quant_val, qmin, qmax, out=expected_quant_val) + + self.assertEqual(numpy.array(actual_quant_val), expected_quant_val) + if __name__ == "__main__": unittest.main()