mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Add support of EmbeddingLayerNorm (#4562)
This commit is contained in:
parent
bf78e4d18b
commit
822b23ff2f
1 changed files with 11 additions and 0 deletions
|
|
@ -356,6 +356,8 @@ class ONNXQuantizer:
|
|||
new_list += self._handle_activation_ops(node, new_list)
|
||||
elif node.op_type == 'Attention':
|
||||
new_list += self._quantize_attention(node, new_list)
|
||||
elif node.op_type == 'EmbedLayerNormalization':
|
||||
new_list += self._quantize_embed_layernorm(node, new_list)
|
||||
else:
|
||||
new_list += self._handle_other_ops(node, new_list)
|
||||
|
||||
|
|
@ -1132,6 +1134,15 @@ class ONNXQuantizer:
|
|||
|
||||
return nodes
|
||||
|
||||
def _quantize_embed_layernorm(self, node, new_nodes_list):
|
||||
assert (node.op_type == "EmbedLayerNormalization")
|
||||
(quantized_input_names, zero_point_names, scale_names, nodes) = \
|
||||
self._quantize_inputs(node, [2, 3, 4], new_nodes_list)
|
||||
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
def _quantize_convolution_integer_ops(self, node, new_nodes_list):
|
||||
'''
|
||||
Used when self.mode is QuantizationMode.IntegerOps.
|
||||
|
|
|
|||
Loading…
Reference in a new issue