diff --git a/onnxruntime/contrib_ops/cuda/math/bias_softmax.h b/onnxruntime/contrib_ops/cuda/math/bias_softmax.h index 60434319db..c837a05192 100644 --- a/onnxruntime/contrib_ops/cuda/math/bias_softmax.h +++ b/onnxruntime/contrib_ops/cuda/math/bias_softmax.h @@ -9,6 +9,10 @@ namespace onnxruntime { namespace contrib { namespace cuda { +// BiasSoftmax follows the OpSet-11 definision of Softmax Op, that is, the input will be coerced to a 2D tensor +// using axis attribute, all dims after axis (included) are in the same batch. This is different from definition +// since OpSet-13. To use BiasSoftmax, during the fusion, if Softmax is OpSet-13 or newer, you can only fuse it +// when axis attribute is the last dim, othewise, the computation result may be wrong. class BiasSoftmax final : public onnxruntime::cuda::CudaKernel { public: BiasSoftmax(const OpKernelInfo& info) : CudaKernel{info} { diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index 91290a0f94..80603cdbd3 100755 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -108,6 +108,12 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s * * In the BERT case scores shape = [batch_size, num_heads, seq_length, seq_length] * and sequence mask shape = [batch_size, 1, 1, seq_length] + * + * NOTE that the axis attribute for Softmax in OpSet-11 and OpSet-13 are different. For OpSet-11, dim ak to dim a(N-1) + * are in same batch. But since OpSet-13, only ak is in a batch. Above fusion logic is for OpSet-11 or before. + * Since OpSet-13, to compute Softmax, we will first transpose the axis dim to the last dim before the real Softmax + * computation if axis is not the last dim. Fusing Add+Softmax to BiasSoftmax would require extra transpose for bias, + * and bring complex checking condition. So since OpSet-13, we will apply the fusion only when axis is the last dim. */ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, NodeArg*& input, NodeArg*& mask, int& new_axis, bool& is_inner_broadcast) { @@ -115,7 +121,8 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node NodeArg* input2 = add_node.MutableInputDefs()[1]; // default axis = -1 if opset >= 13 - int axis = graph_utils::MatchesOpSinceVersion(softmax_node, {1, 11}) ? 1 : -1; + bool is_since_opset_13 = !graph_utils::MatchesOpSinceVersion(softmax_node, {1, 11}); + int axis = is_since_opset_13 ? -1 : 1; auto& softmax_attr = softmax_node.GetAttributes(); if (softmax_attr.find("axis") != softmax_attr.end()) { auto& axis_attr = softmax_attr.at("axis"); @@ -124,9 +131,13 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node int N1 = input1->Shape()->dim_size(); int N2 = input2->Shape()->dim_size(); - new_axis = (int)HandleNegativeAxis(axis, std::max({N1, N2})); - int singlebatch_rank = std::max({N1 - new_axis, N2 - new_axis}); + int rank = std::max({N1, N2}); + new_axis = (int)HandleNegativeAxis(axis, rank); + // The axis attribute for Softmax in OpSet-11 and OpSet-13 are different. + if (is_since_opset_13 && new_axis != rank - 1) return false; + + int singlebatch_rank = rank - new_axis; if (singlebatch_rank > N1 || singlebatch_rank > N2) { return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 515224e7cd..63d40717d9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -3674,7 +3674,7 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OuterBroadcast) { }; std::unique_ptr transformer = std::make_unique(); - TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker); } @@ -3691,11 +3691,37 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OuterBroadcast) { }; std::unique_ptr transformer = std::make_unique(); - TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker); } } +TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OpSet13InValidAxis) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{2, 3, 3, 3, 2, 3, 3, 3}}); + auto* bias_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* add_out = builder.MakeIntermediate(); + auto* softmax_out = builder.MakeOutput(); + + builder.AddNode("Add", {input_arg, bias_arg}, {add_out}); + builder.AddNode("Softmax", {add_out}, {softmax_out}).AddAttribute("axis", static_cast(6)); + }; + + auto pre_graph_checker = [&](Graph& graph) { + for (auto& node : graph.Nodes()) node.SetExecutionProviderType(kCudaExecutionProvider); + ASSERT_EQ(CountOpsInGraph(graph)["Softmax"], 1); + }; + + auto post_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Softmax"], 1); + ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.BiasSoftmax"], 0); + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); +} + static void TestBiasDropoutFusion(const PathString& file_path, const logging::Logger& logger, const int add_count = 0) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, logger).IsOK());