diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 11a830dc6d..40a4a4d26d 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -18,31 +18,36 @@ import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version -from onnxruntime.capi._pybind_state import quantize_matmul_4bits +from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel -from .quant_utils import attribute_to_kwarg +from .quant_utils import QuantFormat, attribute_to_kwarg logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) class WeightOnlyQuantConfig: - def __init__(self, algorithm): + def __init__(self, algorithm, quant_format): """This is the Base class for Weight Only Quant Configuration. Args: algorithm: weight only quantize algorithm name. + quant_format: QuantFormat{QOperator, QDQ}. + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. """ self.algorithm = algorithm + self.quant_format = quant_format class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, ratios=None, + quant_format=QuantFormat.QOperator, ): """ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. @@ -51,11 +56,18 @@ class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): Args: ratios: percentile of clip. Defaults to {}. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format" + if ratios is None: ratios = {} super().__init__( algorithm="RTN", + quant_format=quant_format, ) self.ratios = ratios @@ -69,6 +81,7 @@ class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): actorder=False, mse=False, perchannel=True, + quant_format=QuantFormat.QOperator, ): """ This is a class for GPTQ algorithm Weight Only Quant Configuration. @@ -87,9 +100,16 @@ class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): whether get scale and zero point with mse error. perchannel (bool, optional): whether quantize weight per-channel. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format" + super().__init__( algorithm="GPTQ", + quant_format=quant_format, ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp @@ -105,6 +125,7 @@ class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): block_size=128, bits=4, axis=1, + quant_format=QuantFormat.QOperator, ): """ This is a class for HQQ algorithm Weight Only Quant Configuration. @@ -112,14 +133,21 @@ class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): Args: block_size (int, optional): - channel number in one block to execute a GPTQ quantization iteration. + channel number in one block to execute a HQQ quantization iteration. bits (int, optional): how many bits to represent weight. axis (int, optional): 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. """ + assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format" + super().__init__( algorithm="HQQ", + quant_format=quant_format, ) self.block_size = block_size self.bits = bits @@ -132,8 +160,26 @@ class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, + quant_format=QuantFormat.QOperator, ): - super().__init__(algorithm="DEFAULT") + """ + This is a class for weight only affine quantization configuration. + + Args: + block_size (int, optional): + channel number in one block to execute an affine quantization iteration. + is_symmetric (bool, optional): + whether quantize weight symmetrically. + accuracy_level (int, optional): + Accuracy level of the 4-bit quantized MatMul computation. + Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details. + (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits) + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. + """ + super().__init__(algorithm="DEFAULT", quant_format=quant_format) self.block_size = block_size self.is_symmetric = is_symmetric self.bits = 4 @@ -287,23 +333,26 @@ class HQQWeightOnlyQuantizer: return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now import torch logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - b_pb, bs_graph = get_initializer(inputB, graph_stack) + input_b = node.input[1] + b_pb, bs_graph = get_initializer(input_b, graph_stack) if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight b_array = onnx.numpy_helper.to_array(b_pb) if len(b_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix + return [node] # can only process 2-D matrix b_array_torch = torch.from_numpy(b_array) if torch.cuda.is_available(): b_array_torch = b_array_torch.cuda() @@ -334,7 +383,7 @@ class HQQWeightOnlyQuantizer: b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) b_quant.name = b_pb.name + "_Q4" for input in bs_graph.input: - if input.name == inputB: + if input.name == input_b: bs_graph.input.remove(input) break @@ -366,7 +415,7 @@ class HQQWeightOnlyQuantizer: logger.info(f"complete quantization of {node.name} ...") - return matmul_q4_node + return [matmul_q4_node] def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: @@ -382,7 +431,7 @@ class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config - def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: + def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """4b quantize fp32 weight to a blob""" if len(fp32weight.shape) != 2: @@ -390,83 +439,136 @@ class DefaultWeightOnlyQuantizer: rows, cols = fp32weight.shape block_size = self.config.block_size - blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size - padded_rows = k_blocks * block_size - pad_len = padded_rows - rows - if pad_len > 0: - fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") - # block wise quantization, each block comes from a single column - packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") - scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) - zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) + if self.config.quant_format == QuantFormat.QOperator: + blob_size = block_size // 2 + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") + + # block wise quantization, each block comes from a single column + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) + quantize_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) + else: + packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype) + quantize_qdq_matmul_4bits( + packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric + ) return (packed, scales, zero_point) - def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: + """ + If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. + If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. + """ if node.op_type != "MatMul": - return node # only care about MatMul for now + return [node] # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - inputB = node.input[1] # noqa: N806 - B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: + qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4 + input_b = node.input[1] + b_tensor, b_graph = get_initializer(input_b, graph_stack) + if b_tensor is None: logger.info("MatMul doesn't have const weight. Skip to quantize") - return node # only care about constant weight + return [node] # only care about constant weight - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: + b_ndarray = onnx.numpy_helper.to_array(b_tensor) + if len(b_ndarray.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") - return node # can only process 2-D matrix + return [node] # can only process 2-D matrix - packed, scales, zero_points = self.int4_block_quant(B_array) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" - for input in Bs_graph.input: - if input.name == inputB: - Bs_graph.input.remove(input) + packed, scales, zero_points = self.int4_block_quant(b_ndarray) + + if self.config.quant_format == QuantFormat.QOperator: + b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4") + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales") + else: + b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + + for input in b_graph.input: + if input.name == input_b: + b_graph.input.remove(input) break - scales_tensor = onnx.numpy_helper.from_array(scales) - scales_tensor.name = B.name + "_scales" - Bs_graph.initializer.extend([B_quant, scales_tensor]) + b_graph.initializer.extend([b_quant, scales_tensor]) - input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.config.is_symmetric: - zp_tensor = onnx.numpy_helper.from_array(zero_points) - zp_tensor.name = B.name + "_zero_points" - Bs_graph.initializer.extend([zp_tensor]) - input_names.append(zp_tensor.name) + output_nodes = [] - kwargs = {} - rows, cols = B_array.shape - kwargs["K"] = rows - kwargs["N"] = cols - kwargs["bits"] = 4 - kwargs["block_size"] = self.config.block_size - if self.config.accuracy_level is not None: - kwargs["accuracy_level"] = self.config.accuracy_level + if self.config.quant_format == QuantFormat.QOperator: + input_names = [node.input[0], b_quant.name, scales_tensor.name] + if not self.config.is_symmetric: + zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points") + input_names.append(zp_tensor.name) + b_graph.initializer.extend([zp_tensor]) + kwargs = {} + rows, cols = b_ndarray.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = 4 + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: + kwargs["accuracy_level"] = self.config.accuracy_level - matmul_q4_node = onnx.helper.make_node( - "MatMulNBits", - inputs=input_names, - outputs=[node.output[0]], - name=node.name + "_Q4" if node.name else "", - domain="com.microsoft", - **kwargs, - ) + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + output_nodes.append(matmul_q4_node) + else: + dq_input_names = [b_quant.name, scales_tensor.name] + dq_output_names = [b_quant.name + "_output"] + matmul_input_names = [node.input[0], dq_output_names[0]] + matmul_output_names = [node.output[0]] + if not self.config.is_symmetric: + zp_tensor = onnx.helper.make_tensor( + b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True + ) + dq_input_names.append(zp_tensor.name) + b_graph.initializer.extend([zp_tensor]) + dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + dq_node = onnx.helper.make_node( + "DequantizeLinear", + inputs=dq_input_names, + outputs=dq_output_names, + name=node.name + "_DQ_Q4" if node.name else "", + **dq_kwargs, + ) + matmul_node = onnx.helper.make_node( + "MatMul", + inputs=matmul_input_names, + outputs=matmul_output_names, + name=node.name + "_matmul_Q4" if node.name else "", + ) + output_nodes.extend([dq_node, matmul_node]) logger.info(f"complete quantization of {node.name} ...") - - return matmul_q4_node + return output_nodes class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" + """ + Perform 4b quantization of constant MatMul weights. + If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the + MatMul node. + If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is + replaced by the DequantizeLinear + MatMul nodes. + """ def __init__( self, @@ -475,7 +577,8 @@ class MatMul4BitsQuantizer: is_symmetric: bool = False, accuracy_level: int | None = None, nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, + quant_format=QuantFormat.QOperator, + algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: nodes_to_exclude = [] @@ -488,7 +591,10 @@ class MatMul4BitsQuantizer: self.node_quantizer = None if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + block_size=block_size, + is_symmetric=is_symmetric, + accuracy_level=accuracy_level, + quant_format=quant_format, ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": @@ -526,15 +632,15 @@ class MatMul4BitsQuantizer: node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - out_node = None + out_nodes = [] if node.name in self.nodes_to_exclude: logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - out_node = node + out_nodes = [node] elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": - out_node = self.node_quantizer.quantize(node, graph_stack) + out_nodes = self.node_quantizer.quantize(node, graph_stack) else: - out_node = self.node_quantizer.quantize(node, graph_stack) - new_nodes.append(out_node) + out_nodes = self.node_quantizer.quantize(node, graph_stack) + new_nodes.extend(out_nodes) graph.ClearField("node") graph.node.extend(new_nodes) @@ -688,6 +794,15 @@ set of 4b integers with a scaling factor and an optional offset. default=[], help="Specify the nodes to be excluded from quantization with node names", ) + parser.add_argument( + "--quant_format", + default="QOperator", + type=QuantFormat, + choices=list(QuantFormat), + help="QuantFormat {QOperator, QDQ}" + "QOperator format quantizes the model with quantized operators directly." + "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", + ) return parser.parse_args() @@ -699,6 +814,7 @@ if __name__ == "__main__": input_model_path = args.input_model output_model_path = args.output_model + quant_format = args.quant_format if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") @@ -713,7 +829,10 @@ if __name__ == "__main__": quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( - block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + block_size=args.block_size, + is_symmetric=args.symmetric, + accuracy_level=args.accuracy_level, + quant_format=quant_format, ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig() diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 88e5052db4..4cc8a0c151 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -14,7 +14,7 @@ from typing import Dict, Tuple, Union import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type from onnxruntime.quantization import quant_utils @@ -105,8 +105,9 @@ class TestOpMatMul4Bits(unittest.TestCase): [output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + # blocked quantization requires DQ op set >= 21 + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 # use stable onnx ir version onnx.save(model, output_model_path) @@ -116,9 +117,12 @@ class TestOpMatMul4Bits(unittest.TestCase): data_reader: TestDataFeeds, block_size: int, is_symmetric: bool, + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, ): + use_qdq = quant_format == quant_utils.QuantFormat.QDQ + name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits" model_int4_path = str( - Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + Path(self._tmp_model_dir.name).joinpath(f"{name_prefix}_{block_size}_{is_symmetric}.onnx").absolute() ) # Quantize fp32 model to int4 model @@ -126,15 +130,33 @@ class TestOpMatMul4Bits(unittest.TestCase): model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( - block_size=block_size, is_symmetric=is_symmetric + block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format ) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) - quant_nodes = {"MatMulNBits": 1} + quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} check_op_type_count(self, model_int4_path, **quant_nodes) + if use_qdq: + dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4 + dqnode_io_qtypes = ( + { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } + if is_symmetric + else { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + ) + check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes) + data_reader.rewind() try: @@ -211,6 +233,26 @@ class TestOpMatMul4Bits(unittest.TestCase): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_symmetric_qdq(self): + np.random.seed(13) + + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=True) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_offsets_qdq(self): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ) + @unittest.skipIf( find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" )