Conv+Activation fusion for CPU (#105)

* Add conv+activation fusion.

* Adding tests

* Adding activation LeakyRelu.

* Refactoring the code to use a fusedConv custom op instead of changing
the original conv op at runtime.

* fix build issue.

* fix build issue.

* In order to reduce binary size:
1. reuse onnx shape inference for conv
2. remove most doc.

* Accomodating PR comments.

* Accomodating PR comments

* Remove unused variables
This commit is contained in:
Du Li 2018-12-12 11:37:09 -08:00 committed by GitHub
parent 7f0e5269bd
commit 2230e5b431
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 537 additions and 155 deletions

View file

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "fused_conv.h"
namespace onnxruntime {
namespace contrib {
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedConv,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConv<float>);
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/nn/conv_impl.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
class FusedConv : public Conv<T> {
public:
FusedConv(const OpKernelInfo& info) : Conv<T>(info) {
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
}
Status Compute(OpKernelContext* context) const override {
return Conv<T>::Compute(context);
}
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -6,7 +6,11 @@
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/range_schema_defs.h"
#include "core/graph/op.h"
#include "onnx/defs/shape_inference.h"
namespace ONNX_NAMESPACE {
void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape);
}
namespace onnxruntime {
namespace contrib {
using ::ONNX_NAMESPACE::AttributeProto;
@ -28,6 +32,62 @@ void RegisterContribSchemas() {
Sample echo operator.)DOC");
// register schemas for more operators here
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedConv)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(R"DOC(
The fused convolution operator schema is the same as Conv besides it includes an attribute
activation.)DOC")
.Attr(
"auto_pad",
"",
AttributeProto::STRING,
std::string("NOTSET"))
.Attr(
"kernel_shape",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"dilations",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"strides", "", AttributeProto::INTS, OPTIONAL)
.Attr("pads",
"",
AttributeProto::INTS, OPTIONAL)
.Attr(
"group",
"",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"activation",
"",
AttributeProto::STRING,
OPTIONAL)
.Input(
0,
"X",
"",
"T")
.Input(
1,
"W",
"",
"T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Output(
0,
"Y",
"",
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
.SetDomain(kMSDomain)

View file

@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/initializer.h"
#include "core/graph/conv_activation_fusion.h"
#include "core/graph/graph_utils.h"
using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace {
bool IsFusableActivation(const Node& node) {
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
}
} // namespace
Status ConvActivationFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
std::vector<onnxruntime::NodeIndex> removed_nodes;
for (auto index : order) {
auto node = graph.GetNode(index);
if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) {
continue;
}
const Node& next_node = *(node->OutputNodesBegin());
if (!IsFusableActivation(next_node) || graph.IsNodeOutputsInGraphOutputs(next_node)) {
continue;
}
Node* conv_node = node;
const Node& act_node = next_node;
std::vector<NodeArg> input_args, output_args;
Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node->Name()), "FusedConv",
"fused Conv " + conv_node->Name() + "with activation " + act_node.OpType(),
conv_node->MutableInputDefs(),
conv_node->MutableOutputDefs(),
&conv_node->GetAttributes(),
"com.microsoft");
//Add a new attribute to specify the activation type
fused_conv.AddAttribute("activation", "string");
//Add optional attributes for activations
if (act_node.OpType() == "LeakyRelu") {
const NodeAttributes attrs = act_node.GetAttributes();
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
fused_conv.AddAttribute(it->first, it->second);
}
}
// Replace the input of the node following activation node
const NodeArg* act_output_def = act_node.OutputDefs()[0];
NodeArg* fused_conv_output_def = fused_conv.MutableOutputDefs()[0];
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}
auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == act_output_def) {
def = fused_conv_output_def;
}
}
}
removed_nodes.push_back(act_node.Index());
removed_nodes.push_back(conv_node->Index());
}
for (auto i : removed_nodes) {
graph.RemoveNode(i);
}
if (!removed_nodes.empty()) {
modified = true;
ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/graph_transformer.h"
namespace onnxruntime {
class ConvActivationFusion : public onnxruntime::GraphTransformer {
public:
ConvActivationFusion() noexcept : onnxruntime::GraphTransformer("ConvActivationFusion", "Fusing Activation into Conv") {}
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
};
} // namespace onnxruntime

View file

@ -4,9 +4,159 @@
#include "core/providers/cpu/nn/conv_impl.h"
namespace onnxruntime {
template <>
Status Conv<float>::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const Tensor* X = context->Input<Tensor>(0);
const Tensor* W = context->Input<Tensor>(1);
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W));
std::vector<int64_t> kernel_shape = ComputeKernelShape(W->Shape());
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
for (size_t i = 0; i < kernel_shape.size(); ++i) {
if (kernel_shape[i] != W->Shape()[i + 2]) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
}
std::vector<int64_t> pads(pads_);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(dilations_);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(strides_);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);
AllocatorPtr alloc;
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
const float* Xdata = X->template Data<float>();
float* Ydata = Y->template MutableData<float>();
const size_t kernel_rank = kernel_shape.size();
if (kernel_rank == 2 || kernel_rank == 3) {
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
MlasConvPrepare(&Parameters,
kernel_rank,
static_cast<size_t>(N),
static_cast<size_t>(group_),
static_cast<size_t>(C / group_),
input_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
output_shape.GetDims().data(),
static_cast<size_t>(M / group_),
&WorkingBufferSize);
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
MlasConv(&Parameters,
Xdata,
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
static_cast<float*>(working_buffer.get()),
Ydata);
//TODO: this will be replaced with Tracy's changes.
fuse_activation(activation_, Ydata, Y->Shape().Size(), alpha_);
} else {
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;
const int64_t kernel_dim = C / group_ * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;
auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
float* col_buffer_data = static_cast<float*>(col_buffer.get());
TensorShape image_shape = X->Shape().Slice(1);
std::vector<int64_t> col_buffer_shape{kernel_dim};
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
output_shape.GetDims().end());
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>(
Xdata + group_id * X_offset,
image_shape.GetDims().data(),
col_buffer_shape.data(),
C * input_image_size,
col_buffer_size,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data,
&CPUMathUtil::Instance());
math::Gemm<float, CPUMathUtil>(
CblasNoTrans,
CblasNoTrans,
M / group_,
output_image_size,
kernel_dim,
1,
W->template Data<float>() + group_id * W_offset,
col_buffer_data,
0,
Ydata + group_id * Y_offset,
&CPUMathUtil::Instance());
}
if (B != nullptr) {
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
Xdata += X_offset * group_;
Ydata += Y_offset * group_;
}
}
return Status::OK();
}
ONNX_CPU_OPERATOR_KERNEL(
Conv,
1,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv<float>);
}
} // namespace onnxruntime

View file

@ -10,7 +10,6 @@
#include "core/util/math.h"
namespace onnxruntime {
namespace {
// helper function
template <bool ForceSymmetricAutoPadding>
@ -58,7 +57,6 @@ Status ComputePadAndOutputShape(
}
return Status::OK();
}
} // namespace
// base class used by Conv and ConvTranspose
class ConvBase {
@ -122,21 +120,21 @@ class ConvBase {
if (X->Shape().NumDimensions() != W->Shape().NumDimensions()) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "X num_dims does not match W num_dims.",
" X: ", X->Shape().ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
" X: ", X->Shape().ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
if (C != W->Shape()[1] * group_) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.",
" C: ", C,
" kernel channels: ", W->Shape()[1],
" group: ", group_);
" C: ", C,
" kernel channels: ", W->Shape()[1],
" group: ", group_);
}
if (M % group_ != 0) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output channels M is not divisible by group.",
" M: ", M,
" group: ", group_);
" M: ", M,
" group: ", group_);
}
return Status::OK();
}
@ -179,6 +177,8 @@ class ConvBase {
std::vector<int64_t> strides_;
std::vector<int64_t> pads_;
std::vector<int64_t> dilations_;
std::string activation_;
float alpha_;
private:
std::vector<int64_t> kernel_shape_; // must use ComputeKernelShape(...), instead of kernel_shape_

View file

@ -23,6 +23,23 @@
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
template <typename T>
void fuse_activation(const std::string& activation, T* y_data, size_t size, float alpha) {
EigenVectorArrayMap<T> y_vec(y_data, size);
if (activation.empty()) {
return;
} else if (activation == "Relu") {
y_vec = y_vec.cwiseMax(0);
} else if (activation == "Sigmoid") {
y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp()));
} else if (activation == "Tanh") {
y_vec = y_vec.tanh();
} else if (activation == "LeakyRelu") {
y_vec = (y_vec >= 0).select(y_vec, (T)alpha * y_vec);
} else {
ONNXRUNTIME_NOT_IMPLEMENTED("Not implemented fused activation: ", activation);
}
}
template <typename T>
Status Conv<T>::Compute(OpKernelContext* context) const {
@ -40,15 +57,15 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
for (size_t i = 0; i < kernel_shape.size(); ++i) {
if (kernel_shape[i] != W->Shape()[i + 2]) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
}
@ -151,6 +168,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
Xdata += X_offset * group_;
Ydata += Y_offset * group_;
@ -160,146 +178,6 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
}
template <>
Status Conv<float>::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const Tensor* X = context->Input<Tensor>(0);
const Tensor* W = context->Input<Tensor>(1);
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W));
std::vector<int64_t> kernel_shape = ComputeKernelShape(W->Shape());
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
for (size_t i = 0; i < kernel_shape.size(); ++i) {
if (kernel_shape[i] != W->Shape()[i + 2]) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
}
std::vector<int64_t> pads(pads_);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(dilations_);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(strides_);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);
AllocatorPtr alloc;
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
const float* Xdata = X->template Data<float>();
float* Ydata = Y->template MutableData<float>();
const size_t kernel_rank = kernel_shape.size();
if (kernel_rank == 2 || kernel_rank == 3) {
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
MlasConvPrepare(&Parameters,
kernel_rank,
static_cast<size_t>(N),
static_cast<size_t>(group_),
static_cast<size_t>(C / group_),
input_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
output_shape.GetDims().data(),
static_cast<size_t>(M / group_),
&WorkingBufferSize);
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
MlasConv(&Parameters,
Xdata,
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
static_cast<float*>(working_buffer.get()),
Ydata);
} else {
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;
const int64_t kernel_dim = C / group_ * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;
auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
float* col_buffer_data = static_cast<float*>(col_buffer.get());
TensorShape image_shape = X->Shape().Slice(1);
std::vector<int64_t> col_buffer_shape{kernel_dim};
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
output_shape.GetDims().end());
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>(
Xdata + group_id * X_offset,
image_shape.GetDims().data(),
col_buffer_shape.data(),
C * input_image_size,
col_buffer_size,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data,
&CPUMathUtil::Instance());
math::Gemm<float, CPUMathUtil>(
CblasNoTrans,
CblasNoTrans,
M / group_,
output_image_size,
kernel_dim,
1,
W->template Data<float>() + group_id * W_offset,
col_buffer_data,
0,
Ydata + group_id * Y_offset,
&CPUMathUtil::Instance());
}
if (B != nullptr) {
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
Xdata += X_offset * group_;
Ydata += Y_offset * group_;
}
}
return Status::OK();
}
Status Conv<float>::Compute(OpKernelContext* context) const;
} // namespace onnxruntime

View file

@ -10,6 +10,8 @@
#include "core/graph/conv_bn_fusion.h"
#include "core/graph/conv_mul_fusion.h"
#include "core/graph/conv_add_fusion.h"
#include "core/graph/conv_activation_fusion.h"
#include "core/platform/env.h"
#include "test/capturing_sink.h"
#include "test/test_environment.h"
@ -71,6 +73,25 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
ASSERT_TRUE(session_object.Initialize().IsOK());
}
TEST(GraphTransformationTests, FuseConvActivation) {
SessionOptions so;
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
std::string activations[] = {"relu", "sigmoid", "softsign", "tanh", "leakyrelu"};
for (std::string act : activations) {
InferenceSession session_object{so, &DefaultLoggingManager()};
std::string model_uri = MODEL_FOLDER + "fusion/conv_" + act + ".onnx";
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
std::unique_ptr<ConvActivationFusion> ConvActivationFusion_transformer = std::make_unique<ConvActivationFusion>();
session_object.RegisterGraphTransformer(std::move(ConvActivationFusion_transformer));
ASSERT_TRUE(session_object.Initialize().IsOK());
}
}
TEST(GraphTransformationTests, FuseConvBNNoBias) {
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-no-bias.onnx";

View file

@ -0,0 +1,27 @@
onnx_conv_tanh:³
*
X
WY"Conv*
kernel_shape@@@@ 
"
YZ" LeakyRelu*
alphaÍÌL> 
test-modelZ
X




Z
W




b
Z




B

View file

@ -0,0 +1,26 @@
onnx_conv_relu:<3A>
*
X
WY"Conv*
kernel_shape@@@@ 
YZ"Relu
test-modelZ
X




Z
W




b
Z




B

View file

@ -0,0 +1,26 @@
onnx_conv_sigmoid: 
*
X
WY"Conv*
kernel_shape@@@@ 

YZ"Sigmoid
test-modelZ
X




Z
W




b
Z




B

View file

@ -0,0 +1,26 @@
onnx_conv_softsign:¡
*
X
WY"Conv*
kernel_shape@@@@ 

YZ"Softsign
test-modelZ
X




Z
W




b
Z




B

View file

@ -0,0 +1,26 @@
onnx_conv_tanh:<3A>
*
X
WY"Conv*
kernel_shape@@@@ 
YZ"Tanh
test-modelZ
X




Z
W




b
Z




B