mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
clean up quantization of QAT model (#10549)
This commit is contained in:
parent
8e47bb9a4a
commit
05d6805830
3 changed files with 12 additions and 60 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from .quantize import quantize, quantize_static, quantize_dynamic, quantize_qat
|
||||
from .quantize import quantize, quantize_static, quantize_dynamic
|
||||
from .quantize import QuantizationMode
|
||||
from .calibrate import CalibrationDataReader, CalibraterBase, MinMaxCalibrater, create_calibrator, CalibrationMethod
|
||||
from .quant_utils import QuantType, QuantFormat, write_calibration_table
|
||||
|
|
|
|||
|
|
@ -170,6 +170,12 @@ class ONNXQuantizer:
|
|||
self.fuse_dynamic_quant = True
|
||||
return opset_version
|
||||
|
||||
def has_QDQ_nodes(self):
|
||||
'''
|
||||
Detect if model already has QuantizeLinear or DequantizeLinear.
|
||||
'''
|
||||
return any(node.op_type == 'QuantizeLinear' or node.op_type == 'DequantizeLinear' for node in self.model.nodes())
|
||||
|
||||
def remove_fake_quantized_nodes(self):
|
||||
'''
|
||||
Detect and remove the quantize/dequantizelinear node pairs(fake quantized nodes in Quantization-Aware training)
|
||||
|
|
@ -270,7 +276,10 @@ class ONNXQuantizer:
|
|||
self.generated_value_names.add(output_name)
|
||||
|
||||
def quantize_model(self):
|
||||
self.remove_fake_quantized_nodes()
|
||||
if self.has_QDQ_nodes():
|
||||
logging.warning(
|
||||
"Please check if the model is already quantized."
|
||||
"Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly.")
|
||||
|
||||
for node in self.model.nodes():
|
||||
# quantize subgraphes if have
|
||||
|
|
|
|||
|
|
@ -322,61 +322,4 @@ def quantize_dynamic(model_input: Path,
|
|||
extra_options)
|
||||
|
||||
quantizer.quantize_model()
|
||||
quantizer.model.save_model_to_file(model_output, use_external_data_format)
|
||||
|
||||
|
||||
def quantize_qat(model_input: Path,
|
||||
model_output: Path,
|
||||
op_types_to_quantize=[],
|
||||
per_channel=False,
|
||||
reduce_range=False,
|
||||
activation_type=QuantType.QUInt8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
nodes_to_quantize=[],
|
||||
nodes_to_exclude=[],
|
||||
use_external_data_format=False):
|
||||
'''
|
||||
Given a quantize-aware traning onnx model, create a quantized onnx model and save it into a file
|
||||
:param model_input: file path of model to quantize
|
||||
:param model_output: file path of quantized model
|
||||
:param op_types_to_quantize: specify the types of operators to quantize, like ['Conv'] to quantize Conv only. It quantizes all supported operators by default
|
||||
:param per_channel: quantize weights per channel
|
||||
:param reduce_range: quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode
|
||||
:param activation_type: quantization data type of activation
|
||||
:param nodes_to_quantize:
|
||||
List of nodes names to quantize. When this list is not None only the nodes in this list
|
||||
are quantized.
|
||||
example:
|
||||
[
|
||||
'Conv__224',
|
||||
'Conv__252'
|
||||
]
|
||||
:param nodes_to_exclude:
|
||||
List of nodes names to exclude. The nodes in this list will be excluded from quantization
|
||||
when it is not None.
|
||||
:parma use_external_data_format: option used for large size (>2GB) model. Set to False by default.
|
||||
'''
|
||||
|
||||
mode = QuantizationMode.IntegerOps
|
||||
|
||||
#optimize the original model
|
||||
optimized_model = optimize_model(Path(model_input))
|
||||
|
||||
if not op_types_to_quantize or len(op_types_to_quantize) == 0:
|
||||
op_types_to_quantize = list(IntegerOpsRegistry.keys())
|
||||
|
||||
quantizer = ONNXQuantizer(
|
||||
optimized_model,
|
||||
per_channel,
|
||||
reduce_range,
|
||||
mode,
|
||||
False, #static
|
||||
weight_type,
|
||||
activation_type,
|
||||
None,
|
||||
nodes_to_quantize,
|
||||
nodes_to_exclude,
|
||||
op_types_to_quantize)
|
||||
|
||||
quantizer.quantize_model()
|
||||
quantizer.model.save_model_to_file(model_output, use_external_data_format)
|
||||
quantizer.model.save_model_to_file(model_output, use_external_data_format)
|
||||
Loading…
Reference in a new issue