diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index 5ce8083cb1..7d5578d385 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -23,15 +23,26 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie // Currently, fusion is only supported for float or double data type. if (!Initializer::IsSupportedDataType(add_B_tensor_proto) || - conv_W_tensor_proto->dims_size() < 4 || - add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 || - conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) { + conv_W_tensor_proto->dims_size() < 4) { return Status::OK(); } - // The dimensions of add_B should be equal to 1 except first dimension. - for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) { - if (add_B_tensor_proto->dims(i) != 1) { + int axis; + if (add_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size()) { + // Test for broadcast add such as 1xCx1x1 for a 2D convolution. + axis = 1; + } else if (add_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1) { + // Test for broadcast add such as Cx1x1 for a 2D convolution. + axis = 0; + } else { + return Status::OK(); + } + if (add_B_tensor_proto->dims(axis) != conv_W_tensor_proto->dims(0)) { + return Status::OK(); + } + // The dimensions of add_B should be equal to 1 except axis dimension. + for (int i = 0; i < add_B_tensor_proto->dims_size(); i++) { + if (i != axis && add_B_tensor_proto->dims(i) != 1) { return Status::OK(); } } @@ -43,7 +54,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() || conv_B_tensor_proto->dims_size() != 1 || - conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) { + conv_B_tensor_proto->dims(0) != conv_W_tensor_proto->dims(0)) { return Status::OK(); } diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 7d6379cc66..de6d096ba6 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -24,17 +24,27 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) || !Initializer::IsSupportedDataType(mul_B_tensor_proto) || conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() || - conv_W_tensor_proto->dims_size() < 4 || - !(mul_B_tensor_proto->dims_size() == 0 || - (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 && - conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) { + conv_W_tensor_proto->dims_size() < 4) { return Status::OK(); } - // The dimensions of mul_B should be equal to 1 except first dimension. if (mul_B_tensor_proto->dims_size() != 0) { - for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) { - if (mul_B_tensor_proto->dims(i) != 1) { + int axis; + if (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size()) { + // Test for broadcast multiply such as 1xCx1x1 for a 2D convolution. + axis = 1; + } else if (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1) { + // Test for broadcast multiply such as Cx1x1 for a 2D convolution. + axis = 0; + } else { + return Status::OK(); + } + if (mul_B_tensor_proto->dims(axis) != conv_W_tensor_proto->dims(0)) { + return Status::OK(); + } + // The dimensions of mul_B should be equal to 1 except axis dimension. + for (int i = 0; i < mul_B_tensor_proto->dims_size(); i++) { + if (i != axis && mul_B_tensor_proto->dims(i) != 1) { return Status::OK(); } } @@ -54,7 +64,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() || conv_B_tensor_proto->dims_size() != 1 || - (mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) { + conv_B_tensor_proto->dims(0) != conv_W_tensor_proto->dims(0)) { return Status::OK(); } conv_B = std::make_unique(conv_B_tensor_proto); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 580c1f16c8..73f0418b8d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -266,6 +266,26 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) { ASSERT_TRUE(op_to_count["Mul"] == 0); } +TEST(GraphTransformationTests, FuseConvAddMul3D_2) { + string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-mul-3d-2.onnx"; + + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + rule_transformer_L2->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2); + + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); +} + TEST(GraphTransformationTests, MatMulAddFusion_two_input) { string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-conv-add-mul-3d-2.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-conv-add-mul-3d-2.onnx new file mode 100644 index 0000000000..3fe7832c88 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-conv-add-mul-3d-2.onnx differ