diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index aa0ee6156f..38e3c5e30e 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -47,8 +47,6 @@ class ONNXQuantizer: is_weight_int8 = weight_qType == QuantType.QInt8 self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric'] self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric'] - self.op_types_support_per_channel_quantization = [] if 'OpTypesSupportPerChannelQuantization' not in extra_options \ - else extra_options['OpTypesSupportPerChannelQuantization'] self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 diff --git a/onnxruntime/python/tools/quantization/operators/matmul.py b/onnxruntime/python/tools/quantization/operators/matmul.py index 16ee6d12e5..2d37eeb46e 100644 --- a/onnxruntime/python/tools/quantization/operators/matmul.py +++ b/onnxruntime/python/tools/quantization/operators/matmul.py @@ -1,5 +1,7 @@ import onnx +import itertools from .base_operator import QuantOperatorBase +from .qdq_base_operator import QDQOperatorBase from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType from onnx import onnx_pb as onnx_proto ''' @@ -98,3 +100,24 @@ class QLinearMatMul(QuantOperatorBase): self.quantizer.quantized_value_map[node.output[0]] = q_output self.quantizer.new_nodes += nodes + +class QDQMatMul(QDQOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + assert (node.op_type == "MatMul") + + if self.disable_qdq_for_node_output: + nodes_to_iterate = node.input + else: + nodes_to_iterate = itertools.chain(node.input, node.output) + + for tensor_name in nodes_to_iterate: + # only support per-channel quantization on weight + if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()) : + channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1) + self.quantizer.quantize_tensor_per_channel(tensor_name, channel_axis) + else: + self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index ebe3b7c71a..f8f5546b15 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -19,10 +19,4 @@ class QDQOperatorBase: nodes_to_iterate = itertools.chain(node.input, node.output) for tensor_name in nodes_to_iterate: - if self.quantizer.is_per_channel(): - if node.op_type in self.quantizer.op_types_support_per_channel_quantization : - self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis) - else: - self.quantizer.quantize_tensor(tensor_name) - else: - self.quantizer.quantize_tensor(tensor_name) + self.quantizer.quantize_tensor(tensor_name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 423e8d5c8d..f5797282dd 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -43,10 +43,8 @@ class QDQQuantizer(ONNXQuantizer): self.op_types_to_exclude_output_quantization = [] if 'OpTypesToExcludeOutputQuantizatioin' not in extra_options \ else extra_options['OpTypesToExcludeOutputQuantizatioin'] - # In some cases, for example QDQ BERT model for TensorRT, - # QDQ should always appear as a pair. - # For our quantization tool, we do quantization on Dequantizelinear's input - # to remove Quantizelinear as optimization for weight. + # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization. + # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair. # Therefore, we need to disable this optimization and add qdq pair to weight. self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \ else extra_options['AddQDQPairToWeight'] @@ -57,8 +55,8 @@ class QDQQuantizer(ONNXQuantizer): if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} - # Channel axis when per_channel is True - self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis'] + # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. + self.qdq_op_type_per_channel_support_to_axis = {} if 'QDQOpTypePerChannelSupportToAxis' not in extra_options else extra_options['QDQOpTypePerChannelSupportToAxis'] def quantize_tensor(self, tensor_name): weight = find_by_name(tensor_name, self.model.initializer()) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index bc0a57a425..955a74e525 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -200,6 +200,10 @@ def quantize_static(model_input, the output of ops with this specific op types. DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. If True, it will create identical and dedicated QDQ pair for each node. + QDQOpTypePerChannelSupportToAxis = dictionary : Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, + and it's effective only when per channel quantization is supported and per_channel is True. + If specific op type supports per channel quantization but not explicitly specified with channel axis, + default channel axis will be used. ''' mode = QuantizationMode.QLinearOps diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 3628bd2ec9..e5da380978 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -1,7 +1,7 @@ from .quant_utils import QuantizationMode from .operators.base_operator import QuantOperatorBase from .operators.qdq_base_operator import QDQOperatorBase -from .operators.matmul import MatMulInteger, QLinearMatMul +from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul from .operators.attention import AttentionQuant from .operators.embed_layernorm import EmbedLayerNormalizationQuant from .operators.gather import GatherQuant @@ -66,6 +66,7 @@ QDQRegistry = { "MaxPool": QDQMaxPool, "AveragePool" : QDQDirect8BitOp, "Concat": QDQConcat, + "MatMul": QDQMatMul, }