mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[Quant tool] Improve performance of int4 weight quantization (#20935)
### Description - Uses our own quantization functions instead of the ONNX reference implementation of QuantizeLinear when quantizing weights to int4. - Uses a custom function that packs bytes into 4-bit elements. ### Motivation and Context Running the quantization tool to create QDQ models with int4 weights could take up to 7x longer. This PR uses our own quantization and byte packing utilities to improve performance. #### Measurements Model with ~5M parameters to quantize to int4. - Current implementation: **84.5s** - Only replace ONNX QuantizeLinear implementation: **50.3s** (1.68x speedup) - This PR (replace onnx Q impl, custom packing func): **13.5s** (6.26x speedup) --------- Signed-off-by: adrianlizarraga <adlizarraga@microsoft.com>
This commit is contained in:
parent
4cb23b020c
commit
df28c7d73b
3 changed files with 137 additions and 49 deletions
|
|
@ -25,6 +25,7 @@ from .quant_utils import (
|
|||
find_by_name,
|
||||
model_has_infer_metadata,
|
||||
normalize_axis,
|
||||
pack_bytes_to_4bit,
|
||||
quantize_data,
|
||||
quantize_nparray,
|
||||
save_and_reload_model_with_shape_infer,
|
||||
|
|
@ -340,13 +341,17 @@ class BaseQuantizer:
|
|||
f"\nraw={str(q_weight_initializer)[:200]}."
|
||||
)
|
||||
elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
|
||||
# TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed
|
||||
# within int32_data is fixed.
|
||||
# q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data)
|
||||
packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4)
|
||||
q_weight_initializer = onnx.helper.make_tensor(
|
||||
q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True
|
||||
)
|
||||
if q_weight_data.dtype not in (np.int8, np.uint8):
|
||||
raise RuntimeError(
|
||||
f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
|
||||
)
|
||||
|
||||
# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
|
||||
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
|
||||
packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
|
||||
|
||||
# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
|
||||
q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True)
|
||||
else:
|
||||
q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
|
||||
weight.dims
|
||||
|
|
@ -483,16 +488,18 @@ class BaseQuantizer:
|
|||
|
||||
if not keep_float_weight:
|
||||
if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
|
||||
# TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed
|
||||
# within int32_data is fixed.
|
||||
# q_weight_initializer = onnx.helper.make_tensor(
|
||||
# q_weight_name, weight_qType, weights_shape, quantized_weights
|
||||
# )
|
||||
packed_data = onnx.helper.pack_float32_to_4bit(
|
||||
quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4
|
||||
)
|
||||
if quantized_weights.dtype not in (np.int8, np.uint8):
|
||||
raise RuntimeError(
|
||||
f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
|
||||
)
|
||||
|
||||
# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
|
||||
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
|
||||
packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes()))
|
||||
|
||||
# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
|
||||
q_weight_initializer = onnx.helper.make_tensor(
|
||||
q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True
|
||||
q_weight_name, weight_qType, weights_shape, packed_data, raw=True
|
||||
)
|
||||
self.model.initializer_extend([q_weight_initializer])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -21,10 +21,18 @@ from onnx.reference import ReferenceEvaluator
|
|||
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
|
||||
|
||||
try:
|
||||
from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4
|
||||
from onnx.reference.custom_element_types import float8e4m3fn
|
||||
except ImportError:
|
||||
float8e4m3fn = None
|
||||
|
||||
# INT4 np.dtypes added in ONNX 1.16. These map to np.int8/np.uint8 because numpy
|
||||
# does not support sub-byte types.
|
||||
try:
|
||||
from onnx.reference.custom_element_types import int4, uint4
|
||||
except ImportError:
|
||||
int4 = None
|
||||
uint4 = None
|
||||
|
||||
|
||||
__producer__ = "onnx.quantize"
|
||||
__version__ = "0.1.0"
|
||||
|
|
@ -134,8 +142,8 @@ ONNX_TYPE_TO_NP_TYPE = {
|
|||
onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
|
||||
onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
|
||||
onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn,
|
||||
onnx_proto.TensorProto.INT4: int4,
|
||||
onnx_proto.TensorProto.UINT4: uint4,
|
||||
onnx_proto.TensorProto.INT4: int4, # base_dtype is np.int8
|
||||
onnx_proto.TensorProto.UINT4: uint4, # base_dtype is np.uint8
|
||||
}
|
||||
|
||||
ONNX_INT_TYPE_RANGE = {
|
||||
|
|
@ -212,36 +220,12 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
|
|||
)
|
||||
ref = ReferenceEvaluator(onnx_model)
|
||||
return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
|
||||
elif qType in (
|
||||
onnx_proto.TensorProto.INT4,
|
||||
onnx_proto.TensorProto.UINT4,
|
||||
):
|
||||
if arr.dtype == numpy.float32:
|
||||
onnx_type = TensorProto.FLOAT
|
||||
elif arr.dtype == numpy.float16:
|
||||
onnx_type = TensorProto.FLOAT16
|
||||
else:
|
||||
raise ValueError(f"Unexpected dtype {arr.dtype}.")
|
||||
onnx_model = make_model(
|
||||
make_graph(
|
||||
[
|
||||
make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]),
|
||||
],
|
||||
"qu",
|
||||
[
|
||||
make_tensor_value_info("X", onnx_type, None),
|
||||
make_tensor_value_info("scale", onnx_type, None),
|
||||
make_tensor_value_info("zero_point", qType, None),
|
||||
],
|
||||
[make_tensor_value_info("Y", qType, None)],
|
||||
)
|
||||
)
|
||||
# The reference ONNX implementation of QuantizeLinear<int4> returns "unpacked" int8 numpy values
|
||||
# because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this).
|
||||
# These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor().
|
||||
ref = ReferenceEvaluator(onnx_model)
|
||||
return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0])
|
||||
else:
|
||||
# Quantizes data for all integer types.
|
||||
#
|
||||
# For int4 types, the quantized data is returned as either np.int8 or np.uint8,
|
||||
# which matches the python reference ONNX implementation of QuantizeLinear.
|
||||
# This data can be packed into 4-bit elements by using pack_bytes_to_4bit().
|
||||
dtype = ONNX_TYPE_TO_NP_TYPE[qType]
|
||||
(qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True)
|
||||
|
||||
|
|
@ -482,6 +466,36 @@ def normalize_axis(axis: int, rank: int) -> tuple[bool, int]:
|
|||
return is_valid, axis_norm
|
||||
|
||||
|
||||
def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray:
|
||||
"""
|
||||
Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values.
|
||||
Assumes that the source values are already in the appropriate int4 range.
|
||||
:parameter src_8bit: The 8-bit element values to pack.
|
||||
:return A bytearray with every two 8-bit src elements packed into a single byte.
|
||||
"""
|
||||
num_elems = len(src_8bit)
|
||||
if num_elems == 0:
|
||||
return bytearray()
|
||||
|
||||
dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes
|
||||
dst = bytearray(dst_size)
|
||||
|
||||
src_i: int = 0
|
||||
dst_i: int = 0
|
||||
|
||||
# Pack two 8-bit elements into a single byte in each iteration.
|
||||
while src_i < num_elems - 1:
|
||||
dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF)
|
||||
dst_i += 1
|
||||
src_i += 2
|
||||
|
||||
if src_i < num_elems:
|
||||
# Odd number of elements.
|
||||
dst[dst_i] = src_8bit[src_i] & 0xF
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
class QuantizedInitializer:
|
||||
"""
|
||||
Represents a linearly quantized weight input from ONNX operators
|
||||
|
|
|
|||
|
|
@ -13,7 +13,13 @@ import numpy
|
|||
import onnx
|
||||
from onnx import TensorProto, helper, numpy_helper
|
||||
|
||||
from onnxruntime.quantization.quant_utils import compute_scale_zp, load_model_with_shape_infer, model_has_infer_metadata
|
||||
from onnxruntime.quantization.quant_utils import (
|
||||
compute_scale_zp,
|
||||
load_model_with_shape_infer,
|
||||
model_has_infer_metadata,
|
||||
pack_bytes_to_4bit,
|
||||
quantize_data,
|
||||
)
|
||||
|
||||
|
||||
class TestQuantUtil(unittest.TestCase):
|
||||
|
|
@ -101,6 +107,67 @@ class TestQuantUtil(unittest.TestCase):
|
|||
model_reloaded = load_model_with_shape_infer(Path(model_file_path))
|
||||
self.assertTrue(model_has_infer_metadata(model_reloaded))
|
||||
|
||||
def test_pack_bytes_to_4bit(self):
|
||||
"""
|
||||
Tests the pack_bytes_to_4bit() utility.
|
||||
"""
|
||||
subtest_configs = [
|
||||
(-8, 6, True), # Odd num elems, signed
|
||||
(-8, 7, True), # Even num elems, signed
|
||||
(0, 14, False), # Odd num elems, unsigned
|
||||
(0, 15, False), # Even num elems, unsigned
|
||||
]
|
||||
for min_val, max_val, signed in subtest_configs:
|
||||
with self.subTest(min_val=min_val, max_val=max_val, signed=signed):
|
||||
src_float = numpy.arange(min_val, max_val + 1).astype(numpy.float32)
|
||||
src_int = src_float.astype(numpy.int8 if signed else numpy.uint8)
|
||||
|
||||
actual_packed_vals = bytes(pack_bytes_to_4bit(src_int.tobytes()))
|
||||
expected_packed_vals = onnx.helper.pack_float32_to_4bit(src_float, signed).tobytes()
|
||||
self.assertEqual(actual_packed_vals, expected_packed_vals)
|
||||
|
||||
def test_quantize_data_4bit(self):
|
||||
"""
|
||||
Test that calling quantize_data for int4 quantization returns data of the correct type and range.
|
||||
"""
|
||||
data_float = numpy.arange(-20, 17).astype(numpy.float32)
|
||||
|
||||
subtest_configs = [
|
||||
(onnx.TensorProto.INT4, True), # int4, symmetric quant
|
||||
(onnx.TensorProto.INT4, False), # int4, symmetric quant
|
||||
(onnx.TensorProto.UINT4, True), # uint4, symmetric quant
|
||||
(onnx.TensorProto.UINT4, False), # uint4, symmetric quant
|
||||
]
|
||||
|
||||
for onnx_type, symmetric in subtest_configs:
|
||||
with self.subTest(onnx_type=onnx_type, symmetric=symmetric):
|
||||
_, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric)
|
||||
is_signed = onnx_type == onnx.TensorProto.INT4
|
||||
np_int_type = numpy.int8 if is_signed else numpy.uint8
|
||||
qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type)
|
||||
qmax = numpy.array(7 if is_signed else 15, dtype=np_int_type)
|
||||
|
||||
self.assertEqual(zero_point.dtype, np_int_type)
|
||||
self.assertEqual(scale.dtype, data_float.dtype)
|
||||
|
||||
expected_zp, expected_scale = compute_scale_zp(
|
||||
data_float.min(), data_float.max(), qmin, qmax, symmetric=symmetric
|
||||
)
|
||||
self.assertEqual(zero_point, expected_zp)
|
||||
self.assertEqual(scale, expected_scale)
|
||||
|
||||
# Even int4 quantization generates 8-bit numpy values.
|
||||
self.assertEqual(data_quant.dtype, np_int_type)
|
||||
for index, actual_quant_val in enumerate(data_quant.flatten()):
|
||||
self.assertTrue(actual_quant_val >= qmin and actual_quant_val <= qmax)
|
||||
|
||||
expected_quant_val = numpy.asarray((data_float[index] / scale).round() + zero_point).astype(
|
||||
np_int_type
|
||||
)
|
||||
numpy.clip(expected_quant_val, qmin, qmax, out=expected_quant_val)
|
||||
|
||||
self.assertEqual(numpy.array(actual_quant_val), expected_quant_val)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue