diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index d743a038cb..b88f2d6a46 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -268,6 +268,7 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, int64_t hidden_size, int64_t num_heads, int64_t head_size, + const float mask_filter_value, const logging::Logger& logger) { InlinedVector> pivot_nodes; if (edges.size() == 2) { @@ -386,6 +387,7 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, nullptr, kMSDomain); attention_node.AddAttribute("num_heads", num_heads); + attention_node.AddAttribute("mask_filter_value", mask_filter_value); // Assign provider to this new node. attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType()); @@ -437,7 +439,7 @@ static bool FuseSubGraphQK(Node& layer_norm, std::vector nodes_to_remove; if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, - num_heads, head_size, logger)) { + num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -528,7 +530,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm, std::vector nodes_to_remove; if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, - num_heads, head_size, logger)) { + num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -580,7 +582,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm, | qk_MatMul | | | | [B=2] | ([A=1.0] mask_Cast(to=1)) | | / | \ / - | qk_Div | mask_Sub [B=-10000.0] + | qk_Div | mask_Sub [B=-10000.0 or value of mask_filter_value] | \ | \ / | mask_Add <-------- /---------------------mask_Mul | | / diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index e7f3737088..8e312495da 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -265,7 +265,8 @@ bool ValidateGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidde struct MatchUnidirMaskResult { const Node* div_node = nullptr; // the root node (Div) of the subgraph - bool is_unidirectional = false; // whether the mask is unidirectional. + bool is_unidirectional = false; // whether the mask is unidirectional. + float mask_filter_value = -10000.0f; // the value to filter out the mask. std::vector node_indices; // id of all nodes in the subgraph for removing later. }; @@ -375,7 +376,7 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir | (*, -2, -1, 0) (axes=0) | Cast(9) +----> Shape --> Slice ---------> Squeeze-------+ | | :shape2 :slice2 :squeeze2 v condition - +----------------------------------------------------------------------------------------->Where( ,*,-10000)--->[Add] + +----------------------------------------------------------------------------------------->Where( ,*,-10000 or value of mask_filter_value)--->[Add] When use_shared_node is true, shape1 and shape2 is one node, and also unsqueeze2 and unsqueeze3 is same. */ @@ -394,8 +395,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid const Node& where_node = edges[0]->GetNode(); const Node& div_node = edges[1]->GetNode(); - constexpr float expected_value = -10000.0f; - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where_node.InputDefs()[2]), expected_value, true)) { + if (!optimizer_utils::GetScalarInitializerValue(graph, *(where_node.InputDefs()[2]), result.mask_filter_value, true)) { return false; } @@ -550,6 +550,7 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid struct AttentionMaskNodes { const Node* softmax; bool has_input_mask; // When it is false, the following nodes will be NULL. + float mask_filter_value = -10000.0f; const Node* add; const Node* mul; @@ -566,6 +567,7 @@ struct AttentionMaskNodesDistilBert { const Node* reshape; const Node* equal; const Node* shape; + float mask_filter_value = -10000.0f; }; void SetMaskNodesToRemove(const Graph& graph, AttentionMaskNodes& mask_nodes, std::vector& nodes_to_remove) { @@ -598,10 +600,10 @@ void SetMaskNodesToRemove(const Graph&, AttentionMaskNodesDistilBert& mask_nodes } /** Match Input Mask subgraph: - {UnidirMask Subgraph} - | - (optional) v -[Attention_mask] --> Unsqueeze (axes=1) --> Unsqueeze (axes=2) --> Cast ---->Sub(1,*) --> Mul(*, -10000.0) --> Add( ,*)--->SoftMax -->[MatMul] + { UnidirMask Subgraph} + | + (optional) v +[Attention_mask] --> Unsqueeze (axes=1) --> Unsqueeze (axes=2) --> Cast ---->Sub(1,*) --> Mul(*, -10000.0 or value of mask_filter_value) --> Add( ,*)--->SoftMax -->[MatMul] When is_input_mask_optional is true, this function also matches the following subgraph: {UnidirMask Subgraph [Where]} --> Softmax --> [MatMul] @@ -710,7 +712,7 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& qkv_matmul, Attentio return false; } - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(mask_mul.InputDefs()[1]), float(-10000), false)) { + if (!optimizer_utils::GetScalarInitializerValue(graph, *(mask_mul.InputDefs()[1]), result.mask_filter_value, false)) { DEBUG_LOG("mask_mul const input not matched"); return false; } @@ -1461,7 +1463,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: opt_k_transpose = k_concat; InlinedVector perm; - if (!(graph_utils::GetRepeatedNodeAttributeValues(*opt_k_transpose, "perm", perm) + if (!(graph_utils::GetRepeatedNodeAttributeValues(*opt_k_transpose, "perm", perm) && perm.size() == 4 && perm[0] == 0 && perm[1] == 1 && perm[2] == 3 && perm[3] == 2)) { DEBUG_LOG("opt_k_transpose perm attribute not matched"); return false; @@ -1541,6 +1543,11 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: kMSDomain); attention_node.AddAttribute("num_heads", num_heads); attention_node.AddAttribute("unidirectional", static_cast(unidir_mask_result.is_unidirectional)); + if (mask_nodes.mask_filter_value != -10000.0f or unidir_mask_result.mask_filter_value != -10000.0f) { + float mask_filter_value = mask_nodes.mask_filter_value != -10000.0f ? + mask_nodes.mask_filter_value : unidir_mask_result.mask_filter_value; + attention_node.AddAttribute("mask_filter_value", mask_filter_value); + } // Assign provider to this new node. attention_node.SetExecutionProviderType(layer_norm.GetExecutionProviderType()); diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 1066dfbebc..fe40188f05 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -363,6 +363,34 @@ bool IsScalar(const NodeArg& input_arg) { return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1); } +template +bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value, + bool is_constant) { + if (!IsScalar(input_arg)) { + return false; + } + + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (is_constant) { + tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + } else if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) { + return false; + } + + if (tensor_proto == nullptr) { + return false; + } + + Initializer init_const{*tensor_proto, graph.ModelPath()}; + const T* val = init_const.data(); + value = *val; + + return true; +} + +template bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, float& value, + bool is_constant); + #endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } // namespace optimizer_utils diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index cbb9764faf..caab7ca1d9 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -103,6 +103,10 @@ bool IsSupportedDataType(const Node& node, const T& supported_data_types) { bool IsOperationDeterministic(const std::string& domain, const std::string& op); +template +bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntime::NodeArg& input_arg, T& value, + bool is_constant); + #endif // !#if !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)