Bugfix for BiasSoftmax Fusion (#12517)

This commit is contained in:
Vincent Wang 2022-08-10 07:20:13 +08:00 committed by GitHub
parent 0d9a02e647
commit 0c6037b5ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 5 deletions

View file

@ -9,6 +9,10 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
// BiasSoftmax follows the OpSet-11 definision of Softmax Op, that is, the input will be coerced to a 2D tensor
// using axis attribute, all dims after axis (included) are in the same batch. This is different from definition
// since OpSet-13. To use BiasSoftmax, during the fusion, if Softmax is OpSet-13 or newer, you can only fuse it
// when axis attribute is the last dim, othewise, the computation result may be wrong.
class BiasSoftmax final : public onnxruntime::cuda::CudaKernel {
public:
BiasSoftmax(const OpKernelInfo& info) : CudaKernel{info} {

View file

@ -108,6 +108,12 @@ bool TryBiasSoftmaxSubgraphMatch(Graph& graph, Node& start, Node*& add, Node*& s
*
* In the BERT case scores shape = [batch_size, num_heads, seq_length, seq_length]
* and sequence mask shape = [batch_size, 1, 1, seq_length]
*
* NOTE that the axis attribute for Softmax in OpSet-11 and OpSet-13 are different. For OpSet-11, dim ak to dim a(N-1)
* are in same batch. But since OpSet-13, only ak is in a batch. Above fusion logic is for OpSet-11 or before.
* Since OpSet-13, to compute Softmax, we will first transpose the axis dim to the last dim before the real Softmax
* computation if axis is not the last dim. Fusing Add+Softmax to BiasSoftmax would require extra transpose for bias,
* and bring complex checking condition. So since OpSet-13, we will apply the fusion only when axis is the last dim.
*/
bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, NodeArg*& input, NodeArg*& mask,
int& new_axis, bool& is_inner_broadcast) {
@ -115,7 +121,8 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
NodeArg* input2 = add_node.MutableInputDefs()[1];
// default axis = -1 if opset >= 13
int axis = graph_utils::MatchesOpSinceVersion(softmax_node, {1, 11}) ? 1 : -1;
bool is_since_opset_13 = !graph_utils::MatchesOpSinceVersion(softmax_node, {1, 11});
int axis = is_since_opset_13 ? -1 : 1;
auto& softmax_attr = softmax_node.GetAttributes();
if (softmax_attr.find("axis") != softmax_attr.end()) {
auto& axis_attr = softmax_attr.at("axis");
@ -124,9 +131,13 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
int N1 = input1->Shape()->dim_size();
int N2 = input2->Shape()->dim_size();
new_axis = (int)HandleNegativeAxis(axis, std::max({N1, N2}));
int singlebatch_rank = std::max({N1 - new_axis, N2 - new_axis});
int rank = std::max({N1, N2});
new_axis = (int)HandleNegativeAxis(axis, rank);
// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
if (is_since_opset_13 && new_axis != rank - 1) return false;
int singlebatch_rank = rank - new_axis;
if (singlebatch_rank > N1 || singlebatch_rank > N2) {
return false;
}

View file

@ -3674,7 +3674,7 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OuterBroadcast) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<BiasSoftmaxFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
@ -3691,11 +3691,37 @@ TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OuterBroadcast) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<BiasSoftmaxFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
}
TEST_F(GraphTransformationTests, BiasSoftmaxFusionTest_OpSet13InValidAxis) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3, 2, 3, 3, 3}});
auto* bias_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* add_out = builder.MakeIntermediate();
auto* softmax_out = builder.MakeOutput();
builder.AddNode("Add", {input_arg, bias_arg}, {add_out});
builder.AddNode("Softmax", {add_out}, {softmax_out}).AddAttribute("axis", static_cast<int64_t>(6));
};
auto pre_graph_checker = [&](Graph& graph) {
for (auto& node : graph.Nodes()) node.SetExecutionProviderType(kCudaExecutionProvider);
ASSERT_EQ(CountOpsInGraph(graph)["Softmax"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Softmax"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.BiasSoftmax"], 0);
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<BiasSoftmaxFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
static void TestBiasDropoutFusion(const PathString& file_path, const logging::Logger& logger, const int add_count = 0) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, logger).IsOK());