mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Conv+Activation fusion for CPU (#105)
* Add conv+activation fusion. * Adding tests * Adding activation LeakyRelu. * Refactoring the code to use a fusedConv custom op instead of changing the original conv op at runtime. * fix build issue. * fix build issue. * In order to reduce binary size: 1. reuse onnx shape inference for conv 2. remove most doc. * Accomodating PR comments. * Accomodating PR comments * Remove unused variables
This commit is contained in:
parent
7f0e5269bd
commit
2230e5b431
14 changed files with 537 additions and 155 deletions
16
onnxruntime/contrib_ops/cpu/fused_conv.cc
Normal file
16
onnxruntime/contrib_ops/cpu/fused_conv.cc
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "fused_conv.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
FusedConv,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
FusedConv<float>);
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/contrib_ops/cpu/fused_conv.h
Normal file
24
onnxruntime/contrib_ops/cpu/fused_conv.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/nn/conv_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class FusedConv : public Conv<T> {
|
||||
public:
|
||||
FusedConv(const OpKernelInfo& info) : Conv<T>(info) {
|
||||
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
return Conv<T>::Compute(context);
|
||||
}
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,7 +6,11 @@
|
|||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/range_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "onnx/defs/shape_inference.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape);
|
||||
}
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
using ::ONNX_NAMESPACE::AttributeProto;
|
||||
|
|
@ -28,6 +32,62 @@ void RegisterContribSchemas() {
|
|||
Sample echo operator.)DOC");
|
||||
|
||||
// register schemas for more operators here
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedConv)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(
|
||||
The fused convolution operator schema is the same as Conv besides it includes an attribute
|
||||
activation.)DOC")
|
||||
.Attr(
|
||||
"auto_pad",
|
||||
"",
|
||||
AttributeProto::STRING,
|
||||
std::string("NOTSET"))
|
||||
.Attr(
|
||||
"kernel_shape",
|
||||
"",
|
||||
AttributeProto::INTS,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"dilations",
|
||||
"",
|
||||
AttributeProto::INTS,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"strides", "", AttributeProto::INTS, OPTIONAL)
|
||||
.Attr("pads",
|
||||
"",
|
||||
AttributeProto::INTS, OPTIONAL)
|
||||
.Attr(
|
||||
"group",
|
||||
"",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Attr(
|
||||
"activation",
|
||||
"",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL)
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"",
|
||||
"T")
|
||||
.Input(
|
||||
1,
|
||||
"W",
|
||||
"",
|
||||
"T")
|
||||
.Input(2, "B", "", "T", OpSchema::Optional)
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"",
|
||||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
|
|||
86
onnxruntime/core/graph/conv_activation_fusion.cc
Normal file
86
onnxruntime/core/graph/conv_activation_fusion.cc
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/initializer.h"
|
||||
#include "core/graph/conv_activation_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
using namespace onnx;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
bool IsFusableActivation(const Node& node) {
|
||||
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status ConvActivationFusion::Apply(Graph& graph, bool& modified) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
std::vector<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto index : order) {
|
||||
auto node = graph.GetNode(index);
|
||||
if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
const Node& next_node = *(node->OutputNodesBegin());
|
||||
if (!IsFusableActivation(next_node) || graph.IsNodeOutputsInGraphOutputs(next_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* conv_node = node;
|
||||
const Node& act_node = next_node;
|
||||
std::vector<NodeArg> input_args, output_args;
|
||||
|
||||
Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node->Name()), "FusedConv",
|
||||
"fused Conv " + conv_node->Name() + "with activation " + act_node.OpType(),
|
||||
conv_node->MutableInputDefs(),
|
||||
conv_node->MutableOutputDefs(),
|
||||
&conv_node->GetAttributes(),
|
||||
"com.microsoft");
|
||||
|
||||
//Add a new attribute to specify the activation type
|
||||
fused_conv.AddAttribute("activation", "string");
|
||||
|
||||
//Add optional attributes for activations
|
||||
if (act_node.OpType() == "LeakyRelu") {
|
||||
const NodeAttributes attrs = act_node.GetAttributes();
|
||||
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
|
||||
fused_conv.AddAttribute(it->first, it->second);
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the input of the node following activation node
|
||||
const NodeArg* act_output_def = act_node.OutputDefs()[0];
|
||||
NodeArg* fused_conv_output_def = fused_conv.MutableOutputDefs()[0];
|
||||
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
|
||||
auto output_node = graph.GetNode((*it).Index());
|
||||
if (!output_node) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto& input_defs = output_node->MutableInputDefs();
|
||||
for (auto& def : input_defs) {
|
||||
if (def == act_output_def) {
|
||||
def = fused_conv_output_def;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removed_nodes.push_back(act_node.Index());
|
||||
removed_nodes.push_back(conv_node->Index());
|
||||
}
|
||||
|
||||
for (auto i : removed_nodes) {
|
||||
graph.RemoveNode(i);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
16
onnxruntime/core/graph/conv_activation_fusion.h
Normal file
16
onnxruntime/core/graph/conv_activation_fusion.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ConvActivationFusion : public onnxruntime::GraphTransformer {
|
||||
public:
|
||||
ConvActivationFusion() noexcept : onnxruntime::GraphTransformer("ConvActivationFusion", "Fusing Activation into Conv") {}
|
||||
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -4,9 +4,159 @@
|
|||
#include "core/providers/cpu/nn/conv_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <>
|
||||
Status Conv<float>::Compute(OpKernelContext* context) const {
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t C = X->Shape()[1];
|
||||
const int64_t M = W->Shape()[0];
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W));
|
||||
|
||||
std::vector<int64_t> kernel_shape = ComputeKernelShape(W->Shape());
|
||||
|
||||
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_shape.size(); ++i) {
|
||||
if (kernel_shape[i] != W->Shape()[i + 2]) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> pads(pads_);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_shape.size() * 2, 0);
|
||||
}
|
||||
std::vector<int64_t> dilations(dilations_);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
std::vector<int64_t> strides(strides_);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Y_dims;
|
||||
Y_dims.insert(Y_dims.begin(), {N, M});
|
||||
TensorShape input_shape = X->Shape().Slice(2);
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
|
||||
Tensor* Y = context->Output(0, TensorShape(Y_dims));
|
||||
TensorShape output_shape = Y->Shape().Slice(2);
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
const float* Xdata = X->template Data<float>();
|
||||
float* Ydata = Y->template MutableData<float>();
|
||||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
if (kernel_rank == 2 || kernel_rank == 3) {
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
MlasConvPrepare(&Parameters,
|
||||
kernel_rank,
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(group_),
|
||||
static_cast<size_t>(C / group_),
|
||||
input_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
output_shape.GetDims().data(),
|
||||
static_cast<size_t>(M / group_),
|
||||
&WorkingBufferSize);
|
||||
|
||||
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
|
||||
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
|
||||
|
||||
MlasConv(&Parameters,
|
||||
Xdata,
|
||||
W->template Data<float>(),
|
||||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
static_cast<float*>(working_buffer.get()),
|
||||
Ydata);
|
||||
|
||||
//TODO: this will be replaced with Tracy's changes.
|
||||
fuse_activation(activation_, Ydata, Y->Shape().Size(), alpha_);
|
||||
|
||||
} else {
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
const int64_t kernel_size = TensorShape(kernel_shape).Size();
|
||||
const int64_t X_offset = C / group_ * input_image_size;
|
||||
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
|
||||
const int64_t W_offset = W->Shape().Size() / group_;
|
||||
const int64_t kernel_dim = C / group_ * kernel_size;
|
||||
const int64_t col_buffer_size = kernel_dim * output_image_size;
|
||||
|
||||
auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
float* col_buffer_data = static_cast<float*>(col_buffer.get());
|
||||
|
||||
TensorShape image_shape = X->Shape().Slice(1);
|
||||
std::vector<int64_t> col_buffer_shape{kernel_dim};
|
||||
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
|
||||
output_shape.GetDims().end());
|
||||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>(
|
||||
Xdata + group_id * X_offset,
|
||||
image_shape.GetDims().data(),
|
||||
col_buffer_shape.data(),
|
||||
C * input_image_size,
|
||||
col_buffer_size,
|
||||
kernel_shape.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M / group_,
|
||||
output_image_size,
|
||||
kernel_dim,
|
||||
1,
|
||||
W->template Data<float>() + group_id * W_offset,
|
||||
col_buffer_data,
|
||||
0,
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
|
||||
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Conv,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Conv<float>);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@
|
|||
#include "core/util/math.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace {
|
||||
|
||||
// helper function
|
||||
template <bool ForceSymmetricAutoPadding>
|
||||
|
|
@ -58,7 +57,6 @@ Status ComputePadAndOutputShape(
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// base class used by Conv and ConvTranspose
|
||||
class ConvBase {
|
||||
|
|
@ -122,21 +120,21 @@ class ConvBase {
|
|||
|
||||
if (X->Shape().NumDimensions() != W->Shape().NumDimensions()) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "X num_dims does not match W num_dims.",
|
||||
" X: ", X->Shape().ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
" X: ", X->Shape().ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
if (C != W->Shape()[1] * group_) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.",
|
||||
" C: ", C,
|
||||
" kernel channels: ", W->Shape()[1],
|
||||
" group: ", group_);
|
||||
" C: ", C,
|
||||
" kernel channels: ", W->Shape()[1],
|
||||
" group: ", group_);
|
||||
}
|
||||
|
||||
if (M % group_ != 0) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output channels M is not divisible by group.",
|
||||
" M: ", M,
|
||||
" group: ", group_);
|
||||
" M: ", M,
|
||||
" group: ", group_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -179,6 +177,8 @@ class ConvBase {
|
|||
std::vector<int64_t> strides_;
|
||||
std::vector<int64_t> pads_;
|
||||
std::vector<int64_t> dilations_;
|
||||
std::string activation_;
|
||||
float alpha_;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> kernel_shape_; // must use ComputeKernelShape(...), instead of kernel_shape_
|
||||
|
|
|
|||
|
|
@ -23,6 +23,23 @@
|
|||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
template <typename T>
|
||||
void fuse_activation(const std::string& activation, T* y_data, size_t size, float alpha) {
|
||||
EigenVectorArrayMap<T> y_vec(y_data, size);
|
||||
if (activation.empty()) {
|
||||
return;
|
||||
} else if (activation == "Relu") {
|
||||
y_vec = y_vec.cwiseMax(0);
|
||||
} else if (activation == "Sigmoid") {
|
||||
y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp()));
|
||||
} else if (activation == "Tanh") {
|
||||
y_vec = y_vec.tanh();
|
||||
} else if (activation == "LeakyRelu") {
|
||||
y_vec = (y_vec >= 0).select(y_vec, (T)alpha * y_vec);
|
||||
} else {
|
||||
ONNXRUNTIME_NOT_IMPLEMENTED("Not implemented fused activation: ", activation);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Conv<T>::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -40,15 +57,15 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_shape.size(); ++i) {
|
||||
if (kernel_shape[i] != W->Shape()[i + 2]) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -151,6 +168,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
|
|
@ -160,146 +178,6 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
template <>
|
||||
Status Conv<float>::Compute(OpKernelContext* context) const {
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t C = X->Shape()[1];
|
||||
const int64_t M = W->Shape()[0];
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W));
|
||||
|
||||
std::vector<int64_t> kernel_shape = ComputeKernelShape(W->Shape());
|
||||
|
||||
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_shape.size(); ++i) {
|
||||
if (kernel_shape[i] != W->Shape()[i + 2]) {
|
||||
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
|
||||
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
|
||||
" W: ", W->Shape().ToString().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> pads(pads_);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_shape.size() * 2, 0);
|
||||
}
|
||||
std::vector<int64_t> dilations(dilations_);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
std::vector<int64_t> strides(strides_);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Y_dims;
|
||||
Y_dims.insert(Y_dims.begin(), {N, M});
|
||||
TensorShape input_shape = X->Shape().Slice(2);
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
|
||||
Tensor* Y = context->Output(0, TensorShape(Y_dims));
|
||||
TensorShape output_shape = Y->Shape().Slice(2);
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
const float* Xdata = X->template Data<float>();
|
||||
float* Ydata = Y->template MutableData<float>();
|
||||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
if (kernel_rank == 2 || kernel_rank == 3) {
|
||||
MLAS_CONV_PARAMETERS Parameters;
|
||||
size_t WorkingBufferSize;
|
||||
MlasConvPrepare(&Parameters,
|
||||
kernel_rank,
|
||||
static_cast<size_t>(N),
|
||||
static_cast<size_t>(group_),
|
||||
static_cast<size_t>(C / group_),
|
||||
input_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
output_shape.GetDims().data(),
|
||||
static_cast<size_t>(M / group_),
|
||||
&WorkingBufferSize);
|
||||
|
||||
auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr;
|
||||
BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc));
|
||||
|
||||
MlasConv(&Parameters,
|
||||
Xdata,
|
||||
W->template Data<float>(),
|
||||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
static_cast<float*>(working_buffer.get()),
|
||||
Ydata);
|
||||
} else {
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
const int64_t kernel_size = TensorShape(kernel_shape).Size();
|
||||
const int64_t X_offset = C / group_ * input_image_size;
|
||||
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
|
||||
const int64_t W_offset = W->Shape().Size() / group_;
|
||||
const int64_t kernel_dim = C / group_ * kernel_size;
|
||||
const int64_t col_buffer_size = kernel_dim * output_image_size;
|
||||
|
||||
auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
float* col_buffer_data = static_cast<float*>(col_buffer.get());
|
||||
|
||||
TensorShape image_shape = X->Shape().Slice(1);
|
||||
std::vector<int64_t> col_buffer_shape{kernel_dim};
|
||||
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
|
||||
output_shape.GetDims().end());
|
||||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>(
|
||||
Xdata + group_id * X_offset,
|
||||
image_shape.GetDims().data(),
|
||||
col_buffer_shape.data(),
|
||||
C * input_image_size,
|
||||
col_buffer_size,
|
||||
kernel_shape.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
math::Gemm<float, CPUMathUtil>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M / group_,
|
||||
output_image_size,
|
||||
kernel_dim,
|
||||
1,
|
||||
W->template Data<float>() + group_id * W_offset,
|
||||
col_buffer_data,
|
||||
0,
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
|
||||
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
Status Conv<float>::Compute(OpKernelContext* context) const;
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
#include "core/graph/conv_bn_fusion.h"
|
||||
#include "core/graph/conv_mul_fusion.h"
|
||||
#include "core/graph/conv_add_fusion.h"
|
||||
#include "core/graph/conv_activation_fusion.h"
|
||||
#include "core/platform/env.h"
|
||||
|
||||
#include "test/capturing_sink.h"
|
||||
#include "test/test_environment.h"
|
||||
|
|
@ -71,6 +73,25 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
|
|||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvActivation) {
|
||||
SessionOptions so;
|
||||
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
|
||||
std::string activations[] = {"relu", "sigmoid", "softsign", "tanh", "leakyrelu"};
|
||||
|
||||
for (std::string act : activations) {
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
std::string model_uri = MODEL_FOLDER + "fusion/conv_" + act + ".onnx";
|
||||
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
std::unique_ptr<ConvActivationFusion> ConvActivationFusion_transformer = std::make_unique<ConvActivationFusion>();
|
||||
session_object.RegisterGraphTransformer(std::move(ConvActivationFusion_transformer));
|
||||
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvBNNoBias) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-no-bias.onnx";
|
||||
|
||||
|
|
|
|||
27
onnxruntime/test/testdata/transform/fusion/conv_leakyrelu.onnx
vendored
Normal file
27
onnxruntime/test/testdata/transform/fusion/conv_leakyrelu.onnx
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
onnx_conv_tanh:³
|
||||
*
|
||||
X
|
||||
WY"Conv*
|
||||
kernel_shape@@@@
|
||||
"
|
||||
YZ" LeakyRelu*
|
||||
alphaÍÌL>
|
||||
test-modelZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
W
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
Z
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
26
onnxruntime/test/testdata/transform/fusion/conv_relu.onnx
vendored
Normal file
26
onnxruntime/test/testdata/transform/fusion/conv_relu.onnx
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
onnx_conv_relu:<3A>
|
||||
*
|
||||
X
|
||||
WY"Conv*
|
||||
kernel_shape@@@@
|
||||
|
||||
YZ"Relu
|
||||
test-modelZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
W
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
Z
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
26
onnxruntime/test/testdata/transform/fusion/conv_sigmoid.onnx
vendored
Normal file
26
onnxruntime/test/testdata/transform/fusion/conv_sigmoid.onnx
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
onnx_conv_sigmoid:
|
||||
*
|
||||
X
|
||||
WY"Conv*
|
||||
kernel_shape@@@@
|
||||
|
||||
YZ"Sigmoid
|
||||
test-modelZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
W
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
Z
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
26
onnxruntime/test/testdata/transform/fusion/conv_softsign.onnx
vendored
Normal file
26
onnxruntime/test/testdata/transform/fusion/conv_softsign.onnx
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
onnx_conv_softsign:¡
|
||||
*
|
||||
X
|
||||
WY"Conv*
|
||||
kernel_shape@@@@
|
||||
|
||||
YZ"Softsign
|
||||
test-modelZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
W
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
Z
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
26
onnxruntime/test/testdata/transform/fusion/conv_tanh.onnx
vendored
Normal file
26
onnxruntime/test/testdata/transform/fusion/conv_tanh.onnx
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
onnx_conv_tanh:<3A>
|
||||
*
|
||||
X
|
||||
WY"Conv*
|
||||
kernel_shape@@@@
|
||||
|
||||
YZ"Tanh
|
||||
test-modelZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
W
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
Z
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
Loading…
Reference in a new issue