From 1f304fbee77f38fa2919b4e895f03c3e7b722433 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Oct 2020 20:12:02 -0700 Subject: [PATCH] Attention with past and no unidirectional mask (#5557) * Update fusions to support shared node, and mask of all ones --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 4 - .../core/optimizer/attention_fusion.cc | 7 +- .../core/optimizer/attention_fusion_helper.h | 151 +++++++++++++++--- .../transformers/fusion_gpt_attention.py | 27 +++- .../test/optimizer/graph_transform_test.cc | 57 +++++++ .../fusion/attention_past_no_unidir.onnx | Bin 0 -> 9754 bytes .../fusion/attention_past_unidir.onnx | Bin 0 -> 9754 bytes 7 files changed, 208 insertions(+), 38 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 24a8cb18fc..0928588a59 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -98,10 +98,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, int past_sequence_length = 0; if (past != nullptr) { // past is optional - if (!is_unidirectional_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is only allowed for unidirectional"); - } - const auto& past_dims = past->Shape().GetDims(); if (past_dims.size() != 5) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is expected to have 5 dimension, got ", diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 8af79734cf..f6c4144c71 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -208,7 +208,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if ((node.GetOutputEdgesCount() >= 4 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert + if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1}, kOnnxDomain) && graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { // Get hidden size from layer norm bias tensor shape. @@ -236,13 +236,14 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, reshape_count++; } } + if (add_count == 1 && matmul_count == 3 && shape_count == node.GetOutputEdgesCount() - 4) { // BERT or DistilBert if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_int32_map, logger)) { fused_count++; modified = true; } - } else if (reshape_count == 1 && shape_count == 3) { // GPT - if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, logger)) { + } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (reshape_count + shape_count) == node.GetOutputEdgesCount()) { // GPT + if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) { fused_count++; modified = true; } diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 2694530f64..4a3ccae6c0 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnx/defs/shape_inference.h" +#include "onnx/defs/tensor_proto_util.h" #pragma once @@ -57,11 +59,14 @@ bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vect +----> Shape --> Gather (indices=0) --> Unsqueeze (axes=0) -----------+ | | | +----> Shape --> Gather (indices=1) --> Unsqueeze (axes=0) --------------+ + +The 3 Shape nodes are merged into one node if use_shared_node is true. */ bool MatchGemmSubgraph(Graph& graph, Node& node_after_gemm_reshape, int dst_arg_index, MatchGemmResult& result, + bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start MatchGemmSubgraph"); // GPT Attention fusion supports opset version 9 or later. @@ -96,7 +101,7 @@ bool MatchGemmSubgraph(Graph& graph, return false; } - if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, 1) || + if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, use_shared_node ? 3 : 1) || !optimizer_utils::CheckOutputEdges(graph, slice, 1) || !optimizer_utils::CheckOutputEdges(graph, squeeze, 1) || !optimizer_utils::CheckOutputEdges(graph, unsqueeze, 1) || @@ -173,14 +178,21 @@ bool MatchGemmSubgraph(Graph& graph, if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze_after_gather, 1) || !optimizer_utils::CheckOutputEdges(graph, gather, 1) || - !optimizer_utils::CheckOutputEdges(graph, shape, 1)) { //TODO: deal with shared Shape node which has output edges > 1 + !optimizer_utils::CheckOutputEdges(graph, shape, 1) && !use_shared_node) { DEBUG_LOG("Output edge count not expected for nodes in gemm gather path"); return false; } result.node_indices.push_back(unsqueeze_after_gather.Index()); result.node_indices.push_back(gather.Index()); - result.node_indices.push_back(shape.Index()); + + if (use_shared_node) { + if (shape.Index() != shape_before_slice.Index()) { + return false; + } + } else { + result.node_indices.push_back(shape.Index()); + } if (shape.InputDefs()[0]->Name() != subgraph_input->Name()) { return false; @@ -252,9 +264,83 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde struct MatchUnidirMaskResult { const Node* div_node; // the root node (Div) of the subgraph + bool is_unidirectional; // whether the mask is unidirectional. std::vector node_indices; // id of all nodes in the subgraph for removing later. }; +// Return true when mask is unidirectionl (lower trigular) or all elements are 1. +template +bool ValidateUnidirMask(std::vector mask_data, int64_t w, bool& is_undirectional) { + // The mask data has shape 1x1xWxW + if (mask_data.size() == static_cast(w * w)) { + bool is_one = true; + is_undirectional = true; + + const T* p = mask_data.data(); + for (int i = 0; i < w; i++) { + for (int j = 0; j < w; j++) { + if (*p != static_cast(1)) { + is_one = false; + } + + if (*p != ((i >= j) ? static_cast(1) : static_cast(0))) { + is_undirectional = false; + } + + p++; + } + } + + if (is_undirectional || is_one) + return true; + } + + return false; +} + +bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidirectional, const logging::Logger& logger) { + if (!graph_utils::IsInitializer(graph, mask.Name(), true)) { + DEBUG_LOG("unidir mask is not constant"); + return false; + } + + // Check that the mask shape is 1x1xWxW + auto shape = mask.Shape(); + if (shape == nullptr || static_cast(shape->dim_size()) != 4 || !utils::HasDimValue(shape->dim(0)) || static_cast(1) != shape->dim(0).dim_value() || !utils::HasDimValue(shape->dim(1)) || static_cast(1) != shape->dim(1).dim_value() || !utils::HasDimValue(shape->dim(2)) || !utils::HasDimValue(shape->dim(3)) || shape->dim(2).dim_value() != shape->dim(3).dim_value()) { + DEBUG_LOG("unidir mask shape not expected"); + return false; + } + + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (!graph.GetInitializedTensor(mask.Name(), tensor_proto) || tensor_proto == nullptr) { + return false; + } + + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + DEBUG_LOG("This optimizer does not support external data for unidirectional mask right now"); + return false; + } + + if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + std::vector int32_data = ONNX_NAMESPACE::ParseData(tensor_proto); + if (!ValidateUnidirMask(int32_data, shape->dim(2).dim_value(), is_unidirectional)) { + DEBUG_LOG("Mask is neither unidirectional nor all ones"); + return false; + } + } else if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector float_data = ONNX_NAMESPACE::ParseData(tensor_proto); + if (!ValidateUnidirMask(float_data, shape->dim(2).dim_value(), is_unidirectional)) { + DEBUG_LOG("Mask is neither unidirectional nor all ones"); + return false; + } + } else { + DEBUG_LOG("Expect mask data type is uint8 or float"); + return false; + } + + return true; +} + /** Match Unidirectional Mask subgraph. In the below graph, ':' is followed by variable name in code. * means the input on the left side. @@ -274,8 +360,10 @@ struct MatchUnidirMaskResult { +----> Shape --> Slice ---------> Squeeze-------+ | | :shape2 :slice2 :squeeze2 v condition +----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add] + + When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same. */ -bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, const logging::Logger& logger) { +bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start MatchUnidirMaskSubgraph"); std::vector root_path{ {0, 0, "Where", {9}, kOnnxDomain}, @@ -325,10 +413,9 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid !optimizer_utils::CheckOutputEdges(graph, mask_slice, 1) || !optimizer_utils::CheckOutputEdges(graph, unsqueeze1, 1) || !optimizer_utils::CheckOutputEdges(graph, sub, 1) || - !optimizer_utils::CheckOutputEdges(graph, squeeze1, 3) || + !optimizer_utils::CheckOutputEdges(graph, squeeze1, use_shared_node ? 2 : 3) || !optimizer_utils::CheckOutputEdges(graph, slice1, 1) || - !optimizer_utils::CheckOutputEdges(graph, shape1, 1) || - !optimizer_utils::CheckOutputEdges(graph, mask_slice, 1)) { + !optimizer_utils::CheckOutputEdges(graph, shape1, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for nodes in path 1 of unidirectional mask"); return false; } @@ -348,6 +435,11 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid return false; } + if (!ValidateUnidirMask(graph, *(mask_slice.InputDefs()[0]), result.is_unidirectional, logger)) { + DEBUG_LOG("ValidateUnidirMask returns false for mask_slice"); + return false; + } + if (!CheckSliceParameters(graph, slice1, {1, 2, 3}, {-1, INT_MAX, 0}, logger)) { DEBUG_LOG("CheckSliceParameters returns false for slice1"); return false; @@ -364,7 +456,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } const Node& unsqueeze2 = edges[0]->GetNode(); - if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, 1)) { + if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for unsqueeze2 of unidirectional mask"); return false; } @@ -376,7 +468,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } const Node& unsqueeze3 = edges[0]->GetNode(); - if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, 1)) { + if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for unsqueeze3 of unidirectional mask"); return false; } @@ -401,7 +493,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid const Node& shape2 = edges[2]->GetNode(); if (!optimizer_utils::CheckOutputEdges(graph, squeeze2, 1) || !optimizer_utils::CheckOutputEdges(graph, slice2, 1) || - !optimizer_utils::CheckOutputEdges(graph, shape2, 1)) { + !optimizer_utils::CheckOutputEdges(graph, shape2, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for squeeze_2/slices2/shape2 of unidirectional mask"); return false; } @@ -411,6 +503,10 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid return false; } + if (use_shared_node && (shape1.Index() != shape2.Index() || unsqueeze2.Index() != unsqueeze3.Index())) { + return false; + } + result.div_node = &div_node; result.node_indices = { where_node.Index(), @@ -423,10 +519,13 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid slice1.Index(), shape1.Index(), unsqueeze2.Index(), - unsqueeze3.Index(), squeeze2.Index(), - slice2.Index(), - shape2.Index()}; + slice2.Index()}; + + if (!use_shared_node) { + result.node_indices.push_back(unsqueeze3.Index()); + result.node_indices.push_back(shape2.Index()); + } DEBUG_LOG("Pass MatchUnidirMaskSubgraph"); return true; @@ -739,14 +838,14 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No } std::vector shape_value; if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[1]), shape_value, true) || - shape_value.size() != 1 || - shape_value[0] != 1) { + shape_value.size() != 1 || + shape_value[0] != 1) { return false; } shape_value.clear(); if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[2]), shape_value, true) || - shape_value.size() != 1 || - shape_value[0] != 1) { + shape_value.size() != 1 || + shape_value[0] != 1) { return false; } @@ -894,8 +993,8 @@ bool CheckDistilBertReshapeShape(const Graph& graph, const Node& reshape, int64_ // lazy check: record unqueeze first and then check in the mask path std::vector shape_path{ - {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}}; + {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(reshape, true, shape_path, edges, logger)) { DEBUG_LOG("Failed to find shape path"); @@ -1138,12 +1237,12 @@ NodeArg* GetOrCreateMaskInt32( | (0,2,1,3) (0,2,3,1) (perm=0,2,1,3) | \ / | [Past]? \ / | | - | \ p_Concat? <------|---------------------{Past_Subgraphj}? + | \ k_Concat? <------|---------------------{Past_Subgraphj}? | \ / | | | qk_MatMul | | | | [B=h] | | | | / | / - | qk_Div p_Concat? <------------------ + | qk_Div v_Concat? <------------------ | | | | {Unidir_Mask_Subgraph} | [Mask]? | | / | @@ -1178,7 +1277,7 @@ After Fusion: --------> Add TODO: replace Gemm_Subgraph by MatMul + Add */ -bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, const logging::Logger& logger) { +bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start FuseGptAttention"); const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0); if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) { @@ -1191,7 +1290,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: } MatchGemmResult gemm1_result; - if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, logger) || + if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, use_shared_node, logger) || !ValidateGemmInitializer(graph, *gemm1_result.gemm, hidden_size, false, logger)) { return false; } @@ -1233,7 +1332,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: const Node& v_split = edges[2]->GetNode(); MatchGemmResult gemm0_result; - if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, logger) || + if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, use_shared_node, logger) || !ValidateGemmInitializer(graph, *gemm0_result.gemm, hidden_size, true, logger)) { return false; } @@ -1263,7 +1362,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: } MatchUnidirMaskResult unidir_mask_result; - if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, logger)) { + if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, use_shared_node, logger)) { DEBUG_LOG("MatchUnidirMaskSubgraph returns NULL"); return false; } @@ -1365,7 +1464,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: nullptr, kMSDomain); attention_node.AddAttribute("num_heads", num_heads); - attention_node.AddAttribute("unidirectional", (int64_t)1); + attention_node.AddAttribute("unidirectional", static_cast(unidir_mask_result.is_unidirectional)); // Assign provider to this new node. attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType()); diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index 9089f8c5ad..687a78c150 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -23,16 +23,17 @@ class FusionGptAttention(Fusion): self.utils = FusionUtils(model) self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32 - def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask=''): + def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask, is_unidirectional): attention_node_name = self.model.create_node_name('GptAttention') attention_node = helper.make_node('Attention', inputs=[input, gemm.input[1], gemm.input[2], mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" - attention_node.attribute.extend( - [helper.make_attribute("num_heads", self.num_heads), - helper.make_attribute("unidirectional", 1)]) + attention_node.attribute.extend([ + helper.make_attribute("num_heads", self.num_heads), + helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) + ]) matmul_node = helper.make_node('MatMul', inputs=[attention_node_name + "_output", gemm_qkv.input[1]], @@ -115,6 +116,8 @@ class FusionGptAttention(Fusion): logger.debug("Add and LayerNormalization shall have one same input") return + is_unidirectional = True + slice_mask = None input_mask_nodes = None qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: @@ -127,6 +130,7 @@ class FusionGptAttention(Fusion): logger.debug("fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] + slice_mask = mask_nodes[3] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") @@ -162,11 +166,24 @@ class FusionGptAttention(Fusion): logger.debug("fuse_attention: failed to match mask path") return div_mask = mask_nodes[-1] + slice_mask = mask_nodes[2] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return + # Validate that the mask data is either lower triangular (unidirectional) or all ones + mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0])) + if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) + and mask_data.shape[2] == mask_data.shape[3]): + logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW") + return + if np.allclose(mask_data, np.ones_like(mask_data)): + is_unidirectional = False + elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))): + logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones") + return + q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") @@ -219,7 +236,7 @@ class FusionGptAttention(Fusion): self.casted_attention_mask[input_name] = attention_mask_input_name self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0], - reshape_qkv.output[0], attention_mask_input_name) + reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 85c23aae44..dffd88145d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1834,6 +1834,63 @@ TEST_F(GraphTransformationTests, AttentionFusionFloat32Test) { ValidateAttention(graph); } +// Test GPT-2 Attention Fusion with past and unidirectional mask +TEST_F(GraphTransformationTests, AttentionFusionWithPastAndUnidirMaskTest) { + auto model_uri = MODEL_FOLDER "fusion/attention_past_unidir.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0); + EXPECT_EQ(op_to_count["Softmax"], 0); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + + + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + if (p_node->OpType().compare("Attention") == 0) { + EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 1); + } + } +} + +// Test Attention Fusion with past but no unidirectional mask +TEST_F(GraphTransformationTests, AttentionFusionWithPastAndNoUnidirMaskTest) { + auto model_uri = MODEL_FOLDER "fusion/attention_past_no_unidir.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0); + EXPECT_EQ(op_to_count["Softmax"], 0); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + if (p_node->OpType().compare("Attention") == 0) { + EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 0); + } + } +} + // Test GPT-2 Attention Fusion with float32 mask TEST_F(GraphTransformationTests, AttentionFusionGPTWithPastAndMaskTest) { auto model_uri = MODEL_FOLDER "fusion/gpt2_past_mask_one_layer.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx b/onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a5ea98d2865418da6fa37906a00c50caf0b21958 GIT binary patch literal 9754 zcmb_i3wRS{)^5^UI<%!k$|YQ6!#&WDnaoTQnr~X#Ds1KAQbZqgpFC*^Hkw{YQtm1U zxWJF-FQ~ZU%F5y*h^v4X=zK%P6;MQ3b&-{ol?CyJup+GJ?yvtjlT0#c6PEw^h3A=0 z<~wK3Io~=> zAo0=GgUM+Pgqtd}Dh-v{mBx*Rd}a_+5NvL3i~52yBj#+@yJ9g*P=vX14_b3&0i5T9Eh@b3PEM1n2V^2b1Fzd67XC z)Z;A$)h*2p{wVlw&5PZphFiuILVUq!u%#ItM@=3z1q(AZ5K;YSd24?Bp7}X5Ac}J` zt0mkN2;1kj2JCYK!P^_7zHlHCTo6Eg0rXjqr63+0C#Ik1k-GW9IT3MAT$TdOtUI&S zXcw!XiB{c7s~P6BKX}Nx@!|mm&SA+LH*=kxraTC1f3MMd} zBAa5unb8CWAhLMju#^M`DkrN*gjr2jMK+8Dbp?z{z!&g1zXH^FT`1TP@Ch!MD7<4v zqXbs9;IuGx@tPJO+wuF(6^()@E$TxMJQ+iX^?(PhU?v2}M7#$?M|$svpu6a^fawt2 z$`jW*jGeCnUnW*B}Y(fXi267L$c>apaH4>UFTM9!Ln zH@5Cfu0Pb;=r7I6T3oTwV6YW4IrV{P@^cZBr@nuj3S2hrV^s{tFem|MK~xJs*UzGu z3~Dk82 zpy5pH4FIAn*mbgkrJL5!EbB-bMwn#*DV(gB0<$ce9%hXIvj(&PApta_He}byixwt< zaL@(u;_Mkm^dGf)&UM8D1tNxAYC{uu?V=K>Jp!HYk(0~ z0d#IK>T&2HItSIyO!7x3wS{~v51kXcX9mEyfI?W;PJy@N*97N)1Yi){AO^uh7=flO zh@lKC@EXjFf>CTxr-*YP@_513N)$j??KBp8Dc)*jIEs@jik5C#tFh2K>3dCmEjsRU zS&9-XjE{w`&$Qqq+z~LL&>e75LkIlV+~M@@2mlLC(V=(8nZX^LClN9M;;7?-^|Y-X zym45w;f7JYg@$#Co>Xt4f)dWBg2`lNP}+{vwIZ6s5_7p!{KNTm;rzOsiGTp1&6^=~IBUIvm9?M1fb| zkAb?uhns=C00j^|)?D?LvEbSXqpZiR5e^uO5e`@|;4xkhj2+-ke8GS&OG|zby&c#X z4woKn*~|r?YN&Iiu4N9$yVj4e?#J6Pgof__Fy!MdPMkyBt3y0M?#fQ5A}v!YG@c`ynqOZ-CuQvpJZ zsiYOvfNUg5t2=$(Mlu>=gE6NYw45Ft@H)BVY+|1}u{wE23h+9C6kHM8RN%!j=dpT$ z;3R;TPp+#P@cxhjya0<3J*uiR1emSCq6^4T92U_rH)4ky&?n&u%+!B!OzkMp=CMe4 zD3augCo^*8? zr{|gxKz^R(bw*$cW4)QE26Z>08i6X{oouR+%q$V$2c+>5qi~RG-U+7SZ3#7;zApjF zV>5sRnCDc>s0%DO4&|Ifzp}spjyjW24sDzP%JF3!#>~NLp8#hcv~|Y0VpZ4!@B$rF zOx7FVK>7%v5Fo>-fob8~$wGj(PIka{AaXF8k;t`%TkhzztI0&?F1}JaTd#Xt!4W(DU(? z0#hykhVjTecp(WnT&8_8&a;4)T!16Glmd;(wEQNGU>hPkiZ4eT093~E$Grldy2isE zE29xCe_X-AFRUOWN4yJ6Y30uYgr-}Tu>4~)fHjBAnaAw_4_rknf55pUv4Q~{1+>wK z73?9dQ2rvu3|py`xyrngSIZxX9_XMVWF!$mhOIo-36Fjw<>Qg^Y0Mk;ipe_pJ(fT4 zs$rd}P*bQCwlT9BI@SJ=jB=@k4+Je<_`r34vceY;*pY$pz`%3?0n)^yL&D7C!okCl zi&ojN$I(?b9=hvI- zBc0@dP9`hc1@)75dyZjvAZEdlx<0uwFM_ek2CM-cRMhpRE5p)+mF-p|SJR>|1eXp0 zw?SnK4kt4reZ{81^$b1G71jkfN{d&w3)uQku|bER;H`MUG63?MtZuh%wE|cNFb6SY zHv+I5m0fafXAIUM0Wnr3q`)xCXFlUWqNVOnd3OrY0{BYH023#isx2g!eCm#i&kSSW zI7q_cRIt|RT$>%DTV=$uhuZ^)d1oU zJ%Aox54*wffl!-|wVJg1u|*9nP4=c>L%1aZy9jehv?bioSawG+S{9hs+7gc12Ai@= z@~eOYIcoc4b^53Aaa&K50ff|J;#yOhr>?Y5J?Wkbj#^V;+>ULODF=_S`cbYns>UVVy!7V{@7MrV!-JUT(vn_mHqeY zq}*vI0j4HL(@j_x*>tiZS8FQiD(g%u3;rkG+iWtxfB08ov{f4ZpZ}`5bc}j*sViKY z&s11qusX8OUy?fR&D;Cn!Q+bf$yDX(w~Wf2Ew9m0TX!jY)^DKabG-Dfm;2JlGQa$i zizro@cybp57Y(Ky)kcS*YQqI(@1K?{(+Hlf9(j%~5IjGi#-rY{M!(*;6K6`qvU#|Hu(3@X){H?`FR% zBZA`m)}yehMYd!E`79X9j)Q|Dz$63dA~My&@(4Z%1zhLBfJ0Zm(GnY zp?A*rDu<5LOEWK3DzB}dB|ozJFVeUZ_efvt-KZ#XcWJ)$A!)_&{iJc;Au@gNcv(5Q zRI0lA17TYQG>O4l!kp@N zgQRrTWqC@`0%gRqT4nktAJg9Fiby&8SLOBI1Et^9-bS{*+F$vLc3h`VU_fE z<+VK-z3#e7+T1a($N$*c{=%UX-jl4o2d|xQU9-C5f6B;)O}S)K$=dQGeV2ICT1|mn&FZ)P@}$Pwrc2+hu9KFm+E|g+ zY9~;)g-j(ly#*7jq;TOb+V(_uMS3Hj9(Z7>Jnd|&JgK-=>3R6yBsAOK-t_+Y^0tQu z(!BdCq~$mN7k%q`nGSz-JsoHX$yHaynr$x_A)jb)vGL;B4l4l9V@} zrC;4^D}QsThh1^V;1*O87z2IUW- zm2~oV-+A|Ms+J!fN9dy8O>Cce;~$7D?x4H()X{Fkz9A*yRtfr>k;U+c1a>%5c zU!-T+bLC9>U0P_T(NvmE^B(N_MEn7y^D!;j~w7l#yk9Y0$P=`0&M&&8{ z=ahHmETgpZ z%eUH(P95MaJH*fv{ya%}a#8!;YsQpsvu}|Pe|wSK6_MH}9v@nM`WH90pPKUz@@UpI z?NwL)3flIW(bQ9QwZr)QT_@?#qk|}YJf9pcD3H#S4iUEWRY7G?Wa zRno8KP4n*9w~dZiXC(I*uO`;F2PhS#xyr8|zC~I1T%NRY^-l8aludNYBR|rup;;vQ z#U?Vpeh3{`Go5@q;(PMgKhs6o?~~^(*D6ok=Jmc)Gp_x;TXV>}`Srxo zx{&l^+vE@S+)j$Wd5IomD#+bs9(il`&%9GVBxKF}JbHE!qnzsYGt$H7lx)TK(Mbow z@veC`W|B=)6L9R&)l1$3 zy8+F%7CD@aBqN`iHDk2yXDa4vnE|GJb6=*=%+&j%4UInaH!lU|{77K7FBE8=F(&or zH2+z_)%n{n^H656$pDI}HoBXOqAgK>$fsEvtNpmi>|pGmU}KCXV<#niQfw|%e*_iR zL#cX+N$$nXEHehwM`cj&Y**FqhQ|@CIR`!%s*EMws(Lfmq-;*@%?*un8=&p4W%{KK zG=*jb{L!{B{8lI&0fAzM@dxerD=vFesKi*+lPS?Y!`?sG9BjTFr1~>`w5M>0Feez8 PYuAnTU{bfnC5Hb6TrTHv literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx b/onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx new file mode 100644 index 0000000000000000000000000000000000000000..22a0d1284d63f13f7f7d05f1155bb212f7086dbe GIT binary patch literal 9754 zcmb_i3z!qt_1|P)*@0b_$nppeX?O=VWF|9{gw35@c2Vfc!vdlo+Am*r32rp|Aj$Gp zLBIlkME`<{6)Tlu5yUFs19t9!SOGQ=`GASW9y(1mi_} zQSz&!E0fm}iZ)i|RGF%BtIQirg-m~@DBRT28V`hL#H_ij&t)r|(hzJ31>E+0^@edW z*o+r2hV^kai)Ks~Q`Dhg3g>d$I^Wb3o7EZ$%?DFBdr|UR!uU|E3XCsgTuh!@#11)igKN2jgJBy&!R$?ru5L8T<>!!_7@-I7;#;DcG3Fp_pnv%i9aJd*P?&2cn@^czy`w1(0Wbwj#|rPRu;f!wmCb)AFz=1;>8C7oXb`)dd7?Z=dtI)Cxg8lOgqOxggB1VdVPuB-q3-0 z#PNd7l8CL~(t|q^I|aOeBODKo=)vVO`;TgPv|8kh(G~fu(M_qq;PMb$O$G0}f(i_$ z$fl`qXH|g?h%8<>OeMjE!l^70VOBF`kpq1}SpmHg@Buu|uK+ndB@(U=1q2UF6y7l- za0F(x;I=VSw44?o+S+~ljK+aDS~R*K__Df?r~w~p!AJ-Y39SZ1S7z%6qPb|ZfZ-6l zsni$QR1!XG!Vz9(n6UsFoGxcD9&d8i2h@-BVi^S4Osg8y>a!J&4K+3fM9!Xv zH;yh$elXI~5G>2dSyZ{fWO9@+d3B+9>U%L$pgw=J3S18TYjpz0&?y0DK~xjKP|u>6 za%v(D5Uhm|hUh#eYt6qigdc^ogOS#dlZEKDM8ffah|$BlBzOSinH{A}ehgcy^6(!4 zXb6*d0)QwBdflvG>!cSn%eqpE5oTFH3O6gJ!7R&WhFLSftO+%MO90I%4bgS;qK!!+ z98`gpoZU1;|55(~Q8P1?(ftIkeVE!lHWh2H0dNpO5u`Z)>CRJ=i+~rhF8R%{2AFXb zK;s6W9ETR7aghDYgkXF^Yb3z(P&tWvrZ4mhI0*CFE%3I&G2z)T0x$?}5QCr*dZ1{F z5-7t8ybd$Np%)9(E#e%AJYI0M5(OBnRvHVnq?KAZ_TnUqqOFr&YAn=F=2}x9i-vnV zw&LUp6JVk0vn@DDa|8@1R0mvC*8%@^b2y_p0>FY>bQ#TYXE6uoNrXtiaFp@DdfHkC z*0}7saKot9LczL4U%IuBK}qA&!DJ%SKVwB2N)gpzig`RL{^9(3aDF}Rq(gwv+WcY= zJ#KX%fS)ddCVS9LuPY^!gug-1>qUpWP+c`hPgv1w&%--M5o3-w_bkF(9fHHqMS)kJ zkB)l5hMR%B00j_z_I&l0vEkYYy{ylx6AtK05Dr){pcyX+#x5{NTQH!?GLj!eZxv0Y#$fDD)QR@7u2eCyo)lD| z3KorG(^fIon<>sJMyL{4zHHjk%4S?NRPv_M0thxppqOAQz+B_FM5?L1xV$0eaID8x zfHlqW_FVW>>)8v{2pt$CsAr@UtOxQGIRzG>7wg#zSU9D5X1p;t7kXi3(Y_=x6(Gc% zPFhh6$VQ5^dNapu7^5RLD06#3%Iz}%ubWHFCf2DNvy*qF0k0cK!4Z9%eh#+R3K%b1213#4+F^cY`XeEFp(8*CjxC zVg`@^^PH*~Wq}22P|mrGD+_dBuR8_hP{vuH93RG^&m64wNpKE8S!bOqR)swPFOWgO zRK5WYWcC0O0V0eN7#7Z(N(3nDR0V7UA_u)$iCjyx`HptGnruW)9coxTP^}NPbx7nm z*mC_OOb*@Rp!y(+@O-y{Q32wl+=As2b}TpyR(#$qCSV9QJISpLkDQzb+O4NQ)V#J* zV8{i)Fdmr)FQg!c%eGI(c@|KU2XMp?QlK%}mfw^TEJI{R+H%AJKxIvT+$-?uYdq|+ zvMRy!#}ypx!VE%k#Ct%Mp8h;QXr^Te(?2l-m~+USdE5^0z*Y402b@b0E9k&pKpUM{ z!5-oY=`W(su$4-itL!^@HT{w3feZ>lM3Nq4*~(*{@Mt$uJ{~Eb&b(o-n97siWBLQH z8r+@=HAY%s8#A-MUG5LeIxao&fuL!L4_x;rOMDT59T^x849wsVAdTi75@sG34jztN z^vs4mjv=%0(1k}n&jX)NCLPeOtV!nsjRKG$W3;^4!AgBTUc9jO@?M`oNtSa6jI0;R zQE;Urq9-FqQCAPhY#5Cfm2Y8Q2*9_JdEw0hi=jetHd=^~k>ysz@BACc*z#5Q2K|^hNvMf!Q*OykKr|-D<>@Ws~ z!)Um`Nm5ZuG!z5g9{`@OVHpd+lo5^RBo;G!`$$Hb0&F{c>ae9#$E0BsBC~aTa9$`n z5zZ}xk?{N=?ml1#Ul?kMg(J;PWzW7o!~=&c3@FCw&ZL;T3dNi=O)=o{)hdR)8bBPP z2hii|VK+E76lo2xc8h*Ls<^(n(b*WTk2c3(7hx@pH%IFm%I^rr%R_Tpnxk>Y084IZ zVKs0dS8dOnc7Hk_cXYFuV32xDTx-eD>dDyZ%hW2kYAu~LJ&xg)JX8+nDz*LF>CQ20 z4M}y{VA!hs7lT>;Dh4Or6Zl$-t~cehAkkO@8ttD4$c9fKwaC!t!ZC*}QpZ$hW z-i(s~Qs_sxR>d~bkaUB6m=Tej1m3#h@wD;bEJr5o{u85yZR-S&#tlZiB8Xdl6 zr?Pw9dU`(3Pw#rU7mY0q$}f3{Qk{(_cTjNAWXV%?>@`)dzo6{-(=ui1goAYZz#WS4 zVn~{EcqH`)j!GY`<0RAmDkb`Tg%sG-ZBM^hugKGw9mI69yELRJE-iazjdYV^aOEev z%cV>IT1@L6IU3;&TgI#=rw_MDAMIRA$8f!r+BIAK zUt9OmGbb&|P1nyQyZ#=O&W$Lgch2)GhmO=qGcHvrudSOYKeFpD(&!WSNMG#PpeS+| zX`cNdY5DQ}q+#wMGHt+ESvk2xs=oRIVMiqtiOF8doarUK^3T4?t%JKO3vv(9tn=1> zA#%iE+VBxokM-Fg{o?aVWz4GwN!iNF@}%PV%Fv~?%Ct{DrajITlM42)%IiJ)Nx!ST zjcj?fkNoASVPxaDHo5x4YU%IFYrC^r-Eoz4coSBS{jsI(g+nL&CxJI;(;=^}qy20Vx%#SDv*QIb#1j=R zGG9EqmyG4!AnWhSrOTJDAQeq#=~wqUD&AbVzM?^7l=UAxEe~Hbnu@>dBXvD-fINBM zDEdc#k$=TOf&TpMwPf!?lk$hi3Oez-@BI5W*2oW!CUoKN#<$J5@ef26x6@s_r_fG= zzagd37P9i|^`varSo(GGDoW4QkePR^_V*TUAjSK7(v}~tquZX}*j6|A--M1RmMZ=^ zRJvi+VEN+xW_tJX!}Ri$7E)5ueYfYnr|9~Ucj)W=x6s1huW1{?M?@&TJ%$8F15sS>Tr+!Eyu|Cr0>k7%a z!TbD^%gTsvM2s%py^Rd$X*66nnm$~4fIKk$19?yA9Id`;rtH{^%MT*YBGc#Pzm{!7 zYn9(zSW;2`na{uGddS0@Zlm&~ef0R^De`xNS61v?w~cnbAwX&_uWQ@;;M!eB2R%w_ z`W+*;zUvrOU-NxM*~-K6UvkXSjxXP8J36_qzx)tGPXr4j<;jI@cds5n;-d+Zi&n!@h>)#d36Kn=rPmC$3wp-k6n(^LyPQG&iyl8nEO6? z-gd3>#BF~6J7Y$-y?1LKdAG2R*jg5l-fXMGGL>@`r@1 zo>xH6E@YHboqk5T2HcXPb^G`3j;=`g z;O7jr7LAMK=uNjDtMKte-$-;mFVSI?9I{SIUYDsQeM>pXFdik z*QhKO-!dD<>5a%slTRqf*iP>?iP(+WpO z85+S?-M)Jq`baK=(Y#=!oT?i5A&cROT2lr;;5D3KoTkK$=IBTBx@LH5mo(?cWFu15 z_TIod$4nzZyEfZNZ{VF-$Ek&vVur@v7IPwb#M)9lSV@JD6ZX7s@?Na0hMl8Z3m4lV zC*ZK&@I-wk*$g=W#~vN6Ybmt$Vme!yx?sG% zA)x-|rN~+s3(X2dLQT_0rvIGgKMQ!;zYVhvVg^`DAegSBi?ul391lhUx~5V3kDIJ6 z#t8xr#%wXS6T+4fYiIRGP?{W$s*{-1UfjwuBSCz47V-9aRryYM9Kl-i;ESo+T-vF+ z2Xjr@=G56#-!P{h%KloWcX~%-WM(KBZ;irlg`zPSP{J_5uoHj9