Fuse elementwise add/multiply into convolution (#1028)

* fix elementwise fusions

* update comments

* test case
This commit is contained in:
Tracy Sharpe 2019-05-15 20:54:21 -07:00 committed by GitHub
parent d14e65a224
commit 2e19b14e4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 15 deletions

View file

@ -23,15 +23,26 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
conv_W_tensor_proto->dims_size() < 4 ||
add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 ||
conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
conv_W_tensor_proto->dims_size() < 4) {
return Status::OK();
}
// The dimensions of add_B should be equal to 1 except first dimension.
for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) {
if (add_B_tensor_proto->dims(i) != 1) {
int axis;
if (add_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size()) {
// Test for broadcast add such as 1xCx1x1 for a 2D convolution.
axis = 1;
} else if (add_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1) {
// Test for broadcast add such as Cx1x1 for a 2D convolution.
axis = 0;
} else {
return Status::OK();
}
if (add_B_tensor_proto->dims(axis) != conv_W_tensor_proto->dims(0)) {
return Status::OK();
}
// The dimensions of add_B should be equal to 1 except axis dimension.
for (int i = 0; i < add_B_tensor_proto->dims_size(); i++) {
if (i != axis && add_B_tensor_proto->dims(i) != 1) {
return Status::OK();
}
}
@ -43,7 +54,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
conv_B_tensor_proto->dims_size() != 1 ||
conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
conv_B_tensor_proto->dims(0) != conv_W_tensor_proto->dims(0)) {
return Status::OK();
}

View file

@ -24,17 +24,27 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
conv_W_tensor_proto->dims_size() < 4 ||
!(mul_B_tensor_proto->dims_size() == 0 ||
(mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 &&
conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) {
conv_W_tensor_proto->dims_size() < 4) {
return Status::OK();
}
// The dimensions of mul_B should be equal to 1 except first dimension.
if (mul_B_tensor_proto->dims_size() != 0) {
for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) {
if (mul_B_tensor_proto->dims(i) != 1) {
int axis;
if (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size()) {
// Test for broadcast multiply such as 1xCx1x1 for a 2D convolution.
axis = 1;
} else if (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1) {
// Test for broadcast multiply such as Cx1x1 for a 2D convolution.
axis = 0;
} else {
return Status::OK();
}
if (mul_B_tensor_proto->dims(axis) != conv_W_tensor_proto->dims(0)) {
return Status::OK();
}
// The dimensions of mul_B should be equal to 1 except axis dimension.
for (int i = 0; i < mul_B_tensor_proto->dims_size(); i++) {
if (i != axis && mul_B_tensor_proto->dims(i) != 1) {
return Status::OK();
}
}
@ -54,7 +64,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
conv_B_tensor_proto->dims_size() != 1 ||
(mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) {
conv_B_tensor_proto->dims(0) != conv_W_tensor_proto->dims(0)) {
return Status::OK();
}
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);

View file

@ -266,6 +266,26 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
ASSERT_TRUE(op_to_count["Mul"] == 0);
}
TEST(GraphTransformationTests, FuseConvAddMul3D_2) {
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-mul-3d-2.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
}
TEST(GraphTransformationTests, MatMulAddFusion_two_input) {
string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx";