mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
google/mt5 optimization and fix (#15454)
### Description <!-- Describe your changes. --> 1. enabled self-attention fusion in mt-5 decoder graph 2. fix a parity issue https://github.com/microsoft/onnxruntime/issues/15042 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
c5b6ee1a99
commit
ef42fd09fb
2 changed files with 42 additions and 17 deletions
|
|
@ -47,6 +47,9 @@ class T5DecoderInit(torch.nn.Module):
|
|||
self.decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
)
|
||||
self.tie_word_embeddings = (
|
||||
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -76,7 +79,8 @@ class T5DecoderInit(torch.nn.Module):
|
|||
sequence_output = decoder_outputs.last_hidden_state
|
||||
present_key_values = decoder_outputs.past_key_values
|
||||
|
||||
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
||||
if self.tie_word_embeddings:
|
||||
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
|
||||
|
|
@ -91,6 +95,9 @@ class T5Decoder(torch.nn.Module):
|
|||
self.decoder = decoder
|
||||
self.lm_head = lm_head
|
||||
self.config = config
|
||||
self.tie_word_embeddings = (
|
||||
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
|
||||
)
|
||||
|
||||
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
|
||||
past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
|
||||
|
|
@ -109,7 +116,8 @@ class T5Decoder(torch.nn.Module):
|
|||
sequence_output = decoder_outputs.last_hidden_state
|
||||
present_key_values = decoder_outputs.past_key_values
|
||||
|
||||
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
||||
if self.tie_word_embeddings:
|
||||
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
|
||||
|
|
|
|||
|
|
@ -386,28 +386,39 @@ class FusionT5Attention(FusionAttention):
|
|||
if "past_key_cross" not in past_key:
|
||||
return
|
||||
else:
|
||||
k_nodes = self.model.match_parent_path(
|
||||
idx, k_nodes, _ = self.model.match_parent_paths(
|
||||
matmul_qk,
|
||||
["Transpose", "Concat", "Reshape", "MatMul"],
|
||||
[1, 0, 1, 0],
|
||||
[
|
||||
(["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]),
|
||||
(["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]),
|
||||
],
|
||||
output_name_to_node,
|
||||
)
|
||||
past_key_transpose_node = None
|
||||
present_key_transpose_nodes = None
|
||||
if k_nodes is not None:
|
||||
_, concat_k, reshape_k, _ = k_nodes
|
||||
concat_k, reshape_k = k_nodes[1], k_nodes[-2]
|
||||
key = reshape_k.input[0]
|
||||
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
|
||||
past_key = past_key_transpose_node.input[0]
|
||||
|
||||
if idx == 0:
|
||||
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
|
||||
past_key = past_key_transpose_node.input[0]
|
||||
else:
|
||||
past_key = concat_k.input[0]
|
||||
if past_key in output_name_to_node:
|
||||
return
|
||||
if "past_key_self" not in past_key:
|
||||
return
|
||||
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
|
||||
for present_key_transpose_node in present_key_transpose_nodes:
|
||||
# print("present_key_transpose_node:", present_key_transpose_node)
|
||||
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
||||
# print("present_key_candidate:", present_key_candidate)
|
||||
if present_key_candidate is not None:
|
||||
present_key = present_key_candidate.name
|
||||
break
|
||||
|
||||
if idx == 0:
|
||||
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
|
||||
for present_key_transpose_node in present_key_transpose_nodes:
|
||||
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
||||
if present_key_candidate is not None:
|
||||
present_key = present_key_candidate.name
|
||||
break
|
||||
else:
|
||||
present_key = concat_k.output[0]
|
||||
if present_key is None:
|
||||
return
|
||||
if "present_key_self" not in present_key:
|
||||
|
|
@ -583,7 +594,13 @@ class FusionSimplifiedLayerNormalization(Fusion):
|
|||
[1, 1, 1, 0, 0, 0, 0],
|
||||
)
|
||||
if sim_ln_nodes is None:
|
||||
return
|
||||
sim_ln_nodes = self.model.match_parent_path(
|
||||
node,
|
||||
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
|
||||
[1, 1, 1, 0, 0, 0, 0],
|
||||
)
|
||||
if sim_ln_nodes is None:
|
||||
return
|
||||
|
||||
pow_node = sim_ln_nodes[-2]
|
||||
if self.model.find_constant_input(pow_node, 2.0) != 1:
|
||||
|
|
|
|||
Loading…
Reference in a new issue