Fix MatMulBnFusion to exclude cases when tensors are not 2D tensors (#22762)

### Description
Fixes #22512, MatMul, Add can be fused into a single Gemm even if
tensors dimensions are > 2. The PR excludes that cases.



### Motivation and Context
ORT crashes on valid models due to that unexpected fusion.
This commit is contained in:
Xavier Dupré 2024-11-11 19:48:25 +01:00 committed by GitHub
parent c5276ac448
commit 1f3b675453
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 46 additions and 0 deletions

View file

@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons
return false;
}
// Checks the first input of MatMul has 2 dimensions.
// The test for the second input is done in method Apply as it accesses the constant.
if (node.InputDefs()[0] == nullptr) {
// This should never happen but just in case.
return false;
}
auto shape_a = node.InputDefs()[0]->Shape();
if (shape_a == nullptr) {
// We cannot shape the rank. It is better to avoid fusing.
return false;
}
if (shape_a->dim_size() != 2) {
// Gemm only supports 2D tensors.
return false;
}
// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
const auto& output_defs = batch_norm_node->OutputDefs();
if (output_defs.size() > 1) {
@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
bias_tensor->dims_size() != 1 ||
mean_tensor->dims_size() != 1 ||
var_tensor->dims_size() != 1 ||
matmul_b_tensor->dims_size() != 2 ||
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||

View file

@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
}
}
TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
std::string expected_output_name;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "BatchNormalization") {
expected_output_name = node.OutputDefs()[0]->Name();
}
}
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
ASSERT_EQ(op_to_count["MatMul"], 1);
ASSERT_EQ(op_to_count["Gemm"], 0);
}
TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";