mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-17 01:44:45 +00:00
QDQ tool modification part3 (#9904)
* refine per channel quantization for qdq * remove old option * add comment * add import itertools
This commit is contained in:
parent
4ff78aae45
commit
02aa16e3ea
6 changed files with 34 additions and 16 deletions
|
|
@ -47,8 +47,6 @@ class ONNXQuantizer:
|
|||
is_weight_int8 = weight_qType == QuantType.QInt8
|
||||
self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric']
|
||||
self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric']
|
||||
self.op_types_support_per_channel_quantization = [] if 'OpTypesSupportPerChannelQuantization' not in extra_options \
|
||||
else extra_options['OpTypesSupportPerChannelQuantization']
|
||||
|
||||
self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8
|
||||
self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import onnx
|
||||
import itertools
|
||||
from .base_operator import QuantOperatorBase
|
||||
from .qdq_base_operator import QDQOperatorBase
|
||||
from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
'''
|
||||
|
|
@ -98,3 +100,24 @@ class QLinearMatMul(QuantOperatorBase):
|
|||
self.quantizer.quantized_value_map[node.output[0]] = q_output
|
||||
|
||||
self.quantizer.new_nodes += nodes
|
||||
|
||||
class QDQMatMul(QDQOperatorBase):
|
||||
def __init__(self, onnx_quantizer, onnx_node):
|
||||
super().__init__(onnx_quantizer, onnx_node)
|
||||
|
||||
def quantize(self):
|
||||
node = self.node
|
||||
assert (node.op_type == "MatMul")
|
||||
|
||||
if self.disable_qdq_for_node_output:
|
||||
nodes_to_iterate = node.input
|
||||
else:
|
||||
nodes_to_iterate = itertools.chain(node.input, node.output)
|
||||
|
||||
for tensor_name in nodes_to_iterate:
|
||||
# only support per-channel quantization on weight
|
||||
if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()) :
|
||||
channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1)
|
||||
self.quantizer.quantize_tensor_per_channel(tensor_name, channel_axis)
|
||||
else:
|
||||
self.quantizer.quantize_tensor(tensor_name)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,4 @@ class QDQOperatorBase:
|
|||
nodes_to_iterate = itertools.chain(node.input, node.output)
|
||||
|
||||
for tensor_name in nodes_to_iterate:
|
||||
if self.quantizer.is_per_channel():
|
||||
if node.op_type in self.quantizer.op_types_support_per_channel_quantization :
|
||||
self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis)
|
||||
else:
|
||||
self.quantizer.quantize_tensor(tensor_name)
|
||||
else:
|
||||
self.quantizer.quantize_tensor(tensor_name)
|
||||
self.quantizer.quantize_tensor(tensor_name)
|
||||
|
|
|
|||
|
|
@ -43,10 +43,8 @@ class QDQQuantizer(ONNXQuantizer):
|
|||
self.op_types_to_exclude_output_quantization = [] if 'OpTypesToExcludeOutputQuantizatioin' not in extra_options \
|
||||
else extra_options['OpTypesToExcludeOutputQuantizatioin']
|
||||
|
||||
# In some cases, for example QDQ BERT model for TensorRT,
|
||||
# QDQ should always appear as a pair.
|
||||
# For our quantization tool, we do quantization on Dequantizelinear's input
|
||||
# to remove Quantizelinear as optimization for weight.
|
||||
# We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
|
||||
# In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
|
||||
# Therefore, we need to disable this optimization and add qdq pair to weight.
|
||||
self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \
|
||||
else extra_options['AddQDQPairToWeight']
|
||||
|
|
@ -57,8 +55,8 @@ class QDQQuantizer(ONNXQuantizer):
|
|||
if self.dedicated_qdq_pair:
|
||||
self.tensor_to_its_receiving_nodes = {}
|
||||
|
||||
# Channel axis when per_channel is True
|
||||
self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis']
|
||||
# Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
|
||||
self.qdq_op_type_per_channel_support_to_axis = {} if 'QDQOpTypePerChannelSupportToAxis' not in extra_options else extra_options['QDQOpTypePerChannelSupportToAxis']
|
||||
|
||||
def quantize_tensor(self, tensor_name):
|
||||
weight = find_by_name(tensor_name, self.model.initializer())
|
||||
|
|
|
|||
|
|
@ -200,6 +200,10 @@ def quantize_static(model_input,
|
|||
the output of ops with this specific op types.
|
||||
DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs.
|
||||
If True, it will create identical and dedicated QDQ pair for each node.
|
||||
QDQOpTypePerChannelSupportToAxis = dictionary : Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1},
|
||||
and it's effective only when per channel quantization is supported and per_channel is True.
|
||||
If specific op type supports per channel quantization but not explicitly specified with channel axis,
|
||||
default channel axis will be used.
|
||||
'''
|
||||
|
||||
mode = QuantizationMode.QLinearOps
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from .quant_utils import QuantizationMode
|
||||
from .operators.base_operator import QuantOperatorBase
|
||||
from .operators.qdq_base_operator import QDQOperatorBase
|
||||
from .operators.matmul import MatMulInteger, QLinearMatMul
|
||||
from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul
|
||||
from .operators.attention import AttentionQuant
|
||||
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
|
||||
from .operators.gather import GatherQuant
|
||||
|
|
@ -66,6 +66,7 @@ QDQRegistry = {
|
|||
"MaxPool": QDQMaxPool,
|
||||
"AveragePool" : QDQDirect8BitOp,
|
||||
"Concat": QDQConcat,
|
||||
"MatMul": QDQMatMul,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue