diff --git a/onnxruntime/contrib_ops/cpu/fused_activation.cc b/onnxruntime/contrib_ops/cpu/fused_activation.cc new file mode 100644 index 0000000000..ad18bc0839 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/fused_activation.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/fused_activation.h" + +namespace onnxruntime { + +common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation) { + // Convert the activation parameters from the node into a MLAS_ACTIVATION. + activation.ActivationKind = MlasIdentityActivation; + + std::string activation_type; + if (info.GetAttr("activation", &activation_type).IsOK()) { + if (activation_type == "Relu") { + activation.ActivationKind = MlasReluActivation; + } else if (activation_type == "LeakyRelu") { + activation.ActivationKind = MlasLeakyReluActivation; + activation.alpha = info.GetAttrOrDefault("alpha", 0.01f); + } else if (activation_type == "Tanh") { + activation.ActivationKind = MlasTanhActivation; + } else if (activation_type == "Sigmoid") { + activation.ActivationKind = MlasLogisticActivation; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unimplemented activation: " + activation_type); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_activation.h b/onnxruntime/contrib_ops/cpu/fused_activation.h new file mode 100644 index 0000000000..0121a2038e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/fused_activation.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { + +common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation); + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.cc b/onnxruntime/contrib_ops/cpu/fused_conv.cc index ae8f81e812..2e07fa27d7 100644 --- a/onnxruntime/contrib_ops/cpu/fused_conv.cc +++ b/onnxruntime/contrib_ops/cpu/fused_conv.cc @@ -1,16 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "fused_conv.h" +#include "core/providers/cpu/nn/conv.h" +#include "contrib_ops/cpu/fused_activation.h" namespace onnxruntime { namespace contrib { + +class FusedConvFloat final : public Conv { + public: + FusedConvFloat(const OpKernelInfo& info) : Conv(info) { + ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + } +}; + ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( FusedConv, 1, float, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - FusedConv); + FusedConvFloat); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.h b/onnxruntime/contrib_ops/cpu/fused_conv.h deleted file mode 100644 index 329eb82990..0000000000 --- a/onnxruntime/contrib_ops/cpu/fused_conv.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/cpu/nn/conv_impl.h" - -namespace onnxruntime { -namespace contrib { - -template -class FusedConv : public Conv { - public: - FusedConv(const OpKernelInfo& info) : Conv(info) { - Conv::activation_ = info.GetAttrOrDefault("activation", ""); - Conv::alpha_ = info.GetAttrOrDefault("alpha", 0.01f); - } - - Status Compute(OpKernelContext* context) const override { - return Conv::Compute(context); - } -}; -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.cc b/onnxruntime/contrib_ops/cpu/fused_gemm.cc index e3bfe5b388..d743a3fcad 100644 --- a/onnxruntime/contrib_ops/cpu/fused_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/fused_gemm.cc @@ -1,15 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "fused_gemm.h" +#include "core/providers/cpu/math/gemm.h" namespace onnxruntime { namespace contrib { + +template +class FusedGemm final : public Gemm { + public: + FusedGemm(const OpKernelInfo& info) : Gemm(info) { + Gemm::activation_ = info.GetAttrOrDefault("activation", ""); + Gemm::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f); + } +}; + ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( FusedGemm, 1, float, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - FusedGemm); + FusedGemm); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.h b/onnxruntime/contrib_ops/cpu/fused_gemm.h deleted file mode 100644 index 5be1b34cb4..0000000000 --- a/onnxruntime/contrib_ops/cpu/fused_gemm.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/cpu/math/gemm.h" - -namespace onnxruntime { -namespace contrib { -template -class FusedGemm : public Gemm { - public: - FusedGemm(const OpKernelInfo& info) : Gemm(info) { - Gemm::activation_ = info.GetAttrOrDefault("activation", ""); - Gemm::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f); - } - - Status Compute(OpKernelContext* context) const override { - return Gemm::Compute(context); - } -}; -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index a55bab68ba..b5625551ad 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -34,7 +34,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( KernelDefBuilder() .MayInplace(3, 0) .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcConv); + NchwcConv); ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( MaxPool, @@ -70,39 +70,38 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( template Status ReorderInput::Compute(OpKernelContext* context) const { - const Tensor* X = context->Input(0); - const TensorShape& X_shape = X->Shape(); + const auto* X = context->Input(0); + const auto& X_shape = X->Shape(); ORT_ENFORCE(X_shape.NumDimensions() == 4); ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0); - Tensor* Y = context->Output(0, X_shape); + auto* Y = context->Output(0, X_shape); MlasReorderInput(X_shape.GetDims().data(), X->template Data(), Y->template MutableData()); return Status::OK(); } template Status ReorderOutput::Compute(OpKernelContext* context) const { - const Tensor* X = context->Input(0); - const TensorShape& X_shape = X->Shape(); + const auto* X = context->Input(0); + const auto& X_shape = X->Shape(); ORT_ENFORCE(X_shape.NumDimensions() == 4); std::vector Y_shape(X_shape.GetDims()); ORT_ENFORCE(channels_ <= Y_shape[1]); Y_shape[1] = channels_; - Tensor* Y = context->Output(0, Y_shape); + auto* Y = context->Output(0, Y_shape); MlasReorderOutput(Y_shape.data(), X->template Data(), Y->template MutableData()); return Status::OK(); } -template -Status NchwcConv::Compute(OpKernelContext* context) const { - const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = context->Input(2); - const Tensor* Sum = context->Input(3); +Status NchwcConv::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* B = context->Input(2); + const auto* Sum = context->Input(3); ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W)); - const TensorShape& X_shape = X->Shape(); - const TensorShape& W_shape = W->Shape(); + const auto& X_shape = X->Shape(); + const auto& W_shape = W->Shape(); ORT_ENFORCE(X_shape.NumDimensions() == 4); const size_t nchwc_block_size = MlasNchwcGetBlockSize(); @@ -131,36 +130,20 @@ Status NchwcConv::Compute(OpKernelContext* context) const { Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]}); TensorShape input_shape = X->Shape().Slice(2); ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); - Tensor* Y = context->Output(0, Y_dims); - T* y_data = Y->template MutableData(); + auto* Y = context->Output(0, Y_dims); + auto* y_data = Y->template MutableData(); // Check for the optional Conv/Sum fusion. if (Sum != nullptr) { const auto& sum_shape = Sum->Shape(); ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match"); // If the output was not allocated inplace with the sum tensor, then copy here. - const float* sum_data = Sum->template Data(); + const auto* sum_data = Sum->template Data(); if (y_data != sum_data) { - memcpy(y_data, sum_data, sum_shape.Size() * sizeof(T)); + memcpy(y_data, sum_data, sum_shape.Size() * sizeof(float)); } } - MLAS_ACTIVATION Activation; - if (ConvBase::activation_.empty()) { - Activation.ActivationKind = MlasIdentityActivation; - } else if (ConvBase::activation_ == "Relu") { - Activation.ActivationKind = MlasReluActivation; - } else if (ConvBase::activation_ == "LeakyRelu") { - Activation.ActivationKind = MlasLeakyReluActivation; - Activation.alpha = ConvBase::alpha_; - } else if (ConvBase::activation_ == "Tanh") { - Activation.ActivationKind = MlasTanhActivation; - } else if (ConvBase::activation_ == "Sigmoid") { - Activation.ActivationKind = MlasLogisticActivation; - } else { - ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", ConvBase::activation_); - } - MlasNchwcConv(kernel_shape.size(), X_shape.GetDims().data(), kernel_shape.data(), @@ -173,7 +156,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const { W->template Data(), B != nullptr ? B->template Data() : nullptr, y_data, - &Activation, + &activation_, Sum == nullptr, const_cast(static_cast(context)->GetOperatorThreadPool())); @@ -181,9 +164,9 @@ Status NchwcConv::Compute(OpKernelContext* context) const { } Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const { - const Tensor* X = context->Input(0); + const auto* X = context->Input(0); - const TensorShape& X_shape = X->Shape(); + const auto& X_shape = X->Shape(); ORT_ENFORCE(X_shape.NumDimensions() == 4); ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0); @@ -193,7 +176,7 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind std::vector pads = pads_; std::vector output_dims = PoolBase::SetOutputSize(X_shape, X_shape[1], &pads, dilations_, ceil_mode_); - Tensor* Y = context->Output(0, output_dims); + auto* Y = context->Output(0, output_dims); MlasNchwcPool(kind, 2, diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index d9dbdc500e..65045cd0ee 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -5,8 +5,9 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/providers/cpu/nn/conv_impl.h" +#include "core/providers/cpu/nn/conv_base.h" #include "core/providers/cpu/nn/pool.h" +#include "contrib_ops/cpu/fused_activation.h" namespace onnxruntime { namespace contrib { @@ -34,15 +35,16 @@ class ReorderOutput : public OpKernel { int64_t channels_; }; -template -class NchwcConv : public Conv { +class NchwcConv : public OpKernel, public ConvBase { public: - NchwcConv(const OpKernelInfo& info) : Conv(info) { - Conv::activation_ = info.GetAttrOrDefault("activation", ""); - Conv::alpha_ = info.GetAttrOrDefault("alpha", 0.01f); + NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); } Status Compute(OpKernelContext* context) const override; + + private: + MLAS_ACTIVATION activation_; }; class NchwcPoolBase : public PoolBase { diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 6554e3c601..c54e910b07 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -10,5 +10,5 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 7, 9, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gemm); -} \ No newline at end of file + Gemm); +} diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index c72c5bc1e0..a3aa724ab4 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -11,10 +11,7 @@ namespace onnxruntime { -template +template class Gemm : public OpKernel { public: Gemm(const OpKernelInfo& info) : OpKernel(info) { @@ -40,75 +37,47 @@ class Gemm : public OpKernel { int64_t M = helper.M(); int64_t N = helper.N(); - int64_t K = helper.K(); - auto Y = context->Output(0, TensorShape({M, N})); + auto Y = context->Output(0, {M, N}); // if input is emtpy tensor, return directly as nothing need to be calculated. if (M == 0 || N == 0) return Status::OK(); - T_Y* y_data = Y->template MutableData(); + T* y_data = Y->template MutableData(); - //bias - // Todo: we might should move this part into math::gemm to let eigen - // have better chance to further optimize it. + // Broadcast the bias as needed. if (beta_ != 0) { - auto output_mat = EigenMatrixMapRowMajor( - Y->template MutableData(), - M, - N); - output_mat.setZero(); - - auto& b_shape = B->Shape(); - // if B is (), (1,) or (1, 1), add the scalar + auto output_mat = EigenMatrixMapRowMajor(y_data, M, N); + const auto& b_shape = B->Shape(); + const T* b_data = B->template Data(); if (b_shape.Size() == 1) { - output_mat.array() += *(B->template Data()); - } - // B is (N,) - else if (b_shape.NumDimensions() == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - N); - output_mat.rowwise() += bias_vec.transpose(); - } else if (b_shape.NumDimensions() == 2) { + // B is (), (1,) or (1, 1), set the scalar + output_mat.setConstant(*b_data); + } else if (b_shape.NumDimensions() == 1 || b_shape[0] == 1) { + // B is (N,) or (1, N) + output_mat.rowwise() = ConstEigenVectorMap(b_data, N).transpose(); + } else if (b_shape[1] == 1) { // B is (M, 1) - if (b_shape[1] == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - M); - output_mat.colwise() += bias_vec; - } - // B is (1, N) - else if (b_shape[0] == 1) { - auto bias_vec = ConstEigenVectorMap( - B->template Data(), - N); - output_mat.rowwise() += bias_vec.transpose(); - } + output_mat.colwise() = ConstEigenVectorMap(b_data, M); + } else { // B is (M, N), no broadcast needed. - else { - auto bias_mat = ConstEigenMatrixMapRowMajor( - B->template Data(), - M, - N); - output_mat += bias_mat; - } + output_mat = ConstEigenMatrixMapRowMajor(b_data, M, N); } } // W * x - math::Gemm( + math::Gemm( trans_A_, trans_B_, M, N, - K, + helper.K(), alpha_, - X->template Data(), - W->template Data(), + X->template Data(), + W->template Data(), beta_, y_data, &CPUMathUtil::Instance()); - FuseActivation(activation_, y_data, M * N, leaky_relu_alpha_); + FuseActivation(activation_, y_data, M * N, leaky_relu_alpha_); return Status::OK(); } @@ -119,7 +88,7 @@ class Gemm : public OpKernel { float alpha_; float beta_; -protected: + protected: // For fused gemm + activation std::string activation_; float leaky_relu_alpha_; diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index 7505aca264..70071b8dfe 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -1,13 +1,148 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +/** +* Copyright (c) 2016-present, Facebook, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +/* Modifications Copyright (c) Microsoft. */ +#include "core/providers/cpu/nn/conv.h" #include "core/framework/op_kernel_context_internal.h" -#include "core/providers/cpu/nn/conv_impl.h" #include "core/util/math_cpuonly.h" namespace onnxruntime { -template <> +template +Status Conv::Compute(OpKernelContext* context) const { + size_t num_inputs = OpKernel::Node().InputDefs().size(); + + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; + const int64_t N = X->Shape()[0]; + const int64_t C = X->Shape()[1]; + const int64_t M = W->Shape()[0]; + ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); + + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); + + bool Is2DKernel = kernel_shape.size() == 2; + std::vector pads(pads_); + if (pads.empty()) { + pads.resize(kernel_shape.size() * 2, 0); + } + std::vector dilations(dilations_); + if (dilations.empty()) { + dilations.resize(kernel_shape.size(), 1); + } + std::vector strides(strides_); + if (strides.empty()) { + strides.resize(kernel_shape.size(), 1); + } + + std::vector Y_dims; + Y_dims.insert(Y_dims.begin(), {N, M}); + TensorShape input_shape = X->Shape().Slice(2); + ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + Tensor* Y = context->Output(0, TensorShape(Y_dims)); + TensorShape output_shape = Y->Shape().Slice(2); + + const int64_t input_image_size = input_shape.Size(); + const int64_t output_image_size = output_shape.Size(); + const int64_t kernel_size = TensorShape(kernel_shape).Size(); + const int64_t X_offset = C / group_ * input_image_size; + const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; + const int64_t W_offset = W->Shape().Size() / group_; + const int64_t kernel_dim = C / group_ * kernel_size; + const int64_t col_buffer_size = kernel_dim * output_image_size; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + + auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size); + BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); + T* col_buffer_data = static_cast(col_buffer.get()); + + const T* Xdata = X->template Data(); + T* Ydata = Y->template MutableData(); + + TensorShape image_shape = X->Shape().Slice(1); + std::vector col_buffer_shape{kernel_dim}; + col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(), + output_shape.GetDims().end()); + + for (int image_id = 0; image_id < N; ++image_id) { + for (int group_id = 0; group_id < group_; ++group_id) { + if (Is2DKernel) { + math::Im2col( + Xdata + group_id * X_offset, + C / group_, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + col_buffer_data, + &CPUMathUtil::Instance()); + } else { + math::Im2colNd()( + Xdata + group_id * X_offset, + image_shape.GetDims().data(), + col_buffer_shape.data(), + C * input_image_size, + col_buffer_size, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_shape.size()), + col_buffer_data, + &CPUMathUtil::Instance()); + } + math::Gemm( + CblasNoTrans, + CblasNoTrans, + M / group_, + output_image_size, + kernel_dim, + 1, + W->template Data() + group_id * W_offset, + col_buffer_data, + 0, + Ydata + group_id * Y_offset, + &CPUMathUtil::Instance()); + } + + if (B != nullptr) { + auto Ymatrix = EigenMatrixMap(Ydata, output_image_size, M); + auto Bvec = ConstEigenVectorMap(B->template Data(), M); + Ymatrix.rowwise() += Bvec.transpose(); + } + + Xdata += X_offset * group_; + Ydata += Y_offset * group_; + } + + return Status::OK(); +} + Status Conv::Compute(OpKernelContext* context) const { size_t num_inputs = OpKernel::Node().InputDefs().size(); const auto* X = context->Input(0); @@ -45,27 +180,12 @@ Status Conv::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); const auto* Xdata = X->template Data(); + const auto* Bdata = B != nullptr ? B->template Data() : nullptr; auto* Ydata = Y->template MutableData(); const size_t kernel_rank = kernel_shape.size(); if (kernel_rank == 2 || kernel_rank == 3) { - MLAS_ACTIVATION Activation; - if (activation_.empty()) { - Activation.ActivationKind = MlasIdentityActivation; - } else if (activation_ == "Relu") { - Activation.ActivationKind = MlasReluActivation; - } else if (activation_ == "LeakyRelu") { - Activation.ActivationKind = MlasLeakyReluActivation; - Activation.alpha = alpha_; - } else if (activation_ == "Tanh") { - Activation.ActivationKind = MlasTanhActivation; - } else if (activation_ == "Sigmoid") { - Activation.ActivationKind = MlasLogisticActivation; - } else { - ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_); - } - // Get access to the internal threadpool // Temporarily derive concurrency parameters without access to session state auto ctx_internal = static_cast(context); @@ -85,7 +205,7 @@ Status Conv::Compute(OpKernelContext* context) const { strides.data(), output_shape.GetDims().data(), static_cast(M / group_), - &Activation, + &activation_, &WorkingBufferSize, const_cast(thread_pool)); @@ -95,7 +215,7 @@ Status Conv::Compute(OpKernelContext* context) const { MlasConv(&Parameters, Xdata, W->template Data(), - B != nullptr ? B->template Data() : nullptr, + Bdata, static_cast(working_buffer.get()), Ydata, const_cast(thread_pool)); @@ -147,13 +267,7 @@ Status Conv::Compute(OpKernelContext* context) const { &CPUMathUtil::Instance()); } - if (B != nullptr) { - auto Ymatrix = EigenMatrixMap(Ydata, output_image_size, M); - auto Bvec = ConstEigenVectorMap(B->template Data(), M); - Ymatrix.rowwise() += Bvec.transpose(); - } - - FuseActivation(activation_, Ydata, Y_offset * group_, alpha_); + MlasActivation(&activation_, Ydata, Bdata, M, Ydata, output_image_size, output_image_size); Xdata += X_offset * group_; Ydata += Y_offset * group_; @@ -168,4 +282,5 @@ ONNX_CPU_OPERATOR_KERNEL( 1, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv.h b/onnxruntime/core/providers/cpu/nn/conv.h index cf6484ab41..3e366e1b49 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.h +++ b/onnxruntime/core/providers/cpu/nn/conv.h @@ -1,23 +1,10 @@ -/** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -/* Modifications Copyright (c) Microsoft. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once #include "core/providers/cpu/nn/conv_base.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -30,4 +17,17 @@ class Conv : public OpKernel, public ConvBase { Status Compute(OpKernelContext* context) const override; }; +template <> +class Conv : public OpKernel, public ConvBase { + public: + Conv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + activation_.ActivationKind = MlasIdentityActivation; + } + + Status Compute(OpKernelContext* context) const override; + + protected: + MLAS_ACTIVATION activation_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv_impl.h b/onnxruntime/core/providers/cpu/nn/conv_impl.h deleted file mode 100644 index 679630f465..0000000000 --- a/onnxruntime/core/providers/cpu/nn/conv_impl.h +++ /dev/null @@ -1,153 +0,0 @@ -/** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -/* Modifications Copyright (c) Microsoft. */ - -#pragma once - -#include "core/providers/cpu/nn/conv.h" -#include "core/util/math.h" -#include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" - -namespace onnxruntime { - -template -Status Conv::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); - - const auto* X = context->Input(0); - const auto* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; - const int64_t N = X->Shape()[0]; - const int64_t C = X->Shape()[1]; - const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); - - std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); - - bool Is2DKernel = kernel_shape.size() == 2; - std::vector pads(pads_); - if (pads.empty()) { - pads.resize(kernel_shape.size() * 2, 0); - } - std::vector dilations(dilations_); - if (dilations.empty()) { - dilations.resize(kernel_shape.size(), 1); - } - std::vector strides(strides_); - if (strides.empty()) { - strides.resize(kernel_shape.size(), 1); - } - - std::vector Y_dims; - Y_dims.insert(Y_dims.begin(), {N, M}); - TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); - Tensor* Y = context->Output(0, TensorShape(Y_dims)); - TensorShape output_shape = Y->Shape().Slice(2); - - const int64_t input_image_size = input_shape.Size(); - const int64_t output_image_size = output_shape.Size(); - const int64_t kernel_size = TensorShape(kernel_shape).Size(); - const int64_t X_offset = C / group_ * input_image_size; - const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; - const int64_t kernel_dim = C / group_ * kernel_size; - const int64_t col_buffer_size = kernel_dim * output_image_size; - - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); - - auto col_data = alloc->Alloc(sizeof(T) * col_buffer_size); - BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); - T* col_buffer_data = static_cast(col_buffer.get()); - - const T* Xdata = X->template Data(); - T* Ydata = Y->template MutableData(); - - TensorShape image_shape = X->Shape().Slice(1); - std::vector col_buffer_shape{kernel_dim}; - col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(), - output_shape.GetDims().end()); - - for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { - if (Is2DKernel) { - math::Im2col( - Xdata + group_id * X_offset, - C / group_, - input_shape[0], - input_shape[1], - kernel_shape[0], - kernel_shape[1], - dilations[0], - dilations[1], - pads[0], - pads[1], - pads[2], - pads[3], - strides[0], - strides[1], - col_buffer_data, - &CPUMathUtil::Instance()); - } else { - math::Im2colNd()( - Xdata + group_id * X_offset, - image_shape.GetDims().data(), - col_buffer_shape.data(), - C * input_image_size, - col_buffer_size, - kernel_shape.data(), - strides.data(), - dilations.data(), - pads.data(), - static_cast(kernel_shape.size()), - col_buffer_data, - &CPUMathUtil::Instance()); - } - math::Gemm( - CblasNoTrans, - CblasNoTrans, - M / group_, - output_image_size, - kernel_dim, - 1, - W->template Data() + group_id * W_offset, - col_buffer_data, - 0, - Ydata + group_id * Y_offset, - &CPUMathUtil::Instance()); - } - - if (B != nullptr) { - auto Ymatrix = EigenMatrixMap(Ydata, output_image_size, M); - auto Bvec = ConstEigenVectorMap(B->template Data(), M); - Ymatrix.rowwise() += Bvec.transpose(); - } - FuseActivation(activation_, Ydata, Y_offset * group_, alpha_); - - Xdata += X_offset * group_; - Ydata += Y_offset * group_; - } - - return Status::OK(); -} - -template <> -Status Conv::Compute(OpKernelContext* context) const; - -} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 2fb17717fa..8673f9cc6b 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -157,7 +157,7 @@ TEST(GemmOpTest, GemmScalarBroadcast) { test.Run(); } -TEST(MathOpTest, Gemm2DBroadcast) { +TEST(GemmOpTest, Gemm2DBroadcast_1) { OpTester test("Gemm"); test.AddAttribute("transA", (int64_t)0); @@ -176,6 +176,26 @@ TEST(MathOpTest, Gemm2DBroadcast) { test.Run(); } +TEST(GemmOpTest, Gemm2DBroadcast_2) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Same as GemmBroadcast, but adding the unnecessary second dimension. + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddInput("C", {1, 3}, std::vector{1.0f, 2.0f, 3.0f}); + test.AddOutput("Y", {2, 3}, + {11.0f, 12.0f, 13.0f, + -9.0f, -8.0f, -7.0f}); + test.Run(); +} + TEST(GemmOpTest, GemmFalseBroadcast) { OpTester test("Gemm");