mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Fuse elementwise add/multiply into convolution (#1028)
* fix elementwise fusions * update comments * test case
This commit is contained in:
parent
d14e65a224
commit
2e19b14e4e
4 changed files with 56 additions and 15 deletions
|
|
@ -23,15 +23,26 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
|
|||
|
||||
// Currently, fusion is only supported for float or double data type.
|
||||
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
|
||||
conv_W_tensor_proto->dims_size() < 4 ||
|
||||
add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 ||
|
||||
conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
|
||||
conv_W_tensor_proto->dims_size() < 4) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The dimensions of add_B should be equal to 1 except first dimension.
|
||||
for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) {
|
||||
if (add_B_tensor_proto->dims(i) != 1) {
|
||||
int axis;
|
||||
if (add_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size()) {
|
||||
// Test for broadcast add such as 1xCx1x1 for a 2D convolution.
|
||||
axis = 1;
|
||||
} else if (add_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1) {
|
||||
// Test for broadcast add such as Cx1x1 for a 2D convolution.
|
||||
axis = 0;
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
if (add_B_tensor_proto->dims(axis) != conv_W_tensor_proto->dims(0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
// The dimensions of add_B should be equal to 1 except axis dimension.
|
||||
for (int i = 0; i < add_B_tensor_proto->dims_size(); i++) {
|
||||
if (i != axis && add_B_tensor_proto->dims(i) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
@ -43,7 +54,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
|
|||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
|
||||
conv_B_tensor_proto->dims(0) != conv_W_tensor_proto->dims(0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,17 +24,27 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
|
|||
if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
|
||||
conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
|
||||
conv_W_tensor_proto->dims_size() < 4 ||
|
||||
!(mul_B_tensor_proto->dims_size() == 0 ||
|
||||
(mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 &&
|
||||
conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) {
|
||||
conv_W_tensor_proto->dims_size() < 4) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The dimensions of mul_B should be equal to 1 except first dimension.
|
||||
if (mul_B_tensor_proto->dims_size() != 0) {
|
||||
for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) {
|
||||
if (mul_B_tensor_proto->dims(i) != 1) {
|
||||
int axis;
|
||||
if (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size()) {
|
||||
// Test for broadcast multiply such as 1xCx1x1 for a 2D convolution.
|
||||
axis = 1;
|
||||
} else if (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1) {
|
||||
// Test for broadcast multiply such as Cx1x1 for a 2D convolution.
|
||||
axis = 0;
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
if (mul_B_tensor_proto->dims(axis) != conv_W_tensor_proto->dims(0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
// The dimensions of mul_B should be equal to 1 except axis dimension.
|
||||
for (int i = 0; i < mul_B_tensor_proto->dims_size(); i++) {
|
||||
if (i != axis && mul_B_tensor_proto->dims(i) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
@ -54,7 +64,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
|
|||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
(mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) {
|
||||
conv_B_tensor_proto->dims(0) != conv_W_tensor_proto->dims(0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
|
|
|
|||
|
|
@ -266,6 +266,26 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
|
|||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvAddMul3D_2) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-mul-3d-2.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
|
||||
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MatMulAddFusion_two_input) {
|
||||
string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx";
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/fuse-conv-add-mul-3d-2.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-conv-add-mul-3d-2.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue