mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
matmul 4bit tool chain support qdq (#21362)
### Description This is a partial change ported from fajin/qdqmatmulnbitstoolchain. That branch has issues resolving the web CI. MatMulNBits is a heavily optimized matmul operation. Currently a MatMul can be converted to MatMulNBits to speed up the model inference. However, MatMulNBits is an ORT only op. To make the graph compatible with ONNX ops and utilize MatMulNBits at the same time, we introduce Q/DQ support for MatMulNBits. To convert MatMul ops in a model to MatMulNBits: use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using QDQ mode. In ORT session, DQ + MatMul is fused to MatMulNBits #### Note MatMulNBits assume B weight is uint4. When no zp is provided, zp defaults to 8, which is different from DQ. DQ defaults zp to 0 when no zp provided. And DQ supports int4. Therefore some conversions are introduced during DQ + MatMul --> MatMulNBits step. #### Perf Using QDQ format will increase the model initialization time and memory consumption. With current implement, model init time increased from ~4s to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB. The memory increase is due to 1. in optimizer, after transpose the B weight, a in-memory tensor proto is created using protobuf's arena. 2. in finalize step, when saving initializer and prepacking, ORT arena is used to create buffers for initializers. The memory allocated by arenas cannot be fully deallocated. If disable ORT arena memory allocation, the memory consumptions of both QDQ format and original format are ~2.2GB. The time increase is mainly due to multiple memory copy, but can be further optimized. ### Motivation and Context Please see description for details.
This commit is contained in:
parent
8568a67673
commit
5df4ddd1c3
2 changed files with 244 additions and 83 deletions
|
|
@ -18,31 +18,36 @@ import onnx
|
|||
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime.capi._pybind_state import quantize_matmul_4bits
|
||||
from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits
|
||||
|
||||
from .calibrate import CalibrationDataReader
|
||||
from .onnx_model import ONNXModel
|
||||
from .quant_utils import attribute_to_kwarg
|
||||
from .quant_utils import QuantFormat, attribute_to_kwarg
|
||||
|
||||
logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeightOnlyQuantConfig:
|
||||
def __init__(self, algorithm):
|
||||
def __init__(self, algorithm, quant_format):
|
||||
"""This is the Base class for Weight Only Quant Configuration.
|
||||
|
||||
Args:
|
||||
algorithm:
|
||||
weight only quantize algorithm name.
|
||||
quant_format: QuantFormat{QOperator, QDQ}.
|
||||
QOperator format quantizes the model with quantized operators directly.
|
||||
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
|
||||
"""
|
||||
self.algorithm = algorithm
|
||||
self.quant_format = quant_format
|
||||
|
||||
|
||||
class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
||||
def __init__(
|
||||
self,
|
||||
ratios=None,
|
||||
quant_format=QuantFormat.QOperator,
|
||||
):
|
||||
"""
|
||||
This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
|
||||
|
|
@ -51,11 +56,18 @@ class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
|||
Args:
|
||||
ratios:
|
||||
percentile of clip. Defaults to {}.
|
||||
quant_format (QuantFormat{QOperator, QDQ}, optional):
|
||||
QOperator format quantizes the model with quantized operators directly.
|
||||
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
|
||||
Defaults to QuantFormat.QOperator.
|
||||
"""
|
||||
assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
|
||||
|
||||
if ratios is None:
|
||||
ratios = {}
|
||||
super().__init__(
|
||||
algorithm="RTN",
|
||||
quant_format=quant_format,
|
||||
)
|
||||
self.ratios = ratios
|
||||
|
||||
|
|
@ -69,6 +81,7 @@ class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
|||
actorder=False,
|
||||
mse=False,
|
||||
perchannel=True,
|
||||
quant_format=QuantFormat.QOperator,
|
||||
):
|
||||
"""
|
||||
This is a class for GPTQ algorithm Weight Only Quant Configuration.
|
||||
|
|
@ -87,9 +100,16 @@ class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
|||
whether get scale and zero point with mse error.
|
||||
perchannel (bool, optional):
|
||||
whether quantize weight per-channel.
|
||||
quant_format (QuantFormat{QOperator, QDQ}, optional):
|
||||
QOperator format quantizes the model with quantized operators directly.
|
||||
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
|
||||
Defaults to QuantFormat.QOperator.
|
||||
"""
|
||||
assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format"
|
||||
|
||||
super().__init__(
|
||||
algorithm="GPTQ",
|
||||
quant_format=quant_format,
|
||||
)
|
||||
self.calibration_data_reader = calibration_data_reader
|
||||
self.percdamp = percdamp
|
||||
|
|
@ -105,6 +125,7 @@ class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
|||
block_size=128,
|
||||
bits=4,
|
||||
axis=1,
|
||||
quant_format=QuantFormat.QOperator,
|
||||
):
|
||||
"""
|
||||
This is a class for HQQ algorithm Weight Only Quant Configuration.
|
||||
|
|
@ -112,14 +133,21 @@ class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
|||
|
||||
Args:
|
||||
block_size (int, optional):
|
||||
channel number in one block to execute a GPTQ quantization iteration.
|
||||
channel number in one block to execute a HQQ quantization iteration.
|
||||
bits (int, optional):
|
||||
how many bits to represent weight.
|
||||
axis (int, optional):
|
||||
0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
|
||||
quant_format (QuantFormat{QOperator, QDQ}, optional):
|
||||
QOperator format quantizes the model with quantized operators directly.
|
||||
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
|
||||
Defaults to QuantFormat.QOperator.
|
||||
"""
|
||||
assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format"
|
||||
|
||||
super().__init__(
|
||||
algorithm="HQQ",
|
||||
quant_format=quant_format,
|
||||
)
|
||||
self.block_size = block_size
|
||||
self.bits = bits
|
||||
|
|
@ -132,8 +160,26 @@ class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
|||
block_size: int = 128,
|
||||
is_symmetric: bool = False,
|
||||
accuracy_level: int | None = None,
|
||||
quant_format=QuantFormat.QOperator,
|
||||
):
|
||||
super().__init__(algorithm="DEFAULT")
|
||||
"""
|
||||
This is a class for weight only affine quantization configuration.
|
||||
|
||||
Args:
|
||||
block_size (int, optional):
|
||||
channel number in one block to execute an affine quantization iteration.
|
||||
is_symmetric (bool, optional):
|
||||
whether quantize weight symmetrically.
|
||||
accuracy_level (int, optional):
|
||||
Accuracy level of the 4-bit quantized MatMul computation.
|
||||
Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details.
|
||||
(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)
|
||||
quant_format (QuantFormat{QOperator, QDQ}, optional):
|
||||
QOperator format quantizes the model with quantized operators directly.
|
||||
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
|
||||
Defaults to QuantFormat.QOperator.
|
||||
"""
|
||||
super().__init__(algorithm="DEFAULT", quant_format=quant_format)
|
||||
self.block_size = block_size
|
||||
self.is_symmetric = is_symmetric
|
||||
self.bits = 4
|
||||
|
|
@ -287,23 +333,26 @@ class HQQWeightOnlyQuantizer:
|
|||
|
||||
return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype)
|
||||
|
||||
def quantize(self, node: NodeProto, graph_stack: list[GraphProto]):
|
||||
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
|
||||
def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
|
||||
"""
|
||||
If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node.
|
||||
If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul.
|
||||
"""
|
||||
if node.op_type != "MatMul":
|
||||
return node # only care about MatMul for now
|
||||
return [node] # only care about MatMul for now
|
||||
import torch
|
||||
|
||||
logger.info(f"start to quantize {node.name} ...")
|
||||
inputB = node.input[1] # noqa: N806
|
||||
b_pb, bs_graph = get_initializer(inputB, graph_stack)
|
||||
input_b = node.input[1]
|
||||
b_pb, bs_graph = get_initializer(input_b, graph_stack)
|
||||
if b_pb is None:
|
||||
logger.info("MatMul doesn't have const weight. Skip to quantize")
|
||||
return node # only care about constant weight
|
||||
return [node] # only care about constant weight
|
||||
|
||||
b_array = onnx.numpy_helper.to_array(b_pb)
|
||||
if len(b_array.shape) != 2:
|
||||
logger.info("MatMul weight is not 2D. Skip to quantize")
|
||||
return node # can only process 2-D matrix
|
||||
return [node] # can only process 2-D matrix
|
||||
b_array_torch = torch.from_numpy(b_array)
|
||||
if torch.cuda.is_available():
|
||||
b_array_torch = b_array_torch.cuda()
|
||||
|
|
@ -334,7 +383,7 @@ class HQQWeightOnlyQuantizer:
|
|||
b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
|
||||
b_quant.name = b_pb.name + "_Q4"
|
||||
for input in bs_graph.input:
|
||||
if input.name == inputB:
|
||||
if input.name == input_b:
|
||||
bs_graph.input.remove(input)
|
||||
break
|
||||
|
||||
|
|
@ -366,7 +415,7 @@ class HQQWeightOnlyQuantizer:
|
|||
|
||||
logger.info(f"complete quantization of {node.name} ...")
|
||||
|
||||
return matmul_q4_node
|
||||
return [matmul_q4_node]
|
||||
|
||||
|
||||
def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
|
||||
|
|
@ -382,7 +431,7 @@ class DefaultWeightOnlyQuantizer:
|
|||
def __init__(self, config: DefaultWeightOnlyQuantConfig):
|
||||
self.config = config
|
||||
|
||||
def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
|
||||
def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""4b quantize fp32 weight to a blob"""
|
||||
|
||||
if len(fp32weight.shape) != 2:
|
||||
|
|
@ -390,83 +439,136 @@ class DefaultWeightOnlyQuantizer:
|
|||
rows, cols = fp32weight.shape
|
||||
|
||||
block_size = self.config.block_size
|
||||
blob_size = block_size // 2
|
||||
k_blocks = (rows + block_size - 1) // block_size
|
||||
padded_rows = k_blocks * block_size
|
||||
pad_len = padded_rows - rows
|
||||
if pad_len > 0:
|
||||
fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")
|
||||
|
||||
# block wise quantization, each block comes from a single column
|
||||
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
|
||||
scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
|
||||
zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8")
|
||||
quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric)
|
||||
if self.config.quant_format == QuantFormat.QOperator:
|
||||
blob_size = block_size // 2
|
||||
padded_rows = k_blocks * block_size
|
||||
pad_len = padded_rows - rows
|
||||
if pad_len > 0:
|
||||
fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")
|
||||
|
||||
# block wise quantization, each block comes from a single column
|
||||
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
|
||||
zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8")
|
||||
scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
|
||||
quantize_matmul_4bits(
|
||||
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
|
||||
)
|
||||
else:
|
||||
packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
|
||||
zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
|
||||
scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype)
|
||||
quantize_qdq_matmul_4bits(
|
||||
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
|
||||
)
|
||||
|
||||
return (packed, scales, zero_point)
|
||||
|
||||
def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
|
||||
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
|
||||
def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
|
||||
"""
|
||||
If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node.
|
||||
If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul.
|
||||
"""
|
||||
|
||||
if node.op_type != "MatMul":
|
||||
return node # only care about MatMul for now
|
||||
return [node] # only care about MatMul for now
|
||||
|
||||
logger.info(f"start to quantize {node.name} ...")
|
||||
inputB = node.input[1] # noqa: N806
|
||||
B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806
|
||||
if B is None:
|
||||
qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
|
||||
input_b = node.input[1]
|
||||
b_tensor, b_graph = get_initializer(input_b, graph_stack)
|
||||
if b_tensor is None:
|
||||
logger.info("MatMul doesn't have const weight. Skip to quantize")
|
||||
return node # only care about constant weight
|
||||
return [node] # only care about constant weight
|
||||
|
||||
B_array = onnx.numpy_helper.to_array(B) # noqa: N806
|
||||
if len(B_array.shape) != 2:
|
||||
b_ndarray = onnx.numpy_helper.to_array(b_tensor)
|
||||
if len(b_ndarray.shape) != 2:
|
||||
logger.info("MatMul weight is not 2D. Skip to quantize")
|
||||
return node # can only process 2-D matrix
|
||||
return [node] # can only process 2-D matrix
|
||||
|
||||
packed, scales, zero_points = self.int4_block_quant(B_array)
|
||||
B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
|
||||
B_quant.name = B.name + "_Q4"
|
||||
for input in Bs_graph.input:
|
||||
if input.name == inputB:
|
||||
Bs_graph.input.remove(input)
|
||||
packed, scales, zero_points = self.int4_block_quant(b_ndarray)
|
||||
|
||||
if self.config.quant_format == QuantFormat.QOperator:
|
||||
b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4")
|
||||
scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales")
|
||||
else:
|
||||
b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True)
|
||||
scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales")
|
||||
|
||||
for input in b_graph.input:
|
||||
if input.name == input_b:
|
||||
b_graph.input.remove(input)
|
||||
break
|
||||
|
||||
scales_tensor = onnx.numpy_helper.from_array(scales)
|
||||
scales_tensor.name = B.name + "_scales"
|
||||
Bs_graph.initializer.extend([B_quant, scales_tensor])
|
||||
b_graph.initializer.extend([b_quant, scales_tensor])
|
||||
|
||||
input_names = [node.input[0], B_quant.name, scales_tensor.name]
|
||||
if not self.config.is_symmetric:
|
||||
zp_tensor = onnx.numpy_helper.from_array(zero_points)
|
||||
zp_tensor.name = B.name + "_zero_points"
|
||||
Bs_graph.initializer.extend([zp_tensor])
|
||||
input_names.append(zp_tensor.name)
|
||||
output_nodes = []
|
||||
|
||||
kwargs = {}
|
||||
rows, cols = B_array.shape
|
||||
kwargs["K"] = rows
|
||||
kwargs["N"] = cols
|
||||
kwargs["bits"] = 4
|
||||
kwargs["block_size"] = self.config.block_size
|
||||
if self.config.accuracy_level is not None:
|
||||
kwargs["accuracy_level"] = self.config.accuracy_level
|
||||
if self.config.quant_format == QuantFormat.QOperator:
|
||||
input_names = [node.input[0], b_quant.name, scales_tensor.name]
|
||||
if not self.config.is_symmetric:
|
||||
zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points")
|
||||
input_names.append(zp_tensor.name)
|
||||
b_graph.initializer.extend([zp_tensor])
|
||||
kwargs = {}
|
||||
rows, cols = b_ndarray.shape
|
||||
kwargs["K"] = rows
|
||||
kwargs["N"] = cols
|
||||
kwargs["bits"] = 4
|
||||
kwargs["block_size"] = self.config.block_size
|
||||
if self.config.accuracy_level is not None:
|
||||
kwargs["accuracy_level"] = self.config.accuracy_level
|
||||
|
||||
matmul_q4_node = onnx.helper.make_node(
|
||||
"MatMulNBits",
|
||||
inputs=input_names,
|
||||
outputs=[node.output[0]],
|
||||
name=node.name + "_Q4" if node.name else "",
|
||||
domain="com.microsoft",
|
||||
**kwargs,
|
||||
)
|
||||
matmul_q4_node = onnx.helper.make_node(
|
||||
"MatMulNBits",
|
||||
inputs=input_names,
|
||||
outputs=[node.output[0]],
|
||||
name=node.name + "_Q4" if node.name else "",
|
||||
domain="com.microsoft",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output_nodes.append(matmul_q4_node)
|
||||
else:
|
||||
dq_input_names = [b_quant.name, scales_tensor.name]
|
||||
dq_output_names = [b_quant.name + "_output"]
|
||||
matmul_input_names = [node.input[0], dq_output_names[0]]
|
||||
matmul_output_names = [node.output[0]]
|
||||
if not self.config.is_symmetric:
|
||||
zp_tensor = onnx.helper.make_tensor(
|
||||
b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True
|
||||
)
|
||||
dq_input_names.append(zp_tensor.name)
|
||||
b_graph.initializer.extend([zp_tensor])
|
||||
dq_kwargs = {"axis": 0, "block_size": self.config.block_size}
|
||||
dq_node = onnx.helper.make_node(
|
||||
"DequantizeLinear",
|
||||
inputs=dq_input_names,
|
||||
outputs=dq_output_names,
|
||||
name=node.name + "_DQ_Q4" if node.name else "",
|
||||
**dq_kwargs,
|
||||
)
|
||||
matmul_node = onnx.helper.make_node(
|
||||
"MatMul",
|
||||
inputs=matmul_input_names,
|
||||
outputs=matmul_output_names,
|
||||
name=node.name + "_matmul_Q4" if node.name else "",
|
||||
)
|
||||
output_nodes.extend([dq_node, matmul_node])
|
||||
|
||||
logger.info(f"complete quantization of {node.name} ...")
|
||||
|
||||
return matmul_q4_node
|
||||
return output_nodes
|
||||
|
||||
|
||||
class MatMul4BitsQuantizer:
|
||||
"""Perform 4b quantization of constant MatMul weights"""
|
||||
"""
|
||||
Perform 4b quantization of constant MatMul weights.
|
||||
If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the
|
||||
MatMul node.
|
||||
If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is
|
||||
replaced by the DequantizeLinear + MatMul nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -475,7 +577,8 @@ class MatMul4BitsQuantizer:
|
|||
is_symmetric: bool = False,
|
||||
accuracy_level: int | None = None,
|
||||
nodes_to_exclude=None,
|
||||
algo_config: WeightOnlyQuantConfig = None,
|
||||
quant_format=QuantFormat.QOperator,
|
||||
algo_config: WeightOnlyQuantConfig | None = None,
|
||||
):
|
||||
if nodes_to_exclude is None:
|
||||
nodes_to_exclude = []
|
||||
|
|
@ -488,7 +591,10 @@ class MatMul4BitsQuantizer:
|
|||
self.node_quantizer = None
|
||||
if algo_config is None:
|
||||
algo_config = DefaultWeightOnlyQuantConfig(
|
||||
block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level
|
||||
block_size=block_size,
|
||||
is_symmetric=is_symmetric,
|
||||
accuracy_level=accuracy_level,
|
||||
quant_format=quant_format,
|
||||
)
|
||||
self.algo_config = algo_config
|
||||
if algo_config.algorithm == "HQQ":
|
||||
|
|
@ -526,15 +632,15 @@ class MatMul4BitsQuantizer:
|
|||
node = onnx.helper.make_node( # noqa: PLW2901
|
||||
node.op_type, node.input, node.output, name=node.name, **kwargs
|
||||
)
|
||||
out_node = None
|
||||
out_nodes = []
|
||||
if node.name in self.nodes_to_exclude:
|
||||
logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
|
||||
out_node = node
|
||||
out_nodes = [node]
|
||||
elif self.algo_config is not None and self.algo_config.algorithm == "HQQ":
|
||||
out_node = self.node_quantizer.quantize(node, graph_stack)
|
||||
out_nodes = self.node_quantizer.quantize(node, graph_stack)
|
||||
else:
|
||||
out_node = self.node_quantizer.quantize(node, graph_stack)
|
||||
new_nodes.append(out_node)
|
||||
out_nodes = self.node_quantizer.quantize(node, graph_stack)
|
||||
new_nodes.extend(out_nodes)
|
||||
|
||||
graph.ClearField("node")
|
||||
graph.node.extend(new_nodes)
|
||||
|
|
@ -688,6 +794,15 @@ set of 4b integers with a scaling factor and an optional offset.
|
|||
default=[],
|
||||
help="Specify the nodes to be excluded from quantization with node names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quant_format",
|
||||
default="QOperator",
|
||||
type=QuantFormat,
|
||||
choices=list(QuantFormat),
|
||||
help="QuantFormat {QOperator, QDQ}"
|
||||
"QOperator format quantizes the model with quantized operators directly."
|
||||
"QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -699,6 +814,7 @@ if __name__ == "__main__":
|
|||
|
||||
input_model_path = args.input_model
|
||||
output_model_path = args.output_model
|
||||
quant_format = args.quant_format
|
||||
|
||||
if os.path.exists(output_model_path):
|
||||
logger.error(f"file {output_model_path} already exists")
|
||||
|
|
@ -713,7 +829,10 @@ if __name__ == "__main__":
|
|||
quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits)
|
||||
elif args.quant_method == "default":
|
||||
quant_config = DefaultWeightOnlyQuantConfig(
|
||||
block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level
|
||||
block_size=args.block_size,
|
||||
is_symmetric=args.symmetric,
|
||||
accuracy_level=args.accuracy_level,
|
||||
quant_format=quant_format,
|
||||
)
|
||||
elif args.quant_method == "rtn":
|
||||
quant_config = RTNWeightOnlyQuantConfig()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from typing import Dict, Tuple, Union
|
|||
import numpy as np
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count
|
||||
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type
|
||||
|
||||
from onnxruntime.quantization import quant_utils
|
||||
|
||||
|
|
@ -105,8 +105,9 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
[output_tensor],
|
||||
initializer=initializers,
|
||||
)
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
model.ir_version = 7 # use stable onnx ir version
|
||||
# blocked quantization requires DQ op set >= 21
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)])
|
||||
model.ir_version = 10 # use stable onnx ir version
|
||||
|
||||
onnx.save(model, output_model_path)
|
||||
|
||||
|
|
@ -116,9 +117,12 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
data_reader: TestDataFeeds,
|
||||
block_size: int,
|
||||
is_symmetric: bool,
|
||||
quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator,
|
||||
):
|
||||
use_qdq = quant_format == quant_utils.QuantFormat.QDQ
|
||||
name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits"
|
||||
model_int4_path = str(
|
||||
Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute()
|
||||
Path(self._tmp_model_dir.name).joinpath(f"{name_prefix}_{block_size}_{is_symmetric}.onnx").absolute()
|
||||
)
|
||||
|
||||
# Quantize fp32 model to int4 model
|
||||
|
|
@ -126,15 +130,33 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
|
||||
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
|
||||
quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
|
||||
block_size=block_size, is_symmetric=is_symmetric
|
||||
block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format
|
||||
)
|
||||
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config)
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(model_int4_path, False)
|
||||
|
||||
quant_nodes = {"MatMulNBits": 1}
|
||||
quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1}
|
||||
check_op_type_count(self, model_int4_path, **quant_nodes)
|
||||
|
||||
if use_qdq:
|
||||
dq_qtype = TensorProto.INT4 if is_symmetric else TensorProto.UINT4
|
||||
dqnode_io_qtypes = (
|
||||
{
|
||||
"DequantizeLinear": [
|
||||
["i", 0, dq_qtype],
|
||||
]
|
||||
}
|
||||
if is_symmetric
|
||||
else {
|
||||
"DequantizeLinear": [
|
||||
["i", 0, dq_qtype],
|
||||
["i", 2, dq_qtype],
|
||||
]
|
||||
}
|
||||
)
|
||||
check_qtype_by_node_type(self, model_int4_path, dqnode_io_qtypes)
|
||||
|
||||
data_reader.rewind()
|
||||
|
||||
try:
|
||||
|
|
@ -211,6 +233,26 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
self.quant_test(model_fp32_path, data_reader, 32, False)
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
|
||||
)
|
||||
def test_quantize_matmul_int4_symmetric_qdq(self):
|
||||
np.random.seed(13)
|
||||
|
||||
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute())
|
||||
self.construct_model_matmul(model_fp32_path, symmetric=True)
|
||||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ)
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
|
||||
)
|
||||
def test_quantize_matmul_int4_offsets_qdq(self):
|
||||
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
|
||||
self.construct_model_matmul(model_fp32_path, symmetric=False)
|
||||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ)
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue