mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
cleanup fused conv activation handling (#1403)
* cleanup fused conv activation handling * fix build break * fix mkldnn build break
This commit is contained in:
parent
c139e3ab33
commit
d4ce31ea6d
14 changed files with 305 additions and 353 deletions
31
onnxruntime/contrib_ops/cpu/fused_activation.cc
Normal file
31
onnxruntime/contrib_ops/cpu/fused_activation.cc
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cpu/fused_activation.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation) {
|
||||
// Convert the activation parameters from the node into a MLAS_ACTIVATION.
|
||||
activation.ActivationKind = MlasIdentityActivation;
|
||||
|
||||
std::string activation_type;
|
||||
if (info.GetAttr<std::string>("activation", &activation_type).IsOK()) {
|
||||
if (activation_type == "Relu") {
|
||||
activation.ActivationKind = MlasReluActivation;
|
||||
} else if (activation_type == "LeakyRelu") {
|
||||
activation.ActivationKind = MlasLeakyReluActivation;
|
||||
activation.alpha = info.GetAttrOrDefault<float>("alpha", 0.01f);
|
||||
} else if (activation_type == "Tanh") {
|
||||
activation.ActivationKind = MlasTanhActivation;
|
||||
} else if (activation_type == "Sigmoid") {
|
||||
activation.ActivationKind = MlasLogisticActivation;
|
||||
} else {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unimplemented activation: " + activation_type);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
14
onnxruntime/contrib_ops/cpu/fused_activation.h
Normal file
14
onnxruntime/contrib_ops/cpu/fused_activation.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,16 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "fused_conv.h"
|
||||
#include "core/providers/cpu/nn/conv.h"
|
||||
#include "contrib_ops/cpu/fused_activation.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class FusedConvFloat final : public Conv<float> {
|
||||
public:
|
||||
FusedConvFloat(const OpKernelInfo& info) : Conv<float>(info) {
|
||||
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
|
||||
}
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
FusedConv,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
FusedConv<float>);
|
||||
FusedConvFloat);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/nn/conv_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class FusedConv : public Conv<T> {
|
||||
public:
|
||||
FusedConv(const OpKernelInfo& info) : Conv<T>(info) {
|
||||
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
return Conv<T>::Compute(context);
|
||||
}
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,15 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "fused_gemm.h"
|
||||
#include "core/providers/cpu/math/gemm.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class FusedGemm final : public Gemm<T> {
|
||||
public:
|
||||
FusedGemm(const OpKernelInfo& info) : Gemm<T>(info) {
|
||||
Gemm<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
Gemm<T>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
|
||||
}
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
FusedGemm,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
FusedGemm<float, float, float, float>);
|
||||
FusedGemm<float>);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/math/gemm.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
template <typename T_X,
|
||||
typename T_W,
|
||||
typename T_B,
|
||||
typename T_Y>
|
||||
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> {
|
||||
public:
|
||||
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) {
|
||||
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context);
|
||||
}
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -34,7 +34,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
|||
KernelDefBuilder()
|
||||
.MayInplace(3, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcConv<float>);
|
||||
NchwcConv);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
MaxPool,
|
||||
|
|
@ -70,39 +70,38 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
|||
|
||||
template <typename T>
|
||||
Status ReorderInput<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
|
||||
Tensor* Y = context->Output(0, X_shape);
|
||||
auto* Y = context->Output(0, X_shape);
|
||||
MlasReorderInput(X_shape.GetDims().data(), X->template Data<T>(), Y->template MutableData<T>());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ReorderOutput<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
std::vector<int64_t> Y_shape(X_shape.GetDims());
|
||||
ORT_ENFORCE(channels_ <= Y_shape[1]);
|
||||
Y_shape[1] = channels_;
|
||||
Tensor* Y = context->Output(0, Y_shape);
|
||||
auto* Y = context->Output(0, Y_shape);
|
||||
MlasReorderOutput(Y_shape.data(), X->template Data<T>(), Y->template MutableData<T>());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status NchwcConv<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = context->Input<Tensor>(2);
|
||||
const Tensor* Sum = context->Input<Tensor>(3);
|
||||
Status NchwcConv::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto* W = context->Input<Tensor>(1);
|
||||
const auto* B = context->Input<Tensor>(2);
|
||||
const auto* Sum = context->Input<Tensor>(3);
|
||||
|
||||
ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W));
|
||||
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
const TensorShape& W_shape = W->Shape();
|
||||
const auto& X_shape = X->Shape();
|
||||
const auto& W_shape = W->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
|
||||
const size_t nchwc_block_size = MlasNchwcGetBlockSize();
|
||||
|
|
@ -131,36 +130,20 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
|
|||
Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]});
|
||||
TensorShape input_shape = X->Shape().Slice(2);
|
||||
ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
|
||||
Tensor* Y = context->Output(0, Y_dims);
|
||||
T* y_data = Y->template MutableData<T>();
|
||||
auto* Y = context->Output(0, Y_dims);
|
||||
auto* y_data = Y->template MutableData<float>();
|
||||
|
||||
// Check for the optional Conv/Sum fusion.
|
||||
if (Sum != nullptr) {
|
||||
const auto& sum_shape = Sum->Shape();
|
||||
ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match");
|
||||
// If the output was not allocated inplace with the sum tensor, then copy here.
|
||||
const float* sum_data = Sum->template Data<T>();
|
||||
const auto* sum_data = Sum->template Data<float>();
|
||||
if (y_data != sum_data) {
|
||||
memcpy(y_data, sum_data, sum_shape.Size() * sizeof(T));
|
||||
memcpy(y_data, sum_data, sum_shape.Size() * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
MLAS_ACTIVATION Activation;
|
||||
if (ConvBase::activation_.empty()) {
|
||||
Activation.ActivationKind = MlasIdentityActivation;
|
||||
} else if (ConvBase::activation_ == "Relu") {
|
||||
Activation.ActivationKind = MlasReluActivation;
|
||||
} else if (ConvBase::activation_ == "LeakyRelu") {
|
||||
Activation.ActivationKind = MlasLeakyReluActivation;
|
||||
Activation.alpha = ConvBase::alpha_;
|
||||
} else if (ConvBase::activation_ == "Tanh") {
|
||||
Activation.ActivationKind = MlasTanhActivation;
|
||||
} else if (ConvBase::activation_ == "Sigmoid") {
|
||||
Activation.ActivationKind = MlasLogisticActivation;
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", ConvBase::activation_);
|
||||
}
|
||||
|
||||
MlasNchwcConv(kernel_shape.size(),
|
||||
X_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
|
|
@ -173,7 +156,7 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
|
|||
W->template Data<float>(),
|
||||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
y_data,
|
||||
&Activation,
|
||||
&activation_,
|
||||
Sum == nullptr,
|
||||
const_cast<concurrency::ThreadPool*>(static_cast<OpKernelContextInternal*>(context)->GetOperatorThreadPool()));
|
||||
|
||||
|
|
@ -181,9 +164,9 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
|
||||
const TensorShape& X_shape = X->Shape();
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
|
||||
|
||||
|
|
@ -193,7 +176,7 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind
|
|||
|
||||
std::vector<int64_t> pads = pads_;
|
||||
std::vector<int64_t> output_dims = PoolBase::SetOutputSize(X_shape, X_shape[1], &pads, dilations_, ceil_mode_);
|
||||
Tensor* Y = context->Output(0, output_dims);
|
||||
auto* Y = context->Output(0, output_dims);
|
||||
|
||||
MlasNchwcPool(kind,
|
||||
2,
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@
|
|||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/nn/conv_impl.h"
|
||||
#include "core/providers/cpu/nn/conv_base.h"
|
||||
#include "core/providers/cpu/nn/pool.h"
|
||||
#include "contrib_ops/cpu/fused_activation.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -34,15 +35,16 @@ class ReorderOutput : public OpKernel {
|
|||
int64_t channels_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NchwcConv : public Conv<T> {
|
||||
class NchwcConv : public OpKernel, public ConvBase {
|
||||
public:
|
||||
NchwcConv(const OpKernelInfo& info) : Conv<T>(info) {
|
||||
Conv<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
Conv<T>::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
|
||||
NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
|
||||
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
MLAS_ACTIVATION activation_;
|
||||
};
|
||||
|
||||
class NchwcPoolBase : public PoolBase {
|
||||
|
|
|
|||
|
|
@ -10,5 +10,5 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
7,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gemm<float, float, float, float>);
|
||||
}
|
||||
Gemm<float>);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,10 +11,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T_X,
|
||||
typename T_W,
|
||||
typename T_B,
|
||||
typename T_Y>
|
||||
template <typename T>
|
||||
class Gemm : public OpKernel {
|
||||
public:
|
||||
Gemm(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
@ -40,75 +37,47 @@ class Gemm : public OpKernel {
|
|||
|
||||
int64_t M = helper.M();
|
||||
int64_t N = helper.N();
|
||||
int64_t K = helper.K();
|
||||
auto Y = context->Output(0, TensorShape({M, N}));
|
||||
auto Y = context->Output(0, {M, N});
|
||||
// if input is emtpy tensor, return directly as nothing need to be calculated.
|
||||
if (M == 0 || N == 0)
|
||||
return Status::OK();
|
||||
T_Y* y_data = Y->template MutableData<T_Y>();
|
||||
T* y_data = Y->template MutableData<T>();
|
||||
|
||||
//bias
|
||||
// Todo: we might should move this part into math::gemm to let eigen
|
||||
// have better chance to further optimize it.
|
||||
// Broadcast the bias as needed.
|
||||
if (beta_ != 0) {
|
||||
auto output_mat = EigenMatrixMapRowMajor<T_Y>(
|
||||
Y->template MutableData<T_Y>(),
|
||||
M,
|
||||
N);
|
||||
output_mat.setZero();
|
||||
|
||||
auto& b_shape = B->Shape();
|
||||
// if B is (), (1,) or (1, 1), add the scalar
|
||||
auto output_mat = EigenMatrixMapRowMajor<T>(y_data, M, N);
|
||||
const auto& b_shape = B->Shape();
|
||||
const T* b_data = B->template Data<T>();
|
||||
if (b_shape.Size() == 1) {
|
||||
output_mat.array() += *(B->template Data<T_B>());
|
||||
}
|
||||
// B is (N,)
|
||||
else if (b_shape.NumDimensions() == 1) {
|
||||
auto bias_vec = ConstEigenVectorMap<T_B>(
|
||||
B->template Data<T_B>(),
|
||||
N);
|
||||
output_mat.rowwise() += bias_vec.transpose();
|
||||
} else if (b_shape.NumDimensions() == 2) {
|
||||
// B is (), (1,) or (1, 1), set the scalar
|
||||
output_mat.setConstant(*b_data);
|
||||
} else if (b_shape.NumDimensions() == 1 || b_shape[0] == 1) {
|
||||
// B is (N,) or (1, N)
|
||||
output_mat.rowwise() = ConstEigenVectorMap<T>(b_data, N).transpose();
|
||||
} else if (b_shape[1] == 1) {
|
||||
// B is (M, 1)
|
||||
if (b_shape[1] == 1) {
|
||||
auto bias_vec = ConstEigenVectorMap<T_B>(
|
||||
B->template Data<T_B>(),
|
||||
M);
|
||||
output_mat.colwise() += bias_vec;
|
||||
}
|
||||
// B is (1, N)
|
||||
else if (b_shape[0] == 1) {
|
||||
auto bias_vec = ConstEigenVectorMap<T_B>(
|
||||
B->template Data<T_B>(),
|
||||
N);
|
||||
output_mat.rowwise() += bias_vec.transpose();
|
||||
}
|
||||
output_mat.colwise() = ConstEigenVectorMap<T>(b_data, M);
|
||||
} else {
|
||||
// B is (M, N), no broadcast needed.
|
||||
else {
|
||||
auto bias_mat = ConstEigenMatrixMapRowMajor<T_B>(
|
||||
B->template Data<T_B>(),
|
||||
M,
|
||||
N);
|
||||
output_mat += bias_mat;
|
||||
}
|
||||
output_mat = ConstEigenMatrixMapRowMajor<T>(b_data, M, N);
|
||||
}
|
||||
}
|
||||
|
||||
// W * x
|
||||
math::Gemm<T_X, CPUMathUtil>(
|
||||
math::Gemm<T, CPUMathUtil>(
|
||||
trans_A_,
|
||||
trans_B_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
helper.K(),
|
||||
alpha_,
|
||||
X->template Data<T_X>(),
|
||||
W->template Data<T_W>(),
|
||||
X->template Data<T>(),
|
||||
W->template Data<T>(),
|
||||
beta_,
|
||||
y_data,
|
||||
&CPUMathUtil::Instance());
|
||||
|
||||
FuseActivation<T_Y>(activation_, y_data, M * N, leaky_relu_alpha_);
|
||||
FuseActivation<T>(activation_, y_data, M * N, leaky_relu_alpha_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -119,7 +88,7 @@ class Gemm : public OpKernel {
|
|||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
protected:
|
||||
protected:
|
||||
// For fused gemm + activation
|
||||
std::string activation_;
|
||||
float leaky_relu_alpha_;
|
||||
|
|
|
|||
|
|
@ -1,13 +1,148 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
#include "core/providers/cpu/nn/conv.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/cpu/nn/conv_impl.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <>
|
||||
template <typename T>
|
||||
Status Conv<T>::Compute(OpKernelContext* context) const {
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t C = X->Shape()[1];
|
||||
const int64_t M = W->Shape()[0];
|
||||
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
|
||||
|
||||
std::vector<int64_t> kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
|
||||
|
||||
bool Is2DKernel = kernel_shape.size() == 2;
|
||||
std::vector<int64_t> pads(pads_);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_shape.size() * 2, 0);
|
||||
}
|
||||
std::vector<int64_t> dilations(dilations_);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
std::vector<int64_t> strides(strides_);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Y_dims;
|
||||
Y_dims.insert(Y_dims.begin(), {N, M});
|
||||
TensorShape input_shape = X->Shape().Slice(2);
|
||||
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
|
||||
Tensor* Y = context->Output(0, TensorShape(Y_dims));
|
||||
TensorShape output_shape = Y->Shape().Slice(2);
|
||||
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
const int64_t kernel_size = TensorShape(kernel_shape).Size();
|
||||
const int64_t X_offset = C / group_ * input_image_size;
|
||||
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
|
||||
const int64_t W_offset = W->Shape().Size() / group_;
|
||||
const int64_t kernel_dim = C / group_ * kernel_size;
|
||||
const int64_t col_buffer_size = kernel_dim * output_image_size;
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
T* col_buffer_data = static_cast<T*>(col_buffer.get());
|
||||
|
||||
const T* Xdata = X->template Data<T>();
|
||||
T* Ydata = Y->template MutableData<T>();
|
||||
|
||||
TensorShape image_shape = X->Shape().Slice(1);
|
||||
std::vector<int64_t> col_buffer_shape{kernel_dim};
|
||||
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
|
||||
output_shape.GetDims().end());
|
||||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
if (Is2DKernel) {
|
||||
math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
|
||||
Xdata + group_id * X_offset,
|
||||
C / group_,
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
kernel_shape[0],
|
||||
kernel_shape[1],
|
||||
dilations[0],
|
||||
dilations[1],
|
||||
pads[0],
|
||||
pads[1],
|
||||
pads[2],
|
||||
pads[3],
|
||||
strides[0],
|
||||
strides[1],
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
} else {
|
||||
math::Im2colNd<T, CPUMathUtil, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
image_shape.GetDims().data(),
|
||||
col_buffer_shape.data(),
|
||||
C * input_image_size,
|
||||
col_buffer_size,
|
||||
kernel_shape.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
math::Gemm<T, CPUMathUtil>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M / group_,
|
||||
output_image_size,
|
||||
kernel_dim,
|
||||
1,
|
||||
W->template Data<T>() + group_id * W_offset,
|
||||
col_buffer_data,
|
||||
0,
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_image_size, M);
|
||||
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Conv<float>::Compute(OpKernelContext* context) const {
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
|
|
@ -45,27 +180,12 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
const auto* Xdata = X->template Data<float>();
|
||||
const auto* Bdata = B != nullptr ? B->template Data<float>() : nullptr;
|
||||
auto* Ydata = Y->template MutableData<float>();
|
||||
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
|
||||
if (kernel_rank == 2 || kernel_rank == 3) {
|
||||
MLAS_ACTIVATION Activation;
|
||||
if (activation_.empty()) {
|
||||
Activation.ActivationKind = MlasIdentityActivation;
|
||||
} else if (activation_ == "Relu") {
|
||||
Activation.ActivationKind = MlasReluActivation;
|
||||
} else if (activation_ == "LeakyRelu") {
|
||||
Activation.ActivationKind = MlasLeakyReluActivation;
|
||||
Activation.alpha = alpha_;
|
||||
} else if (activation_ == "Tanh") {
|
||||
Activation.ActivationKind = MlasTanhActivation;
|
||||
} else if (activation_ == "Sigmoid") {
|
||||
Activation.ActivationKind = MlasLogisticActivation;
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_);
|
||||
}
|
||||
|
||||
// Get access to the internal threadpool
|
||||
// Temporarily derive concurrency parameters without access to session state
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
|
|
@ -85,7 +205,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
strides.data(),
|
||||
output_shape.GetDims().data(),
|
||||
static_cast<size_t>(M / group_),
|
||||
&Activation,
|
||||
&activation_,
|
||||
&WorkingBufferSize,
|
||||
const_cast<concurrency::ThreadPool*>(thread_pool));
|
||||
|
||||
|
|
@ -95,7 +215,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
MlasConv(&Parameters,
|
||||
Xdata,
|
||||
W->template Data<float>(),
|
||||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
Bdata,
|
||||
static_cast<float*>(working_buffer.get()),
|
||||
Ydata,
|
||||
const_cast<concurrency::ThreadPool*>(thread_pool));
|
||||
|
|
@ -147,13 +267,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<float>(Ydata, output_image_size, M);
|
||||
auto Bvec = ConstEigenVectorMap<float>(B->template Data<float>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
MlasActivation(&activation_, Ydata, Bdata, M, Ydata, output_image_size, output_image_size);
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
|
|
@ -168,4 +282,5 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Conv<float>);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,23 +1,10 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/nn/conv_base.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -30,4 +17,17 @@ class Conv : public OpKernel, public ConvBase {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <>
|
||||
class Conv<float> : public OpKernel, public ConvBase {
|
||||
public:
|
||||
Conv<float>(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
|
||||
activation_.ActivationKind = MlasIdentityActivation;
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
protected:
|
||||
MLAS_ACTIVATION activation_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,153 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/nn/conv.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T>
|
||||
Status Conv<T>::Compute(OpKernelContext* context) const {
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t C = X->Shape()[1];
|
||||
const int64_t M = W->Shape()[0];
|
||||
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
|
||||
|
||||
std::vector<int64_t> kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
|
||||
|
||||
bool Is2DKernel = kernel_shape.size() == 2;
|
||||
std::vector<int64_t> pads(pads_);
|
||||
if (pads.empty()) {
|
||||
pads.resize(kernel_shape.size() * 2, 0);
|
||||
}
|
||||
std::vector<int64_t> dilations(dilations_);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
std::vector<int64_t> strides(strides_);
|
||||
if (strides.empty()) {
|
||||
strides.resize(kernel_shape.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Y_dims;
|
||||
Y_dims.insert(Y_dims.begin(), {N, M});
|
||||
TensorShape input_shape = X->Shape().Slice(2);
|
||||
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
|
||||
Tensor* Y = context->Output(0, TensorShape(Y_dims));
|
||||
TensorShape output_shape = Y->Shape().Slice(2);
|
||||
|
||||
const int64_t input_image_size = input_shape.Size();
|
||||
const int64_t output_image_size = output_shape.Size();
|
||||
const int64_t kernel_size = TensorShape(kernel_shape).Size();
|
||||
const int64_t X_offset = C / group_ * input_image_size;
|
||||
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
|
||||
const int64_t W_offset = W->Shape().Size() / group_;
|
||||
const int64_t kernel_dim = C / group_ * kernel_size;
|
||||
const int64_t col_buffer_size = kernel_dim * output_image_size;
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size);
|
||||
BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
|
||||
T* col_buffer_data = static_cast<T*>(col_buffer.get());
|
||||
|
||||
const T* Xdata = X->template Data<T>();
|
||||
T* Ydata = Y->template MutableData<T>();
|
||||
|
||||
TensorShape image_shape = X->Shape().Slice(1);
|
||||
std::vector<int64_t> col_buffer_shape{kernel_dim};
|
||||
col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(),
|
||||
output_shape.GetDims().end());
|
||||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
if (Is2DKernel) {
|
||||
math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
|
||||
Xdata + group_id * X_offset,
|
||||
C / group_,
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
kernel_shape[0],
|
||||
kernel_shape[1],
|
||||
dilations[0],
|
||||
dilations[1],
|
||||
pads[0],
|
||||
pads[1],
|
||||
pads[2],
|
||||
pads[3],
|
||||
strides[0],
|
||||
strides[1],
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
} else {
|
||||
math::Im2colNd<T, CPUMathUtil, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
image_shape.GetDims().data(),
|
||||
col_buffer_shape.data(),
|
||||
C * input_image_size,
|
||||
col_buffer_size,
|
||||
kernel_shape.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
math::Gemm<T, CPUMathUtil>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M / group_,
|
||||
output_image_size,
|
||||
kernel_dim,
|
||||
1,
|
||||
W->template Data<T>() + group_id * W_offset,
|
||||
col_buffer_data,
|
||||
0,
|
||||
Ydata + group_id * Y_offset,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
if (B != nullptr) {
|
||||
auto Ymatrix = EigenMatrixMap<T>(Ydata, output_image_size, M);
|
||||
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Conv<float>::Compute(OpKernelContext* context) const;
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -157,7 +157,7 @@ TEST(GemmOpTest, GemmScalarBroadcast) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Gemm2DBroadcast) {
|
||||
TEST(GemmOpTest, Gemm2DBroadcast_1) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
|
|
@ -176,6 +176,26 @@ TEST(MathOpTest, Gemm2DBroadcast) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, Gemm2DBroadcast_2) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
test.AddAttribute("transB", (int64_t)0);
|
||||
test.AddAttribute("alpha", 1.0f);
|
||||
test.AddAttribute("beta", 1.0f);
|
||||
|
||||
// Same as GemmBroadcast, but adding the unnecessary second dimension.
|
||||
test.AddInput<float>("A", {2, 4},
|
||||
{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f});
|
||||
test.AddInput<float>("B", {4, 3}, std::vector<float>(12, 1.0f));
|
||||
test.AddInput<float>("C", {1, 3}, std::vector<float>{1.0f, 2.0f, 3.0f});
|
||||
test.AddOutput<float>("Y", {2, 3},
|
||||
{11.0f, 12.0f, 13.0f,
|
||||
-9.0f, -8.0f, -7.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GemmOpTest, GemmFalseBroadcast) {
|
||||
OpTester test("Gemm");
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue