fuse Conv+Add+activation for CPU from different op-branch (#10987)

* Fuse op conv Add and activation from two branch
* simplify code

Co-authored-by: Jicheng Wen <jicwen@microsoft.com>
This commit is contained in:
wejoncy 2022-04-01 09:25:17 +08:00 committed by GitHub
parent 79e4ed8064
commit 11a4ca741d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 457 additions and 32 deletions

View file

@ -739,6 +739,7 @@ struct MLAS_CONV_PARAMETERS {
size_t InputSize;
size_t OutputSize;
size_t K;
float Beta;
MLAS_CONV_ALGORITHM Algorithm;
ptrdiff_t ThreadCount;
union {
@ -752,25 +753,23 @@ struct MLAS_CONV_PARAMETERS {
} u;
};
void
MLASCALL
MlasConvPrepare(
MLAS_CONV_PARAMETERS* Parameters,
size_t Dimensions,
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
size_t FilterCount,
const MLAS_ACTIVATION* Activation,
size_t* WorkingBufferSize,
MLAS_THREADPOOL* ThreadPool
);
void MLASCALL
MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters,
size_t Dimensions,
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* DilationShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
size_t FilterCount,
const MLAS_ACTIVATION* Activation,
size_t* WorkingBufferSize,
float Beta,
MLAS_THREADPOOL* ThreadPool);
void
MLASCALL

View file

@ -571,7 +571,7 @@ Return Value:
//
size_t CountK;
float beta = 0.0f;
float beta = Parameters->Beta;
float* SegmentOutput = Output + SegmentStartN + n;
for (size_t k = 0; k < K; k += CountK) {
@ -934,9 +934,9 @@ Return Value:
// Invoke the threaded GEMM directly with the input tensor.
//
MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount,
OutputSize, K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb, 0.0f,
Output, OutputSize, ThreadPool);
MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, OutputSize,
K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb,
Parameters->Beta, Output, OutputSize, ThreadPool);
//
// Apply the activation with optional bias.
@ -962,7 +962,8 @@ Return Value:
}
MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, OutputSize, K, 1.0f, filter,
K, WorkingBuffer, OutputSize, 0.0f, Output, OutputSize, ThreadPool);
K, WorkingBuffer, OutputSize, Parameters->Beta, Output, OutputSize,
ThreadPool);
//
// Apply the activation with optional bias.
@ -1038,6 +1039,7 @@ MlasConvPrepare(
size_t FilterCount,
const MLAS_ACTIVATION* Activation,
size_t* WorkingBufferSize,
float Beta,
MLAS_THREADPOOL* ThreadPool
)
/*++
@ -1100,6 +1102,7 @@ Return Value:
Parameters->GroupCount = GroupCount;
Parameters->InputChannels = InputChannels;
Parameters->FilterCount = FilterCount;
Parameters->Beta = Beta;
size_t InputSize = 1;
size_t OutputSize = 1;

View file

@ -0,0 +1,285 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <deque>
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/conv_add_act_fusion.h"
#include "core/mlas/inc/mlas.h"
#include "core/graph/node_attr_utils.h"
#include "core/optimizer/utils.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace {
namespace selectors {
bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
if (!node_arg.Exists()) {
return false;
}
const auto* type_proto = node_arg.TypeAsProto();
if (!type_proto) {
return false;
}
int32_t actual_data_type;
if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) {
return false;
}
return data_type == actual_data_type;
}
const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& node) {
if (!optimizer_utils::CheckOutputEdges(graph_viewer.GetGraph(), node, 1)) {
return nullptr;
}
return &*node.OutputNodesBegin();
}
class ConvAddActivation : public NodeSelector {
public:
ConvAddActivation() = default;
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override {
const std::string_view node_ep = node.GetExecutionProviderType();
if (node_ep != kCpuExecutionProvider || !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return std::nullopt;
}
// we can't assign `conv_node` as the producer-node, even it is, because we have to make sure
// 1. Its type is 'conv', 2. it has to satisfy the other requirements,like shape, please refer to SelectConvProducer for more info
const Node* conv_node = nullptr;
const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
if (!add_node) {
return std::nullopt;
}
// Let's support addition first, leave any-element-wise-op fusion in the future.
// what we want to here is that:
// 1 find the Add node, 2 find it's producer node and make sure it's a conv node
// 3 find the next node and check if it's a activation node, if yes, we will fuse conv+add+activation or conv+add
//
if (graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {7, 13, 14})) {
conv_node = SelectProducerConv(*add_node);
}
if (!conv_node) {
return std::nullopt;
}
// GetLoneConsumerNode will ensure outputedge_count is 1
const auto* act_node = GetLoneConsumerNode(graph_viewer, *add_node);
// even the next node is not a activation node, it's also fine.
if (!act_node) {
// we can't fuse add-activation when add_node has multiple consumer nodes
act_node = nullptr;
} else if (SelectActivation(graph_viewer, *act_node)) {
// this branch is deliberately empty as we want to keep 'act_node' as remains.
} else {
act_node = nullptr;
}
NodesToOptimizeIndicesBuilder builder{};
builder.target_node = conv_node->Index();
builder.output_nodes = {add_node->Index()};
if (act_node) {
builder.output_nodes.push_back(act_node->Index());
}
return builder.Build();
}
static bool SelectActivation(const GraphViewer& graph_viewer, const Node& activation_node) {
auto is_supported_cpu_ep_activation = [&graph_viewer](const Node& activation_node) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Relu", {6, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Sigmoid", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Tanh", {6, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "LeakyRelu", {6})) {
return true;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Clip", {6, 11, 12, 13})) {
float min, max;
if (!optimizer_utils::GetClipConstantMinMax(graph_viewer.GetGraph(), activation_node, min, max)) {
return false;
}
return true;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "HardSigmoid", {6})) {
return true;
}
return false;
};
return is_supported_cpu_ep_activation(activation_node);
}
const Node* SelectProducerConv(const Node& node) const {
InlinedVector<const Node*> inputs_node;
constexpr int32_t kTensorDims = 4; // NCHW
const auto& input_defs = node.InputDefs();
for (auto producer_node_ptr = node.InputNodesBegin(); producer_node_ptr != node.InputNodesEnd(); ++producer_node_ptr) {
const Node* producer_node = dynamic_cast<const Node*>(&(*producer_node_ptr));
inputs_node.push_back(producer_node);
}
size_t input_defs_count = input_defs.size();
if (input_defs_count != 2 || inputs_node.size() > input_defs_count) {
return nullptr;
}
// Test if all of inputs have an equal shape.
auto* input_0_shape = input_defs[0]->Shape();
// Check if ONNX shape inferencing has computed a precise dimension value.
if ((input_0_shape == nullptr) || (input_0_shape->dim_size() != kTensorDims)) {
return nullptr;
}
for (int i = 0; i < kTensorDims; i++) {
auto& input_0_dim = input_0_shape->dim(i);
// even though zero-dim is valid, but we don't support here
if (!utils::HasDimValue(input_0_dim) || (input_0_dim.dim_value() == 0)) {
if (!utils::HasDimParam(input_0_dim)) {
return nullptr;
}
}
}
// we can't fuse them if shape is not matched, it will happens when broadcast-Add
for (size_t n = 1; n < input_defs_count; n++) {
auto* input_n_shape = input_defs[n]->Shape();
if (input_n_shape == nullptr || (input_n_shape->dim_size() != kTensorDims)) {
return nullptr;
}
for (int i = 0; i < kTensorDims; i++) {
auto& input_0_dim = input_0_shape->dim(i);
auto& input_n_dim = input_n_shape->dim(i);
if (!utils::HasDimValue(input_n_dim) || (input_0_dim.dim_value() != input_n_dim.dim_value())) {
if (!utils::HasDimParam(input_0_dim) || !utils::HasDimParam(input_n_dim) || (input_0_dim.dim_param() != input_n_dim.dim_param())) {
return nullptr;
}
}
}
}
// If one of the inputs to the Add node is a convolution, then
// attempt to fuse the addition into the convolution itself.
for (size_t n = 0; (n < inputs_node.size()) && inputs_node[n]; n++) {
const auto& producer_input_defs = inputs_node[n]->InputDefs();
const auto& producer_input_args_count = inputs_node[n]->InputArgCount();
size_t pre_input_defs_count = producer_input_defs.size();
// Check if this is a single use convolution that hasn't already
// been fused with another Add/Sum node. The Add/Sum can also only be
// fused if the convolution isn't itself fused with an activation.
if ((inputs_node[n]->OpType() == "Conv") && (pre_input_defs_count < 4) && (producer_input_args_count.size() < 4) &&
(graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) && (inputs_node[n]->GetOutputEdgesCount() == 1)) {
if (pre_input_defs_count < 3) {
// The optional bias parameter is empty so set to an empty string.
// TODO, add a new null arguments for bias
continue;
}
return inputs_node[n];
}
}
return nullptr;
}
};
} // namespace selectors
namespace actions {
using NTO = NodesToOptimize;
class FuseConvAddActivation : public ReplaceWithNew {
private:
std::string OpType(const RuntimeState&) const override { return "FusedConv"; }
std::string Domain(const RuntimeState&) const override { return kMSDomain; }
NodeAttributes ExtraAttributes(const RuntimeState& state) const override {
NodeAttributes extra_fused_conv_attributes;
const auto* activation = state.selected_nodes.Output(state.selected_nodes.num_outputs-1);
if (state.selected_nodes.num_outputs == 1 || activation->OpType() == "Add") {
//activation node is the last node in conv+add+activation fusion pattern, while conv+add is also possible
return extra_fused_conv_attributes;
}
ORT_ENFORCE(activation != nullptr, "Expected activation node.");
const auto& activation_op_type = activation->OpType();
utils::SetNodeAttribute(utils::MakeAttribute("activation", activation_op_type), extra_fused_conv_attributes);
InlinedVector<float> activation_params;
if (activation_op_type == "LeakyRelu") {
activation_params.push_back(graph_utils::GetNodeAttribute(*activation, "alpha")->f());
} else if (activation_op_type == "Clip") {
float min, max;
ORT_ENFORCE(optimizer_utils::GetClipConstantMinMax(state.graph, *activation, min, max),
"Failed to get Clip min/max constants.");
activation_params.push_back(min);
activation_params.push_back(max);
} else if (activation_op_type == "HardSigmoid") {
auto* alpha_attr = graph_utils::GetNodeAttribute(*activation, "alpha");
auto* beta_attr = graph_utils::GetNodeAttribute(*activation, "beta");
float alpha = (alpha_attr == nullptr ? 0.2f : alpha_attr->f());
float beta = (beta_attr == nullptr ? 0.5f : beta_attr->f());
activation_params.push_back(alpha);
activation_params.push_back(beta);
}
if (!activation_params.empty()) {
utils::SetNodeAttribute(utils::MakeAttribute("activation_params", activation_params),
extra_fused_conv_attributes);
}
return extra_fused_conv_attributes;
}
std::vector<NodeAndMoveInfo> ValueMoves(const RuntimeState& state) const override {
const auto& conv = state.selected_nodes.Target();
ORT_ENFORCE(conv.GetOutputEdgesCount() == 1 && conv.OutputNodesBegin()->OpType() == "Add",
"Expected Conv then Add.");
const auto add_input_idx = 1 - conv.OutputEdgesBegin()->GetDstArgIndex();
const auto conv_location = NTO::NodeLocation{NTO::NodeType::kTarget, 0};
const auto add_location = NTO::NodeLocation{NTO::NodeType::kOutput, 0};
const auto activation_location = NTO::NodeLocation{NTO::NodeType::kOutput, 1};
//Conv+add+activation
if (state.selected_nodes.num_outputs == 2) {
return {
MoveAll(conv_location, ArgType::kInput), // move all inputs from conv
MoveAndAppend(add_location, ArgType::kInput, add_input_idx, ArgType::kInput), // append add input
MoveAll(activation_location, ArgType::kOutput), // move all outputs from relu
};
} else {
//Conv+Add only
return {
MoveAll(conv_location, ArgType::kInput), // move all inputs from conv
MoveAndAppend(add_location, ArgType::kInput, add_input_idx, ArgType::kInput), // append add input
MoveAll(add_location, ArgType::kOutput), // move all outputs from relu
};
}
}
};
} // namespace actions
void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) {
const auto name = "ConvAddAct";
auto action = std::make_unique<actions::FuseConvAddActivation>();
auto selector = std::make_unique<selectors::ConvAddActivation>();
registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}},
std::move(selector), std::move(action));
}
SelectorActionRegistry CreateSelectorActionRegistry() {
SelectorActionRegistry registry{};
RegisterConvAddActivationFusionRules(registry);
return registry;
}
} // namespace
ConvAddActivationFusion::ConvAddActivationFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers,
const SatApplyContextVariant& apply_context)
: SelectorActionTransformer{
"ConvAddActivationFusion", CreateSelectorActionRegistry(), apply_context, compatible_execution_providers} {
}
} // namespace onnxruntime

View file

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/selectors_actions/selector_action_transformer.h"
namespace onnxruntime {
/**
@Class ConvAddActivationFusion
Transformer that optimizes the graph by using NCHW nodes and a more general version of convaddrelu.
This Fusion pattern is used to fuse Conv Add Activation together from different branch, The reason
is that we assume the graph would be executed by sequential executor. then the orders of branch running doesn't matter
*/
class ConvAddActivationFusion : public SelectorActionTransformer {
public:
ConvAddActivationFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const SatApplyContextVariant& apply_context = {});
};
} // namespace onnxruntime

View file

@ -12,6 +12,7 @@
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/optimizer/conv_add_act_fusion.h"
#if !defined(ORT_MINIMAL_BUILD)
@ -161,7 +162,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool disable_quant_qdq =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
#ifndef DISABLE_CONTRIB_OPS
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
#endif
switch (level) {
case TransformerLevel::Level1: {
// RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run)
@ -198,7 +201,6 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
const bool enable_gelu_approximation =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1";
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
const InlinedHashSet<std::string_view> cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider};
const InlinedHashSet<std::string_view> cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider,
@ -263,6 +265,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
}
auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault);
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
// NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similiar things
// of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for
// x86-64 cpu, not edge cpu like arm. But This tranformer could be used by opencl-ep/cpu-ep. So
// we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled.
// this PR #6351 implemented similiar fusion-pattern but only for CUDA, and can only fuse conv-add-relu, while we can fuse more activation.
transformers.emplace_back(std::make_unique<ConvAddActivationFusion>(cpu_ep));
#endif
} break;

View file

@ -154,9 +154,10 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
Status Conv<float>::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
const Tensor* X = context->Input<Tensor>(0);
const Tensor* W = context->Input<Tensor>(1);
const Tensor* B = num_inputs >= 3 ? context->Input<Tensor>(2) : nullptr;
const Tensor* Sum = num_inputs >= 4 ? context->Input<Tensor>(3) : nullptr;
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
@ -195,7 +196,18 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
const auto* Xdata = X->template Data<float>();
const auto* Bdata = B != nullptr ? B->template Data<float>() : nullptr;
auto* Ydata = Y->template MutableData<float>();
// Check for the optional Conv/Sum fusion.
float Beta = 0.0f;
if (Sum != nullptr) {
const auto& sum_shape = Sum->Shape();
ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match");
// If the output was not allocated inplace with the sum tensor, then copy here.
const auto* sum_data = Sum->template Data<float>();
if (Ydata != sum_data) {
memcpy(Ydata, sum_data, sum_shape.Size() * sizeof(float));
}
Beta = 1.0f;
}
const size_t kernel_rank = kernel_shape.size();
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
@ -216,6 +228,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
static_cast<size_t>(M / conv_attrs_.group),
&activation_,
&WorkingBufferSize,
Beta,
thread_pool);
auto* working_data = WorkingBufferSize > 0 ? alloc->Alloc(SafeInt<size_t>(sizeof(float)) * WorkingBufferSize)
@ -266,7 +279,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
1,
W->template Data<float>() + group_id * W_offset,
col_buffer_data,
0,
Beta,
Ydata + group_id * Y_offset,
thread_pool);
}

View file

@ -29,7 +29,7 @@ class Conv<float> : public OpKernel {
}
Status Compute(OpKernelContext* context) const override;
protected:
MLAS_ACTIVATION activation_;

View file

@ -205,6 +205,31 @@ TEST(FusedConvTest, Conv2D_Bias_Z_Relu) {
}
#endif
TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
1, // group
vector<int64_t>{2, 2}, // kernel_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
"Relu" // activation
};
vector<float> X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
vector<int64_t> X_shape = {1, 1, 3, 3};
vector<float> W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
vector<int64_t> W_shape = {2, 1, 2, 2};
vector<int64_t> Y_shape = {1, 2, 2, 2};
vector<float> B = {1.0f, -1.0f};
vector<int64_t> B_shape = {2};
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
vector<int64_t> Z_shape = {1, 2, 2, 2};
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cpu);
}
#endif
} // namespace test

View file

@ -109,6 +109,7 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) {
static_cast<size_t>(output_channels_per_group),
&activation,
&WorkingBufferSize,
0.0f,
nullptr);
auto X = RandomVectorUniform(x_shape, -2.0, 2.0);

View file

@ -57,6 +57,7 @@ class MlasConv2DTest : public MlasTestBase {
FilterCount,
&Activation,
&WorkingBufferSize,
0.0f,
threadpool_);
MlasConv(&Parameters,

View file

@ -28,6 +28,7 @@
#include "core/optimizer/concat_slice_elimination.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/conv_add_act_fusion.h"
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
@ -751,6 +752,70 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) {
#endif
#if !defined(DISABLE_CONTRIB_OPS)
// Conv->Add->Relu will be transformed to FusedConv
TEST_F(GraphTransformationTests, FuseCpuConvAddRelu) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
for (auto& node : p_model->MainGraph().Nodes()) {
node.SetExecutionProviderType(kCpuExecutionProvider);
}
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1);
ASSERT_TRUE(op_to_count["Relu"] == 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvAddActivationFusion>(), TransformerLevel::Level3));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph
ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph
}
// Conv->Add->Relu will be partly fused to Conv_Add->Relu since there is Identity depend on Add
TEST_F(GraphTransformationTests, FuseCpuConvAddReluIdentity) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
for (auto& node : p_model->MainGraph().Nodes()) {
node.SetExecutionProviderType(kCpuExecutionProvider);
}
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1);
ASSERT_TRUE(op_to_count["Relu"] == 1);
ASSERT_TRUE(op_to_count["Identity"] == 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvAddActivationFusion>(), TransformerLevel::Level3));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed
ASSERT_TRUE(op_to_count["Relu"] == 1); // Relu remains
ASSERT_TRUE(op_to_count["Identity"] == 1); // Identity remains
}
// Conv->Add will be transformed to FusedConv
TEST_F(GraphTransformationTests, FuseCpuConvAdd) {
auto model_uri = MODEL_FOLDER "fusion/conv_add.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
for (auto& node : p_model->MainGraph().Nodes()) {
node.SetExecutionProviderType(kCpuExecutionProvider);
}
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvAddActivationFusion>(), TransformerLevel::Level3));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed
}
#endif
#if !defined(DISABLE_CONTRIB_OPS)
TEST_F(GraphTransformationTests, FuseConvActivation) {
#ifdef USE_CUDA