google/mt5 optimization and fix (#15454)

### Description
<!-- Describe your changes. -->
1. enabled self-attention fusion in mt-5 decoder graph
2. fix a parity issue
https://github.com/microsoft/onnxruntime/issues/15042


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-04-11 00:09:11 -07:00 committed by GitHub
parent c5b6ee1a99
commit ef42fd09fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 17 deletions

View file

@ -47,6 +47,9 @@ class T5DecoderInit(torch.nn.Module):
self.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
self.tie_word_embeddings = (
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
)
def forward(
self,
@ -76,7 +79,8 @@ class T5DecoderInit(torch.nn.Module):
sequence_output = decoder_outputs.last_hidden_state
present_key_values = decoder_outputs.past_key_values
sequence_output = sequence_output * (self.config.d_model**-0.5)
if self.tie_word_embeddings:
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
@ -91,6 +95,9 @@ class T5Decoder(torch.nn.Module):
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.tie_word_embeddings = (
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
)
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
@ -109,7 +116,8 @@ class T5Decoder(torch.nn.Module):
sequence_output = decoder_outputs.last_hidden_state
present_key_values = decoder_outputs.past_key_values
sequence_output = sequence_output * (self.config.d_model**-0.5)
if self.tie_word_embeddings:
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)

View file

@ -386,28 +386,39 @@ class FusionT5Attention(FusionAttention):
if "past_key_cross" not in past_key:
return
else:
k_nodes = self.model.match_parent_path(
idx, k_nodes, _ = self.model.match_parent_paths(
matmul_qk,
["Transpose", "Concat", "Reshape", "MatMul"],
[1, 0, 1, 0],
[
(["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]),
(["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]),
],
output_name_to_node,
)
past_key_transpose_node = None
present_key_transpose_nodes = None
if k_nodes is not None:
_, concat_k, reshape_k, _ = k_nodes
concat_k, reshape_k = k_nodes[1], k_nodes[-2]
key = reshape_k.input[0]
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
past_key = past_key_transpose_node.input[0]
if idx == 0:
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
past_key = past_key_transpose_node.input[0]
else:
past_key = concat_k.input[0]
if past_key in output_name_to_node:
return
if "past_key_self" not in past_key:
return
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
# print("present_key_transpose_node:", present_key_transpose_node)
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
# print("present_key_candidate:", present_key_candidate)
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
if idx == 0:
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
else:
present_key = concat_k.output[0]
if present_key is None:
return
if "present_key_self" not in present_key:
@ -583,7 +594,13 @@ class FusionSimplifiedLayerNormalization(Fusion):
[1, 1, 1, 0, 0, 0, 0],
)
if sim_ln_nodes is None:
return
sim_ln_nodes = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
[1, 1, 1, 0, 0, 0, 0],
)
if sim_ln_nodes is None:
return
pow_node = sim_ln_nodes[-2]
if self.model.find_constant_input(pow_node, 2.0) != 1: