mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Support INT4 weight only quantize, including RTN and GPTQ 2 algorithms (#17390)
### Description Support INT4 weight only quantize (WOQ) via Intel Neural Compressor, including RTN and GPTQ 2 algorithms. **Note:** Please install `neural-compressor==2.3` for weight only quantize. ### Motivation and Context As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy. RTN is the most straightforward way to quantize weight. GPTQ algorithm provides more accurate quantization but requires more computational resources. ### Evaluation results The following table shows the accuracy results of Llama-2 models evaluated on [lambada_openai](https://huggingface.co/datasets/lambada) task. `GPTQ W4G32Asym` in configuration column means GPTQ algorithm is used for 4-bit weight only quantization, setting group_size=32 and scheme=asym. <table class="tg"> <thead> <tr> <th rowspan="2">Model name</th> <th rowspan="2">Configuration</th> <th colspan="2">Lambada_openai</th> <th rowspan="2">Accuracy Ratio<br>[WOQ/FP32]</th> </tr> <tr> <th>Accuracy</th> <th>Perplexity</th> </tr> </thead> <tbody> <tr> <td rowspan="2">meta-llama/Llama-2-7b-chat-hf</td> <td>FP32</td> <td>0.7058</td> <td>3.2788</td> <td>/</td> </tr> <tr> <td>GPTQ<br>W4G32Asym</td> <td>0.7025</td> <td>3.4489</td> <td>99.53%</td> </tr> <tr> <td rowspan="2">meta-llama/Llama-2-7b-hf</td> <td>FP32</td> <td>0.7392</td> <td>3.3950</td> <td>/</td> </tr> <tr> <td>GPTQ<br>W4G32Asym</td> <td>0.7326</td> <td>3.5286</td> <td>99.11%</td> </tr> <tr> <td rowspan="2">meta-llama/Llama-2-13b-chat-hf</td> <td>FP32</td> <td>0.7312</td> <td>2.9163</td> <td>/</td> </tr> <tr> <td>GPTQ<br>W4G128Asym</td> <td>0.7289</td> <td>3.0061</td> <td>99.56%</td> <tr> <td rowspan="2">meta-llama/Llama-2-13b-hf</td> <td>FP32</td> <td>0.7677</td> <td>3.0438</td> <td>/</td> </tr> <tr> <td>GPTQ<br>W4G32Asym</td> <td>0.7607</td> <td>3.1562</td> <td>99.09%</td> </tr> <tr> <td rowspan="2">meta-llama/Llama-2-70b-chat-hf</td> <td>FP32</td> <td>0.7543</td> <td>2.6181</td> <td>/</td> </tr> <tr> <td>RTN<br>W4G32Sym</td> <td>0.7489</td> <td>2.6850</td> <td>99.28%</td> </tr> <tr> <td rowspan="2">meta-llama/Llama-2-70b-hf</td> <td>FP32</td> <td>0.7964</td> <td>2.6612</td> <td>/</td> </tr> <tr> <td>RTN<br>W4G32Sym</td> <td>0.7896</td> <td>2.7546</td> <td>99.15%</td> </tr> </tbody> </table> --------- Signed-off-by: yuwenzho <yuwen.zhou@intel.com> Co-authored-by: Wang, Mengni <mengni.wang@intel.com>
This commit is contained in:
parent
df116b82c7
commit
731b50dfc4
3 changed files with 244 additions and 22 deletions
|
|
@ -7,6 +7,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
|
@ -14,9 +16,11 @@ import numpy as np
|
|||
import numpy.typing as npt
|
||||
import onnx
|
||||
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime.capi._pybind_state import quantize_matmul_4bits
|
||||
|
||||
from .calibrate import CalibrationDataReader
|
||||
from .onnx_model import ONNXModel
|
||||
from .quant_utils import attribute_to_kwarg
|
||||
|
||||
|
|
@ -24,24 +28,98 @@ logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s",
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeightOnlyQuantConfig:
|
||||
def __init__(self, algorithm):
|
||||
"""This is the Base class for Weight Only Quant Configuration.
|
||||
|
||||
Args:
|
||||
algorithm:
|
||||
weight only quantize algorithm name.
|
||||
"""
|
||||
self.algorithm = algorithm
|
||||
|
||||
|
||||
class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
||||
def __init__(
|
||||
self,
|
||||
ratios=None,
|
||||
):
|
||||
"""
|
||||
This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
|
||||
RTN is the most straightforward way to quantize weight using scale maps.
|
||||
|
||||
Args:
|
||||
ratios:
|
||||
percentile of clip. Defaults to {}.
|
||||
"""
|
||||
if ratios is None:
|
||||
ratios = {}
|
||||
super().__init__(
|
||||
algorithm="RTN",
|
||||
)
|
||||
self.ratios = ratios
|
||||
|
||||
|
||||
class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
|
||||
def __init__(
|
||||
self,
|
||||
calibration_data_reader: CalibrationDataReader,
|
||||
percdamp=0.01,
|
||||
blocksize=128,
|
||||
actorder=False,
|
||||
mse=False,
|
||||
perchannel=True,
|
||||
):
|
||||
"""
|
||||
This is a class for GPTQ algorithm Weight Only Quant Configuration.
|
||||
GPTQ algorithm provides more accurate quantization but requires more computational resources.
|
||||
|
||||
Args:
|
||||
calibration_data_reader:
|
||||
a calibration data reader. It enumerates calibration data and generates inputs for the original model.
|
||||
percdamp:
|
||||
percent of the average Hessian diagonal to use for dampening.
|
||||
blocksize (int, optional):
|
||||
channel number in one block to execute a GPTQ quantization iteration.
|
||||
actorder (bool, optional):
|
||||
whether rearrange Hessian matrix considering the diag's value.
|
||||
mse (bool, optional):
|
||||
whether get scale and zero point with mse error.
|
||||
perchannel (bool, optional):
|
||||
whether quantize weight per-channel.
|
||||
"""
|
||||
super().__init__(
|
||||
algorithm="GPTQ",
|
||||
)
|
||||
self.calibration_data_reader = calibration_data_reader
|
||||
self.percdamp = percdamp
|
||||
self.blocksize = blocksize
|
||||
self.actorder = actorder
|
||||
self.mse = mse
|
||||
self.perchannel = perchannel
|
||||
|
||||
|
||||
class MatMul4BitsQuantizer:
|
||||
"""Perform 4b quantization of constant MatMul weights"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ModelProto,
|
||||
model: ModelProto | str,
|
||||
block_size: int,
|
||||
is_symmetric: bool,
|
||||
accuracy_level: int | None = None,
|
||||
nodes_to_exclude: list[str] | None = None,
|
||||
nodes_to_exclude=None,
|
||||
algo_config: WeightOnlyQuantConfig = None,
|
||||
):
|
||||
if nodes_to_exclude is None:
|
||||
nodes_to_exclude = []
|
||||
self.model = ONNXModel(model)
|
||||
self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
|
||||
self.model_path = model if isinstance(model, str) else None
|
||||
self.block_size = block_size
|
||||
self.is_symmetric = is_symmetric
|
||||
self.accuracy_level = accuracy_level
|
||||
self.nodes_to_exclude = set(nodes_to_exclude)
|
||||
self.algo_config = algo_config
|
||||
|
||||
@staticmethod
|
||||
def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
|
||||
|
|
@ -176,20 +254,99 @@ class MatMul4BitsQuantizer:
|
|||
graph_stack.pop()
|
||||
return graph
|
||||
|
||||
def _generate_q4_node_config(self):
|
||||
"""Generate weight only quant configuration for nodes."""
|
||||
q4_node_config = {}
|
||||
template_config_q4 = {
|
||||
"bits": 4,
|
||||
"group_size": self.block_size,
|
||||
"scheme": "sym" if self.is_symmetric else "asym",
|
||||
}
|
||||
for node in self.model.model.graph.node:
|
||||
if node.op_type in ["MatMul"]:
|
||||
if not all([self.model.get_initializer(i) is None for i in node.input]):
|
||||
q4_node_config[node.name] = template_config_q4
|
||||
return q4_node_config
|
||||
|
||||
def int4_quant_algo(self):
|
||||
"""4b quantize a model with RTN or GPTQ algorithm. Please refer to
|
||||
https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md
|
||||
for more details on weight only quantization using Intel® Neural Compressor.
|
||||
"""
|
||||
|
||||
def inc_dataloader():
|
||||
data_reader = copy.deepcopy(self.algo_config.calibration_data_reader)
|
||||
for data in data_reader:
|
||||
yield data, None
|
||||
|
||||
kwargs = {}
|
||||
if self.accuracy_level is not None:
|
||||
kwargs["accuracy_level"] = self.accuracy_level
|
||||
weight_only_node_config = self._generate_q4_node_config()
|
||||
|
||||
algorithm = self.algo_config.algorithm
|
||||
logger.info(f"start to quantize model with {algorithm} algorithm...")
|
||||
if algorithm == "RTN":
|
||||
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
|
||||
|
||||
kwargs["ratios"] = self.algo_config.ratios
|
||||
|
||||
self.model = rtn_quantize(
|
||||
model=self.model_path if self.model_path is not None else self.model.model,
|
||||
weight_config=weight_only_node_config,
|
||||
**kwargs,
|
||||
)
|
||||
elif algorithm == "GPTQ":
|
||||
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
|
||||
|
||||
kwargs["percdamp"] = self.algo_config.percdamp
|
||||
kwargs["blocksize"] = self.algo_config.blocksize
|
||||
kwargs["actorder"] = self.algo_config.actorder
|
||||
kwargs["mse"] = self.algo_config.mse
|
||||
kwargs["perchannel"] = self.algo_config.perchannel
|
||||
kwargs["n_samples"] = -1
|
||||
dataloader = inc_dataloader()
|
||||
|
||||
self.model = gptq_quantize(
|
||||
model=self.model_path if self.model_path is not None else self.model.model,
|
||||
weight_config=weight_only_node_config,
|
||||
dataloader=dataloader,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info(f"complete quantization of model with {algorithm} algorithm.")
|
||||
|
||||
def process(self):
|
||||
# use a stack to keep track of sub-graphs
|
||||
graph_stack = [self.model.graph()]
|
||||
opset_import = self.model.opset_import()
|
||||
if self.algo_config is None:
|
||||
# use a stack to keep track of sub-graphs
|
||||
graph_stack = [self.model.graph()]
|
||||
opset_import = self.model.opset_import()
|
||||
|
||||
has_ms_domain = False
|
||||
for opset in opset_import:
|
||||
if opset.domain == "com.microsoft":
|
||||
has_ms_domain = True
|
||||
if not has_ms_domain:
|
||||
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
|
||||
has_ms_domain = False
|
||||
for opset in opset_import:
|
||||
if opset.domain == "com.microsoft":
|
||||
has_ms_domain = True
|
||||
if not has_ms_domain:
|
||||
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
|
||||
|
||||
self._process_subgraph(graph_stack)
|
||||
self.model.clean_initializers()
|
||||
self._process_subgraph(graph_stack)
|
||||
self.model.clean_initializers()
|
||||
else:
|
||||
# use Intel® Neural Compressor for RTN or GPTQ weight-only quantize algorithm
|
||||
try:
|
||||
importlib.import_module("neural_compressor")
|
||||
except Exception as e:
|
||||
logging.error(f"{e}.")
|
||||
raise RuntimeError(
|
||||
"neural-compressor is not correctly installed. Please check your environment."
|
||||
) from e
|
||||
|
||||
import neural_compressor
|
||||
|
||||
assert version.parse(neural_compressor.__version__) >= version.parse(
|
||||
"2.3.2"
|
||||
), "Require neural-compressor >= 2.3.2 to support weight only quantization!"
|
||||
|
||||
self.int4_quant_algo()
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
|
|||
|
|
@ -466,7 +466,6 @@ def quantize_static(
|
|||
|
||||
import copy
|
||||
|
||||
import onnx
|
||||
from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant
|
||||
|
||||
def inc_dataloader():
|
||||
|
|
@ -478,13 +477,11 @@ def quantize_static(
|
|||
dataloader = inc_dataloader()
|
||||
sq = ORTSmoothQuant(model_input, dataloader, reduce_range)
|
||||
del dataloader
|
||||
model = sq.transform(
|
||||
extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)
|
||||
).model
|
||||
nodes_to_exclude.extend([i.name for i in model.graph.node if i.name not in orig_nodes])
|
||||
model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True))
|
||||
sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.")
|
||||
model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
|
||||
onnx.save_model(model, model_input, save_as_external_data=True)
|
||||
model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix()
|
||||
model.save(model_input)
|
||||
nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes])
|
||||
model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
|
||||
|
|
|
|||
|
|
@ -71,13 +71,16 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
output_name = "output"
|
||||
initializers = []
|
||||
|
||||
def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str):
|
||||
def make_matmul(
|
||||
input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str, node_name: str
|
||||
):
|
||||
weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32)
|
||||
initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name))
|
||||
return onnx.helper.make_node(
|
||||
"MatMul",
|
||||
[input_name, weight_name],
|
||||
[output_name],
|
||||
node_name,
|
||||
)
|
||||
|
||||
in_features = 52
|
||||
|
|
@ -88,6 +91,7 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
[in_features, out_features],
|
||||
"linear1.weight",
|
||||
output_name,
|
||||
"MatMul_0",
|
||||
)
|
||||
|
||||
# make graph
|
||||
|
|
@ -139,6 +143,48 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
else:
|
||||
raise exception
|
||||
|
||||
def quant_test_with_algo(
|
||||
self,
|
||||
algorithm: str,
|
||||
model_fp32_path: str,
|
||||
data_reader: TestDataFeeds,
|
||||
block_size: int,
|
||||
is_symmetric: bool,
|
||||
):
|
||||
model_int4_path = str(
|
||||
Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute()
|
||||
)
|
||||
|
||||
# Quantize fp32 model to int4 model
|
||||
from onnxruntime.quantization import matmul_4bits_quantizer
|
||||
|
||||
algo_config = None
|
||||
if algorithm == "RTN":
|
||||
# test RTN algorithm
|
||||
algo_config = matmul_4bits_quantizer.RTNWeightOnlyQuantConfig()
|
||||
elif algorithm == "GPTQ":
|
||||
# test GPTQ algorithm
|
||||
algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
|
||||
|
||||
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
|
||||
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(model_int4_path, False)
|
||||
|
||||
quant_nodes = {"MatMulNBits": 1}
|
||||
check_op_type_count(self, model_int4_path, **quant_nodes)
|
||||
|
||||
data_reader.rewind()
|
||||
|
||||
try:
|
||||
check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next())
|
||||
except Exception as exception:
|
||||
if "4b quantization not yet supported on this hardware platform!" in exception.args[0]:
|
||||
# Currently we don't have int4 quantization support on all platforms, has to tolerate this exception
|
||||
pass
|
||||
else:
|
||||
raise exception
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
|
||||
)
|
||||
|
|
@ -159,6 +205,28 @@ class TestOpMatMul4Bits(unittest.TestCase):
|
|||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
self.quant_test(model_fp32_path, data_reader, 32, False)
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
|
||||
)
|
||||
def test_quantize_matmul_int4_using_rtn_algo(self):
|
||||
if not find_spec("neural_compressor"):
|
||||
self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
|
||||
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
|
||||
self.construct_model_matmul(model_fp32_path, symmetric=False)
|
||||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False)
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
|
||||
)
|
||||
def test_quantize_matmul_int4_using_gptq_algo(self):
|
||||
if not find_spec("neural_compressor"):
|
||||
self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
|
||||
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
|
||||
self.construct_model_matmul(model_fp32_path, symmetric=False)
|
||||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue