diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 24a8cb18fc..0928588a59 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -98,10 +98,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, int past_sequence_length = 0; if (past != nullptr) { // past is optional - if (!is_unidirectional_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is only allowed for unidirectional"); - } - const auto& past_dims = past->Shape().GetDims(); if (past_dims.size() != 5) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is expected to have 5 dimension, got ", diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 8af79734cf..f6c4144c71 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -208,7 +208,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if ((node.GetOutputEdgesCount() >= 4 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert + if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1}, kOnnxDomain) && graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { // Get hidden size from layer norm bias tensor shape. @@ -236,13 +236,14 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, reshape_count++; } } + if (add_count == 1 && matmul_count == 3 && shape_count == node.GetOutputEdgesCount() - 4) { // BERT or DistilBert if (AttentionFusion::FuseSubGraph(node, *add_node, graph, hidden_size, mask_int32_map, logger)) { fused_count++; modified = true; } - } else if (reshape_count == 1 && shape_count == 3) { // GPT - if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, logger)) { + } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (reshape_count + shape_count) == node.GetOutputEdgesCount()) { // GPT + if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) { fused_count++; modified = true; } diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 2694530f64..4a3ccae6c0 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnx/defs/shape_inference.h" +#include "onnx/defs/tensor_proto_util.h" #pragma once @@ -57,11 +59,14 @@ bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vect +----> Shape --> Gather (indices=0) --> Unsqueeze (axes=0) -----------+ | | | +----> Shape --> Gather (indices=1) --> Unsqueeze (axes=0) --------------+ + +The 3 Shape nodes are merged into one node if use_shared_node is true. */ bool MatchGemmSubgraph(Graph& graph, Node& node_after_gemm_reshape, int dst_arg_index, MatchGemmResult& result, + bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start MatchGemmSubgraph"); // GPT Attention fusion supports opset version 9 or later. @@ -96,7 +101,7 @@ bool MatchGemmSubgraph(Graph& graph, return false; } - if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, 1) || + if (!optimizer_utils::CheckOutputEdges(graph, shape_before_slice, use_shared_node ? 3 : 1) || !optimizer_utils::CheckOutputEdges(graph, slice, 1) || !optimizer_utils::CheckOutputEdges(graph, squeeze, 1) || !optimizer_utils::CheckOutputEdges(graph, unsqueeze, 1) || @@ -173,14 +178,21 @@ bool MatchGemmSubgraph(Graph& graph, if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze_after_gather, 1) || !optimizer_utils::CheckOutputEdges(graph, gather, 1) || - !optimizer_utils::CheckOutputEdges(graph, shape, 1)) { //TODO: deal with shared Shape node which has output edges > 1 + !optimizer_utils::CheckOutputEdges(graph, shape, 1) && !use_shared_node) { DEBUG_LOG("Output edge count not expected for nodes in gemm gather path"); return false; } result.node_indices.push_back(unsqueeze_after_gather.Index()); result.node_indices.push_back(gather.Index()); - result.node_indices.push_back(shape.Index()); + + if (use_shared_node) { + if (shape.Index() != shape_before_slice.Index()) { + return false; + } + } else { + result.node_indices.push_back(shape.Index()); + } if (shape.InputDefs()[0]->Name() != subgraph_input->Name()) { return false; @@ -252,9 +264,83 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde struct MatchUnidirMaskResult { const Node* div_node; // the root node (Div) of the subgraph + bool is_unidirectional; // whether the mask is unidirectional. std::vector node_indices; // id of all nodes in the subgraph for removing later. }; +// Return true when mask is unidirectionl (lower trigular) or all elements are 1. +template +bool ValidateUnidirMask(std::vector mask_data, int64_t w, bool& is_undirectional) { + // The mask data has shape 1x1xWxW + if (mask_data.size() == static_cast(w * w)) { + bool is_one = true; + is_undirectional = true; + + const T* p = mask_data.data(); + for (int i = 0; i < w; i++) { + for (int j = 0; j < w; j++) { + if (*p != static_cast(1)) { + is_one = false; + } + + if (*p != ((i >= j) ? static_cast(1) : static_cast(0))) { + is_undirectional = false; + } + + p++; + } + } + + if (is_undirectional || is_one) + return true; + } + + return false; +} + +bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidirectional, const logging::Logger& logger) { + if (!graph_utils::IsInitializer(graph, mask.Name(), true)) { + DEBUG_LOG("unidir mask is not constant"); + return false; + } + + // Check that the mask shape is 1x1xWxW + auto shape = mask.Shape(); + if (shape == nullptr || static_cast(shape->dim_size()) != 4 || !utils::HasDimValue(shape->dim(0)) || static_cast(1) != shape->dim(0).dim_value() || !utils::HasDimValue(shape->dim(1)) || static_cast(1) != shape->dim(1).dim_value() || !utils::HasDimValue(shape->dim(2)) || !utils::HasDimValue(shape->dim(3)) || shape->dim(2).dim_value() != shape->dim(3).dim_value()) { + DEBUG_LOG("unidir mask shape not expected"); + return false; + } + + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (!graph.GetInitializedTensor(mask.Name(), tensor_proto) || tensor_proto == nullptr) { + return false; + } + + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + DEBUG_LOG("This optimizer does not support external data for unidirectional mask right now"); + return false; + } + + if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + std::vector int32_data = ONNX_NAMESPACE::ParseData(tensor_proto); + if (!ValidateUnidirMask(int32_data, shape->dim(2).dim_value(), is_unidirectional)) { + DEBUG_LOG("Mask is neither unidirectional nor all ones"); + return false; + } + } else if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector float_data = ONNX_NAMESPACE::ParseData(tensor_proto); + if (!ValidateUnidirMask(float_data, shape->dim(2).dim_value(), is_unidirectional)) { + DEBUG_LOG("Mask is neither unidirectional nor all ones"); + return false; + } + } else { + DEBUG_LOG("Expect mask data type is uint8 or float"); + return false; + } + + return true; +} + /** Match Unidirectional Mask subgraph. In the below graph, ':' is followed by variable name in code. * means the input on the left side. @@ -274,8 +360,10 @@ struct MatchUnidirMaskResult { +----> Shape --> Slice ---------> Squeeze-------+ | | :shape2 :slice2 :squeeze2 v condition +----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add] + + When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same. */ -bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, const logging::Logger& logger) { +bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start MatchUnidirMaskSubgraph"); std::vector root_path{ {0, 0, "Where", {9}, kOnnxDomain}, @@ -325,10 +413,9 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid !optimizer_utils::CheckOutputEdges(graph, mask_slice, 1) || !optimizer_utils::CheckOutputEdges(graph, unsqueeze1, 1) || !optimizer_utils::CheckOutputEdges(graph, sub, 1) || - !optimizer_utils::CheckOutputEdges(graph, squeeze1, 3) || + !optimizer_utils::CheckOutputEdges(graph, squeeze1, use_shared_node ? 2 : 3) || !optimizer_utils::CheckOutputEdges(graph, slice1, 1) || - !optimizer_utils::CheckOutputEdges(graph, shape1, 1) || - !optimizer_utils::CheckOutputEdges(graph, mask_slice, 1)) { + !optimizer_utils::CheckOutputEdges(graph, shape1, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for nodes in path 1 of unidirectional mask"); return false; } @@ -348,6 +435,11 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid return false; } + if (!ValidateUnidirMask(graph, *(mask_slice.InputDefs()[0]), result.is_unidirectional, logger)) { + DEBUG_LOG("ValidateUnidirMask returns false for mask_slice"); + return false; + } + if (!CheckSliceParameters(graph, slice1, {1, 2, 3}, {-1, INT_MAX, 0}, logger)) { DEBUG_LOG("CheckSliceParameters returns false for slice1"); return false; @@ -364,7 +456,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } const Node& unsqueeze2 = edges[0]->GetNode(); - if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, 1)) { + if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze2, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for unsqueeze2 of unidirectional mask"); return false; } @@ -376,7 +468,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } const Node& unsqueeze3 = edges[0]->GetNode(); - if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, 1)) { + if (!optimizer_utils::CheckOutputEdges(graph, unsqueeze3, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for unsqueeze3 of unidirectional mask"); return false; } @@ -401,7 +493,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid const Node& shape2 = edges[2]->GetNode(); if (!optimizer_utils::CheckOutputEdges(graph, squeeze2, 1) || !optimizer_utils::CheckOutputEdges(graph, slice2, 1) || - !optimizer_utils::CheckOutputEdges(graph, shape2, 1)) { + !optimizer_utils::CheckOutputEdges(graph, shape2, use_shared_node ? 2 : 1)) { DEBUG_LOG("Output edge count not expected for squeeze_2/slices2/shape2 of unidirectional mask"); return false; } @@ -411,6 +503,10 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid return false; } + if (use_shared_node && (shape1.Index() != shape2.Index() || unsqueeze2.Index() != unsqueeze3.Index())) { + return false; + } + result.div_node = &div_node; result.node_indices = { where_node.Index(), @@ -423,10 +519,13 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid slice1.Index(), shape1.Index(), unsqueeze2.Index(), - unsqueeze3.Index(), squeeze2.Index(), - slice2.Index(), - shape2.Index()}; + slice2.Index()}; + + if (!use_shared_node) { + result.node_indices.push_back(unsqueeze3.Index()); + result.node_indices.push_back(shape2.Index()); + } DEBUG_LOG("Pass MatchUnidirMaskSubgraph"); return true; @@ -739,14 +838,14 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No } std::vector shape_value; if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[1]), shape_value, true) || - shape_value.size() != 1 || - shape_value[0] != 1) { + shape_value.size() != 1 || + shape_value[0] != 1) { return false; } shape_value.clear(); if (!optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[2]), shape_value, true) || - shape_value.size() != 1 || - shape_value[0] != 1) { + shape_value.size() != 1 || + shape_value[0] != 1) { return false; } @@ -894,8 +993,8 @@ bool CheckDistilBertReshapeShape(const Graph& graph, const Node& reshape, int64_ // lazy check: record unqueeze first and then check in the mask path std::vector shape_path{ - {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}}; + {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(reshape, true, shape_path, edges, logger)) { DEBUG_LOG("Failed to find shape path"); @@ -1138,12 +1237,12 @@ NodeArg* GetOrCreateMaskInt32( | (0,2,1,3) (0,2,3,1) (perm=0,2,1,3) | \ / | [Past]? \ / | | - | \ p_Concat? <------|---------------------{Past_Subgraphj}? + | \ k_Concat? <------|---------------------{Past_Subgraphj}? | \ / | | | qk_MatMul | | | | [B=h] | | | | / | / - | qk_Div p_Concat? <------------------ + | qk_Div v_Concat? <------------------ | | | | {Unidir_Mask_Subgraph} | [Mask]? | | / | @@ -1178,7 +1277,7 @@ After Fusion: --------> Add TODO: replace Gemm_Subgraph by MatMul + Add */ -bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, const logging::Logger& logger) { +bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start FuseGptAttention"); const Node* parent_node = graph_utils::GetInputNode(layer_norm, 0); if (nullptr == parent_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*parent_node, "Add", {7, 13}, kOnnxDomain)) { @@ -1191,7 +1290,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: } MatchGemmResult gemm1_result; - if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, logger) || + if (!MatchGemmSubgraph(graph, *graph.GetNode(add_after_gemm->Index()), 1, gemm1_result, use_shared_node, logger) || !ValidateGemmInitializer(graph, *gemm1_result.gemm, hidden_size, false, logger)) { return false; } @@ -1233,7 +1332,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: const Node& v_split = edges[2]->GetNode(); MatchGemmResult gemm0_result; - if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, logger) || + if (!MatchGemmSubgraph(graph, *graph.GetNode(v_split.Index()), 0, gemm0_result, use_shared_node, logger) || !ValidateGemmInitializer(graph, *gemm0_result.gemm, hidden_size, true, logger)) { return false; } @@ -1263,7 +1362,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: } MatchUnidirMaskResult unidir_mask_result; - if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, logger)) { + if (!MatchUnidirMaskSubgraph(graph, *(mask_nodes.has_input_mask ? mask_nodes.add : mask_nodes.softmax), unidir_mask_result, use_shared_node, logger)) { DEBUG_LOG("MatchUnidirMaskSubgraph returns NULL"); return false; } @@ -1365,7 +1464,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: nullptr, kMSDomain); attention_node.AddAttribute("num_heads", num_heads); - attention_node.AddAttribute("unidirectional", (int64_t)1); + attention_node.AddAttribute("unidirectional", static_cast(unidir_mask_result.is_unidirectional)); // Assign provider to this new node. attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType()); diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index 9089f8c5ad..687a78c150 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -23,16 +23,17 @@ class FusionGptAttention(Fusion): self.utils = FusionUtils(model) self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32 - def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask=''): + def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask, is_unidirectional): attention_node_name = self.model.create_node_name('GptAttention') attention_node = helper.make_node('Attention', inputs=[input, gemm.input[1], gemm.input[2], mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" - attention_node.attribute.extend( - [helper.make_attribute("num_heads", self.num_heads), - helper.make_attribute("unidirectional", 1)]) + attention_node.attribute.extend([ + helper.make_attribute("num_heads", self.num_heads), + helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) + ]) matmul_node = helper.make_node('MatMul', inputs=[attention_node_name + "_output", gemm_qkv.input[1]], @@ -115,6 +116,8 @@ class FusionGptAttention(Fusion): logger.debug("Add and LayerNormalization shall have one same input") return + is_unidirectional = True + slice_mask = None input_mask_nodes = None qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: @@ -127,6 +130,7 @@ class FusionGptAttention(Fusion): logger.debug("fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] + slice_mask = mask_nodes[3] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") @@ -162,11 +166,24 @@ class FusionGptAttention(Fusion): logger.debug("fuse_attention: failed to match mask path") return div_mask = mask_nodes[-1] + slice_mask = mask_nodes[2] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return + # Validate that the mask data is either lower triangular (unidirectional) or all ones + mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0])) + if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) + and mask_data.shape[2] == mask_data.shape[3]): + logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW") + return + if np.allclose(mask_data, np.ones_like(mask_data)): + is_unidirectional = False + elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))): + logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones") + return + q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") @@ -219,7 +236,7 @@ class FusionGptAttention(Fusion): self.casted_attention_mask[input_name] = attention_mask_input_name self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0], - reshape_qkv.output[0], attention_mask_input_name) + reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 85c23aae44..dffd88145d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1834,6 +1834,63 @@ TEST_F(GraphTransformationTests, AttentionFusionFloat32Test) { ValidateAttention(graph); } +// Test GPT-2 Attention Fusion with past and unidirectional mask +TEST_F(GraphTransformationTests, AttentionFusionWithPastAndUnidirMaskTest) { + auto model_uri = MODEL_FOLDER "fusion/attention_past_unidir.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0); + EXPECT_EQ(op_to_count["Softmax"], 0); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + + + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + if (p_node->OpType().compare("Attention") == 0) { + EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 1); + } + } +} + +// Test Attention Fusion with past but no unidirectional mask +TEST_F(GraphTransformationTests, AttentionFusionWithPastAndNoUnidirMaskTest) { + auto model_uri = MODEL_FOLDER "fusion/attention_past_no_unidir.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0); + EXPECT_EQ(op_to_count["Softmax"], 0); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + Node* p_node = graph.GetNode(node_index); + if (p_node->OpType().compare("Attention") == 0) { + EXPECT_EQ(p_node->GetAttributes().at("unidirectional").i(), 0); + } + } +} + // Test GPT-2 Attention Fusion with float32 mask TEST_F(GraphTransformationTests, AttentionFusionGPTWithPastAndMaskTest) { auto model_uri = MODEL_FOLDER "fusion/gpt2_past_mask_one_layer.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx b/onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx new file mode 100644 index 0000000000..a5ea98d286 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/attention_past_no_unidir.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx b/onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx new file mode 100644 index 0000000000..22a0d1284d Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/attention_past_unidir.onnx differ