mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Fix MatMulBnFusion to exclude cases when tensors are not 2D tensors (#22762)
### Description Fixes #22512, MatMul, Add can be fused into a single Gemm even if tensors dimensions are > 2. The PR excludes that cases. ### Motivation and Context ORT crashes on valid models due to that unexpected fusion.
This commit is contained in:
parent
c5276ac448
commit
1f3b675453
3 changed files with 46 additions and 0 deletions
|
|
@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons
|
|||
return false;
|
||||
}
|
||||
|
||||
// Checks the first input of MatMul has 2 dimensions.
|
||||
// The test for the second input is done in method Apply as it accesses the constant.
|
||||
if (node.InputDefs()[0] == nullptr) {
|
||||
// This should never happen but just in case.
|
||||
return false;
|
||||
}
|
||||
auto shape_a = node.InputDefs()[0]->Shape();
|
||||
if (shape_a == nullptr) {
|
||||
// We cannot shape the rank. It is better to avoid fusing.
|
||||
return false;
|
||||
}
|
||||
if (shape_a->dim_size() != 2) {
|
||||
// Gemm only supports 2D tensors.
|
||||
return false;
|
||||
}
|
||||
|
||||
// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
|
||||
const auto& output_defs = batch_norm_node->OutputDefs();
|
||||
if (output_defs.size() > 1) {
|
||||
|
|
@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
|
|||
bias_tensor->dims_size() != 1 ||
|
||||
mean_tensor->dims_size() != 1 ||
|
||||
var_tensor->dims_size() != 1 ||
|
||||
matmul_b_tensor->dims_size() != 2 ||
|
||||
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
|
||||
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
|
||||
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||
|
||||
|
|
|
|||
|
|
@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::string expected_output_name;
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "BatchNormalization") {
|
||||
expected_output_name = node.OutputDefs()[0]->Name();
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 1);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue