cleanup fused conv activation handling (#1403)

* cleanup fused conv activation handling

* fix build break

* fix mkldnn build break
This commit is contained in:
Tracy Sharpe 2019-07-14 16:34:16 -07:00 committed by Ke Zhang
parent c139e3ab33
commit d4ce31ea6d
14 changed files with 305 additions and 353 deletions

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/fused_activation.h"
namespace onnxruntime {
common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation) {
// Convert the activation parameters from the node into a MLAS_ACTIVATION.
activation.ActivationKind = MlasIdentityActivation;
std::string activation_type;
if (info.GetAttr<std::string>("activation", &activation_type).IsOK()) {
if (activation_type == "Relu") {
activation.ActivationKind = MlasReluActivation;
} else if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation.alpha = info.GetAttrOrDefault<float>("alpha", 0.01f);
} else if (activation_type == "Tanh") {
activation.ActivationKind = MlasTanhActivation;
} else if (activation_type == "Sigmoid") {
activation.ActivationKind = MlasLogisticActivation;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unimplemented activation: " + activation_type);
}
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation);
} // namespace onnxruntime

View file

@ -1,16 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "fused_conv.h"
#include "core/providers/cpu/nn/conv.h"
#include "contrib_ops/cpu/fused_activation.h"
namespace onnxruntime {
namespace contrib {
class FusedConvFloat final : public Conv<float> {
public:
FusedConvFloat(const OpKernelInfo& info) : Conv<float>(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
}
};
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedConv,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConv<float>);
FusedConvFloat);
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,24 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/nn/conv_impl.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
class FusedConv : public Conv<T> {
public:
FusedConv(const OpKernelInfo& info) : Conv<T>(info) {
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
}
Status Compute(OpKernelContext* context) const override {
return Conv<T>::Compute(context);
}
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,15 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "fused_gemm.h"
#include "core/providers/cpu/math/gemm.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
class FusedGemm final : public Gemm<T> {
public:
FusedGemm(const OpKernelInfo& info) : Gemm<T>(info) {
Gemm<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Gemm<T>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
}
};
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedGemm,
1,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedGemm<float, float, float, float>);
FusedGemm<float>);
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,26 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/math/gemm.h"
namespace onnxruntime {
namespace contrib {
template <typename T_X,
typename T_W,
typename T_B,
typename T_Y>
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> {
public:
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) {
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
}
Status Compute(OpKernelContext* context) const override {
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context);
}
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -34,7 +34,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
KernelDefBuilder()
.MayInplace(3, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
NchwcConv<float>);
NchwcConv);
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
MaxPool,
@ -70,39 +70,38 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
template <typename T>
Status ReorderInput<T>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
Tensor* Y = context->Output(0, X_shape);
auto* Y = context->Output(0, X_shape);
MlasReorderInput(X_shape.GetDims().data(), X->template Data<T>(), Y->template MutableData<T>());
return Status::OK();
}
template <typename T>
Status ReorderOutput<T>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
std::vector<int64_t> Y_shape(X_shape.GetDims());
ORT_ENFORCE(channels_ <= Y_shape[1]);
Y_shape[1] = channels_;
Tensor* Y = context->Output(0, Y_shape);
auto* Y = context->Output(0, Y_shape);
MlasReorderOutput(Y_shape.data(), X->template Data<T>(), Y->template MutableData<T>());
return Status::OK();
}
template <typename T>
Status NchwcConv<T>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const Tensor* W = context->Input<Tensor>(1);
const Tensor* B = context->Input<Tensor>(2);
const Tensor* Sum = context->Input<Tensor>(3);
Status NchwcConv::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
const auto* B = context->Input<Tensor>(2);
const auto* Sum = context->Input<Tensor>(3);
ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W));
const TensorShape& X_shape = X->Shape();
const TensorShape& W_shape = W->Shape();
const auto& X_shape = X->Shape();
const auto& W_shape = W->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
const size_t nchwc_block_size = MlasNchwcGetBlockSize();
@ -131,36 +130,20 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, Y_dims);
T* y_data = Y->template MutableData<T>();
auto* Y = context->Output(0, Y_dims);
auto* y_data = Y->template MutableData<float>();
// Check for the optional Conv/Sum fusion.
if (Sum != nullptr) {
const auto& sum_shape = Sum->Shape();
ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match");
// If the output was not allocated inplace with the sum tensor, then copy here.
const float* sum_data = Sum->template Data<T>();
const auto* sum_data = Sum->template Data<float>();
if (y_data != sum_data) {
memcpy(y_data, sum_data, sum_shape.Size() * sizeof(T));
memcpy(y_data, sum_data, sum_shape.Size() * sizeof(float));
}
}
MLAS_ACTIVATION Activation;
if (ConvBase::activation_.empty()) {
Activation.ActivationKind = MlasIdentityActivation;
} else if (ConvBase::activation_ == "Relu") {
Activation.ActivationKind = MlasReluActivation;
} else if (ConvBase::activation_ == "LeakyRelu") {
Activation.ActivationKind = MlasLeakyReluActivation;
Activation.alpha = ConvBase::alpha_;
} else if (ConvBase::activation_ == "Tanh") {
Activation.ActivationKind = MlasTanhActivation;
} else if (ConvBase::activation_ == "Sigmoid") {
Activation.ActivationKind = MlasLogisticActivation;
} else {
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", ConvBase::activation_);
}
MlasNchwcConv(kernel_shape.size(),
X_shape.GetDims().data(),
kernel_shape.data(),
@ -173,7 +156,7 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
y_data,
&Activation,
&activation_,
Sum == nullptr,
const_cast<concurrency::ThreadPool*>(static_cast<OpKernelContextInternal*>(context)->GetOperatorThreadPool()));
@ -181,9 +164,9 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
}
Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const {
const Tensor* X = context->Input<Tensor>(0);
const auto* X = context->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
@ -193,7 +176,7 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind
std::vector<int64_t> pads = pads_;
std::vector<int64_t> output_dims = PoolBase::SetOutputSize(X_shape, X_shape[1], &pads, dilations_, ceil_mode_);
Tensor* Y = context->Output(0, output_dims);
auto* Y = context->Output(0, output_dims);
MlasNchwcPool(kind,
2,

View file

@ -5,8 +5,9 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_impl.h"
#include "core/providers/cpu/nn/conv_base.h"
#include "core/providers/cpu/nn/pool.h"
#include "contrib_ops/cpu/fused_activation.h"
namespace onnxruntime {
namespace contrib {
@ -34,15 +35,16 @@ class ReorderOutput : public OpKernel {
int64_t channels_;
};
template <typename T>
class NchwcConv : public Conv<T> {
class NchwcConv : public OpKernel, public ConvBase {
public:
NchwcConv(const OpKernelInfo& info) : Conv<T>(info) {
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
}
Status Compute(OpKernelContext* context) const override;
private:
MLAS_ACTIVATION activation_;
};
class NchwcPoolBase : public PoolBase {

View file

@ -10,5 +10,5 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
7,
9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gemm<float, float, float, float>);
}
Gemm<float>);
}

View file

@ -11,10 +11,7 @@
namespace onnxruntime {
template <typename T_X,
typename T_W,
typename T_B,
typename T_Y>
template <typename T>
class Gemm : public OpKernel {
public:
Gemm(const OpKernelInfo& info) : OpKernel(info) {
@ -40,75 +37,47 @@ class Gemm : public OpKernel {
int64_t M = helper.M();
int64_t N = helper.N();
int64_t K = helper.K();
auto Y = context->Output(0, TensorShape({M, N}));
auto Y = context->Output(0, {M, N});
// if input is emtpy tensor, return directly as nothing need to be calculated.
if (M == 0 || N == 0)
return Status::OK();
T_Y* y_data = Y->template MutableData<T_Y>();
T* y_data = Y->template MutableData<T>();
//bias
// Todo: we might should move this part into math::gemm to let eigen
// have better chance to further optimize it.
// Broadcast the bias as needed.
if (beta_ != 0) {
auto output_mat = EigenMatrixMapRowMajor<T_Y>(
Y->template MutableData<T_Y>(),
M,
N);
output_mat.setZero();
auto& b_shape = B->Shape();
// if B is (), (1,) or (1, 1), add the scalar
auto output_mat = EigenMatrixMapRowMajor<T>(y_data, M, N);
const auto& b_shape = B->Shape();
const T* b_data = B->template Data<T>();
if (b_shape.Size() == 1) {
output_mat.array() += *(B->template Data<T_B>());
}
// B is (N,)
else if (b_shape.NumDimensions() == 1) {
auto bias_vec = ConstEigenVectorMap<T_B>(
B->template Data<T_B>(),
N);
output_mat.rowwise() += bias_vec.transpose();
} else if (b_shape.NumDimensions() == 2) {
// B is (), (1,) or (1, 1), set the scalar
output_mat.setConstant(*b_data);
} else if (b_shape.NumDimensions() == 1 || b_shape[0] == 1) {
// B is (N,) or (1, N)
output_mat.rowwise() = ConstEigenVectorMap<T>(b_data, N).transpose();
} else if (b_shape[1] == 1) {
// B is (M, 1)
if (b_shape[1] == 1) {
auto bias_vec = ConstEigenVectorMap<T_B>(
B->template Data<T_B>(),
M);
output_mat.colwise() += bias_vec;
}
// B is (1, N)
else if (b_shape[0] == 1) {
auto bias_vec = ConstEigenVectorMap<T_B>(
B->template Data<T_B>(),
N);
output_mat.rowwise() += bias_vec.transpose();
}
output_mat.colwise() = ConstEigenVectorMap<T>(b_data, M);
} else {
// B is (M, N), no broadcast needed.
else {
auto bias_mat = ConstEigenMatrixMapRowMajor<T_B>(
B->template Data<T_B>(),
M,
N);
output_mat += bias_mat;
}
output_mat = ConstEigenMatrixMapRowMajor<T>(b_data, M, N);
}
}
// W * x
math::Gemm<T_X, CPUMathUtil>(
math::Gemm<T, CPUMathUtil>(
trans_A_,
trans_B_,
M,
N,
K,
helper.K(),
alpha_,
X->template Data<T_X>(),
W->template Data<T_W>(),
X->template Data<T>(),
W->template Data<T>(),
beta_,
y_data,
&CPUMathUtil::Instance());
FuseActivation<T_Y>(activation_, y_data, M * N, leaky_relu_alpha_);
FuseActivation<T>(activation_, y_data, M * N, leaky_relu_alpha_);
return Status::OK();
}
@ -119,7 +88,7 @@ class Gemm : public OpKernel {
float alpha_;
float beta_;
protected:
protected:
// For fused gemm + activation
std::string activation_;
float leaky_relu_alpha_;

View file

@ -1,13 +1,148 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
#include "core/providers/cpu/nn/conv.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/providers/cpu/nn/conv_impl.h"
#include "core/util/math_cpuonly.h"
namespace onnxruntime {
template <>
template <typename T>
Status Conv<T>::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
bool Is2DKernel = kernel_shape.size() == 2;
std::vector<int64_t> pads(pads_);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(dilations_);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(strides_);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;
const int64_t kernel_dim = C / group_ * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
T* col_buffer_data = static_cast<T*>(col_buffer.get());
const T* Xdata = X->template Data<T>();
T* Ydata = Y->template MutableData<T>();
TensorShape image_shape = X->Shape().Slice(1);
std::vector<int64_t> col_buffer_shape{kernel_dim};
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
output_shape.GetDims().end());
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
if (Is2DKernel) {
math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
Xdata + group_id * X_offset,
C / group_,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data,
&CPUMathUtil::Instance());
} else {
math::Im2colNd<T, CPUMathUtil, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
image_shape.GetDims().data(),
col_buffer_shape.data(),
C * input_image_size,
col_buffer_size,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data,
&CPUMathUtil::Instance());
}
math::Gemm<T, CPUMathUtil>(
CblasNoTrans,
CblasNoTrans,
M / group_,
output_image_size,
kernel_dim,
1,
W->template Data<T>() + group_id * W_offset,
col_buffer_data,
0,
Ydata + group_id * Y_offset,
&CPUMathUtil::Instance());
}
if (B != nullptr) {
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_image_size, M);
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
Xdata += X_offset * group_;
Ydata += Y_offset * group_;
}
return Status::OK();
}
Status Conv<float>::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const auto* X = context->Input<Tensor>(0);
@ -45,27 +180,12 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
const auto* Xdata = X->template Data<float>();
const auto* Bdata = B != nullptr ? B->template Data<float>() : nullptr;
auto* Ydata = Y->template MutableData<float>();
const size_t kernel_rank = kernel_shape.size();
if (kernel_rank == 2 || kernel_rank == 3) {
MLAS_ACTIVATION Activation;
if (activation_.empty()) {
Activation.ActivationKind = MlasIdentityActivation;
} else if (activation_ == "Relu") {
Activation.ActivationKind = MlasReluActivation;
} else if (activation_ == "LeakyRelu") {
Activation.ActivationKind = MlasLeakyReluActivation;
Activation.alpha = alpha_;
} else if (activation_ == "Tanh") {
Activation.ActivationKind = MlasTanhActivation;
} else if (activation_ == "Sigmoid") {
Activation.ActivationKind = MlasLogisticActivation;
} else {
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_);
}
// Get access to the internal threadpool
// Temporarily derive concurrency parameters without access to session state
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
@ -85,7 +205,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
strides.data(),
output_shape.GetDims().data(),
static_cast<size_t>(M / group_),
&Activation,
&activation_,
&WorkingBufferSize,
const_cast<concurrency::ThreadPool*>(thread_pool));
@ -95,7 +215,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
MlasConv(&Parameters,
Xdata,
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
Bdata,
static_cast<float*>(working_buffer.get()),
Ydata,
const_cast<concurrency::ThreadPool*>(thread_pool));
@ -147,13 +267,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
&CPUMathUtil::Instance());
}
if (B != nullptr) {
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
MlasActivation(&activation_, Ydata, Bdata, M, Ydata, output_image_size, output_image_size);
Xdata += X_offset * group_;
Ydata += Y_offset * group_;
@ -168,4 +282,5 @@ ONNX_CPU_OPERATOR_KERNEL(
1,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv<float>);
} // namespace onnxruntime

View file

@ -1,23 +1,10 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/nn/conv_base.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
@ -30,4 +17,17 @@ class Conv : public OpKernel, public ConvBase {
Status Compute(OpKernelContext* context) const override;
};
template <>
class Conv<float> : public OpKernel, public ConvBase {
public:
Conv<float>(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
activation_.ActivationKind = MlasIdentityActivation;
}
Status Compute(OpKernelContext* context) const override;
protected:
MLAS_ACTIVATION activation_;
};
} // namespace onnxruntime

View file

@ -1,153 +0,0 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
#pragma once
#include "core/providers/cpu/nn/conv.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
template <typename T>
Status Conv<T>::Compute(OpKernelContext* context) const {
size_t num_inputs = OpKernel::Node().InputDefs().size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
bool Is2DKernel = kernel_shape.size() == 2;
std::vector<int64_t> pads(pads_);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(dilations_);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(strides_);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;
const int64_t kernel_dim = C / group_ * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size);
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
T* col_buffer_data = static_cast<T*>(col_buffer.get());
const T* Xdata = X->template Data<T>();
T* Ydata = Y->template MutableData<T>();
TensorShape image_shape = X->Shape().Slice(1);
std::vector<int64_t> col_buffer_shape{kernel_dim};
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
output_shape.GetDims().end());
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
if (Is2DKernel) {
math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
Xdata + group_id * X_offset,
C / group_,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data,
&CPUMathUtil::Instance());
} else {
math::Im2colNd<T, CPUMathUtil, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
image_shape.GetDims().data(),
col_buffer_shape.data(),
C * input_image_size,
col_buffer_size,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data,
&CPUMathUtil::Instance());
}
math::Gemm<T, CPUMathUtil>(
CblasNoTrans,
CblasNoTrans,
M / group_,
output_image_size,
kernel_dim,
1,
W->template Data<T>() + group_id * W_offset,
col_buffer_data,
0,
Ydata + group_id * Y_offset,
&CPUMathUtil::Instance());
}
if (B != nullptr) {
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_image_size, M);
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
Xdata += X_offset * group_;
Ydata += Y_offset * group_;
}
return Status::OK();
}
template <>
Status Conv<float>::Compute(OpKernelContext* context) const;
} // namespace onnxruntime

View file

@ -157,7 +157,7 @@ TEST(GemmOpTest, GemmScalarBroadcast) {
test.Run();
}
TEST(MathOpTest, Gemm2DBroadcast) {
TEST(GemmOpTest, Gemm2DBroadcast_1) {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
@ -176,6 +176,26 @@ TEST(MathOpTest, Gemm2DBroadcast) {
test.Run();
}
TEST(GemmOpTest, Gemm2DBroadcast_2) {
OpTester test("Gemm");
test.AddAttribute("transA", (int64_t)0);
test.AddAttribute("transB", (int64_t)0);
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);
// Same as GemmBroadcast, but adding the unnecessary second dimension.
test.AddInput<float>("A", {2, 4},
{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f});
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
test.AddInput<float>("C", {1, 3}, std::vector<float>{1.0f, 2.0f, 3.0f});
test.AddOutput<float>("Y", {2, 3},
{11.0f, 12.0f, 13.0f,
-9.0f, -8.0f, -7.0f});
test.Run();
}
TEST(GemmOpTest, GemmFalseBroadcast) {
OpTester test("Gemm");