mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Bugfix for BiasSoftmax Fusion (#12517)
This commit is contained in:
parent
0d9a02e647
commit
0c6037b5ab
3 changed files with 46 additions and 5 deletions
|
|
@ -9,6 +9,10 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
// BiasSoftmax follows the OpSet-11 definision of Softmax Op, that is, the input will be coerced to a 2D tensor
|
||||
// using axis attribute, all dims after axis (included) are in the same batch. This is different from definition
|
||||
// since OpSet-13. To use BiasSoftmax, during the fusion, if Softmax is OpSet-13 or newer, you can only fuse it
|
||||
// when axis attribute is the last dim, othewise, the computation result may be wrong.
|
||||
class BiasSoftmax final : public onnxruntime::cuda::CudaKernel {
|
||||
public:
|
||||
BiasSoftmax(const OpKernelInfo& info) : CudaKernel{info} {
|
||||
|
|
|
|||
|
|
@ -108,6 +108,12 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s
|
|||
*
|
||||
* In the BERT case scores shape = [batch_size, num_heads, seq_length, seq_length]
|
||||
* and sequence mask shape = [batch_size, 1, 1, seq_length]
|
||||
*
|
||||
* NOTE that the axis attribute for Softmax in OpSet-11 and OpSet-13 are different. For OpSet-11, dim ak to dim a(N-1)
|
||||
* are in same batch. But since OpSet-13, only ak is in a batch. Above fusion logic is for OpSet-11 or before.
|
||||
* Since OpSet-13, to compute Softmax, we will first transpose the axis dim to the last dim before the real Softmax
|
||||
* computation if axis is not the last dim. Fusing Add+Softmax to BiasSoftmax would require extra transpose for bias,
|
||||
* and bring complex checking condition. So since OpSet-13, we will apply the fusion only when axis is the last dim.
|
||||
*/
|
||||
bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, NodeArg*& input, NodeArg*& mask,
|
||||
int& new_axis, bool& is_inner_broadcast) {
|
||||
|
|
@ -115,7 +121,8 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
|
|||
NodeArg* input2 = add_node.MutableInputDefs()[1];
|
||||
|
||||
// default axis = -1 if opset >= 13
|
||||
int axis = graph_utils::MatchesOpSinceVersion(softmax_node, {1, 11}) ? 1 : -1;
|
||||
bool is_since_opset_13 = !graph_utils::MatchesOpSinceVersion(softmax_node, {1, 11});
|
||||
int axis = is_since_opset_13 ? -1 : 1;
|
||||
auto& softmax_attr = softmax_node.GetAttributes();
|
||||
if (softmax_attr.find("axis") != softmax_attr.end()) {
|
||||
auto& axis_attr = softmax_attr.at("axis");
|
||||
|
|
@ -124,9 +131,13 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
|
|||
|
||||
int N1 = input1->Shape()->dim_size();
|
||||
int N2 = input2->Shape()->dim_size();
|
||||
new_axis = (int)HandleNegativeAxis(axis, std::max({N1, N2}));
|
||||
int singlebatch_rank = std::max({N1 - new_axis, N2 - new_axis});
|
||||
int rank = std::max({N1, N2});
|
||||
new_axis = (int)HandleNegativeAxis(axis, rank);
|
||||
|
||||
// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
|
||||
if (is_since_opset_13 && new_axis != rank - 1) return false;
|
||||
|
||||
int singlebatch_rank = rank - new_axis;
|
||||
if (singlebatch_rank > N1 || singlebatch_rank > N2) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3674,7 +3674,7 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OuterBroadcast) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<BiasSoftmaxFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
|
||||
TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
|
|
@ -3691,11 +3691,37 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OuterBroadcast) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<BiasSoftmaxFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
|
||||
TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OpSet13InValidAxis) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3, 2, 3, 3, 3}});
|
||||
auto* bias_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* add_out = builder.MakeIntermediate();
|
||||
auto* softmax_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Add", {input_arg, bias_arg}, {add_out});
|
||||
builder.AddNode("Softmax", {add_out}, {softmax_out}).AddAttribute("axis", static_cast<int64_t>(6));
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
for (auto& node : graph.Nodes()) node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Softmax"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Softmax"], 1);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.BiasSoftmax"], 0);
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<BiasSoftmaxFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
static void TestBiasDropoutFusion(const PathString& file_path, const logging::Logger& logger, const int add_count = 0) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, logger).IsOK());
|
||||
|
|
|
|||
Loading…
Reference in a new issue