diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 2b440390db..0ed11c7698 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -356,6 +356,8 @@ class ONNXQuantizer: new_list += self._handle_activation_ops(node, new_list) elif node.op_type == 'Attention': new_list += self._quantize_attention(node, new_list) + elif node.op_type == 'EmbedLayerNormalization': + new_list += self._quantize_embed_layernorm(node, new_list) else: new_list += self._handle_other_ops(node, new_list) @@ -1132,6 +1134,15 @@ class ONNXQuantizer: return nodes + def _quantize_embed_layernorm(self, node, new_nodes_list): + assert (node.op_type == "EmbedLayerNormalization") + (quantized_input_names, zero_point_names, scale_names, nodes) = \ + self._quantize_inputs(node, [2, 3, 4], new_nodes_list) + + nodes.append(node) + + return nodes + def _quantize_convolution_integer_ops(self, node, new_nodes_list): ''' Used when self.mode is QuantizationMode.IntegerOps.