From ef42fd09fb67af8fe321d01238a90fd1f5ac42e1 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 11 Apr 2023 00:09:11 -0700 Subject: [PATCH] google/mt5 optimization and fix (#15454) ### Description 1. enabled self-attention fusion in mt-5 decoder graph 2. fix a parity issue https://github.com/microsoft/onnxruntime/issues/15042 ### Motivation and Context --------- Co-authored-by: Ubuntu --- .../transformers/models/t5/t5_decoder.py | 12 ++++- .../tools/transformers/onnx_model_t5.py | 47 +++++++++++++------ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py index ccfe690f0d..e7a9f7c4a8 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py @@ -47,6 +47,9 @@ class T5DecoderInit(torch.nn.Module): self.decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) + self.tie_word_embeddings = ( + self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True + ) def forward( self, @@ -76,7 +79,8 @@ class T5DecoderInit(torch.nn.Module): sequence_output = decoder_outputs.last_hidden_state present_key_values = decoder_outputs.past_key_values - sequence_output = sequence_output * (self.config.d_model**-0.5) + if self.tie_word_embeddings: + sequence_output = sequence_output * (self.config.d_model**-0.5) lm_logits = self.lm_head(sequence_output) past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values) @@ -91,6 +95,9 @@ class T5Decoder(torch.nn.Module): self.decoder = decoder self.lm_head = lm_head self.config = config + self.tie_word_embeddings = ( + self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True + ) def forward(self, decoder_input_ids, encoder_attention_mask, *past): past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers) @@ -109,7 +116,8 @@ class T5Decoder(torch.nn.Module): sequence_output = decoder_outputs.last_hidden_state present_key_values = decoder_outputs.past_key_values - sequence_output = sequence_output * (self.config.d_model**-0.5) + if self.tie_word_embeddings: + sequence_output = sequence_output * (self.config.d_model**-0.5) lm_logits = self.lm_head(sequence_output) present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values) diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index a0819612b7..5be31dbb26 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -386,28 +386,39 @@ class FusionT5Attention(FusionAttention): if "past_key_cross" not in past_key: return else: - k_nodes = self.model.match_parent_path( + idx, k_nodes, _ = self.model.match_parent_paths( matmul_qk, - ["Transpose", "Concat", "Reshape", "MatMul"], - [1, 0, 1, 0], + [ + (["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]), + (["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]), + ], + output_name_to_node, ) + past_key_transpose_node = None + present_key_transpose_nodes = None if k_nodes is not None: - _, concat_k, reshape_k, _ = k_nodes + concat_k, reshape_k = k_nodes[1], k_nodes[-2] key = reshape_k.input[0] - past_key_transpose_node = output_name_to_node[concat_k.input[0]] - past_key = past_key_transpose_node.input[0] + + if idx == 0: + past_key_transpose_node = output_name_to_node[concat_k.input[0]] + past_key = past_key_transpose_node.input[0] + else: + past_key = concat_k.input[0] if past_key in output_name_to_node: return if "past_key_self" not in past_key: return - present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]] - for present_key_transpose_node in present_key_transpose_nodes: - # print("present_key_transpose_node:", present_key_transpose_node) - present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0]) - # print("present_key_candidate:", present_key_candidate) - if present_key_candidate is not None: - present_key = present_key_candidate.name - break + + if idx == 0: + present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]] + for present_key_transpose_node in present_key_transpose_nodes: + present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0]) + if present_key_candidate is not None: + present_key = present_key_candidate.name + break + else: + present_key = concat_k.output[0] if present_key is None: return if "present_key_self" not in present_key: @@ -583,7 +594,13 @@ class FusionSimplifiedLayerNormalization(Fusion): [1, 1, 1, 0, 0, 0, 0], ) if sim_ln_nodes is None: - return + sim_ln_nodes = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], + [1, 1, 1, 0, 0, 0, 0], + ) + if sim_ln_nodes is None: + return pow_node = sim_ln_nodes[-2] if self.model.find_constant_input(pow_node, 2.0) != 1: