mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Add MaxPool FP16 in XnnPack EP (#22258)
### Description Add support for FP16 kernels in the XnnPack execution provider for MaxPool operations. Fixes: [AB#50332](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/50332) ### Motivation and Context The major purpose of this pull request is to add some common vars/functions and setup a consistent style for adding FP16 kernels in XnnPack EP. ---------
This commit is contained in:
parent
c73e6afa6c
commit
bbb54985a8
4 changed files with 83 additions and 5 deletions
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/providers/utils.h"
|
||||
#include "core/providers/xnnpack/xnnpack_init.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
// to sanity check output shape
|
||||
|
|
@ -54,6 +55,10 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit,
|
|||
// input of maxpool could be fp16/fp32/fp64,i8/u8 according to ONNX
|
||||
if (x_type == nullptr ||
|
||||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
// because pool_fp16_op_test can be enabled by other preprocessor, for example, COREML_ENABLE_MLPROGRAM
|
||||
#ifdef XNNPACK_FP16_SUPPORTED
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 &&
|
||||
#endif
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
|
||||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
|
||||
break;
|
||||
|
|
@ -193,9 +198,19 @@ MaxPool::MaxPool(const OpKernelInfo& info)
|
|||
stride_height, stride_width,
|
||||
dilation_height, dilation_width,
|
||||
output_min, output_max, flags, &p);
|
||||
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
maxpool_type_ = OpComputeType::op_compute_type_fp16;
|
||||
const float output_min = -65504.0;
|
||||
const float output_max = 65504.0;
|
||||
status = xnn_create_max_pooling2d_nhwc_f16(input_padding_top, input_padding_right,
|
||||
input_padding_bottom, input_padding_left,
|
||||
pooling_height, pooling_width,
|
||||
stride_height, stride_width,
|
||||
dilation_height, dilation_width,
|
||||
output_min, output_max, flags, &p);
|
||||
} else {
|
||||
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto()));
|
||||
ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8, but got ", stype);
|
||||
ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8|FLOAT16, but got ", stype);
|
||||
}
|
||||
ORT_ENFORCE(status == xnn_status_success, "xnn_create_max_pooling2d_nhwc_",
|
||||
OpTypeToString(maxpool_type_), "failed. Status:", status);
|
||||
|
|
@ -225,10 +240,12 @@ Status MaxPool::Compute(OpKernelContext* context) const {
|
|||
pthreadpool_t threadpool = GetThreadPool();
|
||||
|
||||
auto reshape_fn = xnn_reshape_max_pooling2d_nhwc_f32;
|
||||
if (maxpool_type_ == OpComputeType::op_compute_type_qu8)
|
||||
if (maxpool_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
reshape_fn = xnn_reshape_max_pooling2d_nhwc_u8;
|
||||
else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
} else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
reshape_fn = xnn_reshape_max_pooling2d_nhwc_s8;
|
||||
} else if (maxpool_type_ == OpComputeType::op_compute_type_fp16) {
|
||||
reshape_fn = xnn_reshape_max_pooling2d_nhwc_f16;
|
||||
}
|
||||
|
||||
auto status = reshape_fn(op0_.get(), N, H, W,
|
||||
|
|
@ -244,8 +261,10 @@ Status MaxPool::Compute(OpKernelContext* context) const {
|
|||
status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), X.Data<float>(), Y->MutableData<float>());
|
||||
} else if (maxpool_type_ == OpComputeType::op_compute_type_qu8) {
|
||||
status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
|
||||
} else {
|
||||
} else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
|
||||
status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
|
||||
} else if (maxpool_type_ == OpComputeType::op_compute_type_fp16) {
|
||||
status = xnn_setup_max_pooling2d_nhwc_f16(op0_.get(), X.Data<MLFloat16>(), Y->MutableData<MLFloat16>());
|
||||
}
|
||||
|
||||
if (status != xnn_status_success) {
|
||||
|
|
@ -285,5 +304,24 @@ ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionPro
|
|||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>()}),
|
||||
MaxPool);
|
||||
|
||||
#ifdef XNNPACK_FP16_SUPPORTED
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, MLFloat16, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
MaxPool);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, MLFloat16, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
MaxPool);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, MLFloat16, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
MaxPool);
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, MLFloat16, kXnnpackExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
MaxPool);
|
||||
#endif
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)>
|
||||
|
||||
#define KERNEL_CREATE_INFO_VERSIONED_TYPED(Start, End, Type, Op, Domain) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Type, Op)>
|
||||
|
||||
#define KERNEL_CREATE_INFO(Start, Op, Domain) \
|
||||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)>
|
||||
|
|
@ -39,6 +43,19 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
BuildKernelCreateInfo< \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)>
|
||||
|
||||
#ifdef XNNPACK_FP16_SUPPORTED
|
||||
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \
|
||||
startver, endver, MLFloat16, name)
|
||||
|
||||
#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \
|
||||
MLFloat16, name)
|
||||
#else
|
||||
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name)
|
||||
#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name)
|
||||
#endif
|
||||
|
||||
// Layout sensitive operators in NHWC domain
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
|
||||
|
|
@ -68,6 +85,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSIn
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
|
||||
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool);
|
||||
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
|
||||
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
|
||||
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
|
||||
|
||||
// ONNX operators
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm);
|
||||
|
|
@ -138,6 +159,13 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain),
|
||||
|
||||
KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),
|
||||
|
||||
#ifdef XNNPACK_FP16_SUPPORTED
|
||||
KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain),
|
||||
KERNEL_CREATE_INFO_TYPED(12, MLFloat16, MaxPool, kMSInternalNHWCDomain),
|
||||
#endif
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -46,6 +46,18 @@ namespace xnnpack {
|
|||
#define XNN_ALLOCATION_ALIGNMENT 16
|
||||
#endif
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
|
||||
#define XNN_ARCH_ARM64 1
|
||||
#else
|
||||
#define XNN_ARCH_ARM64 0
|
||||
#endif
|
||||
|
||||
// fp16 support can vary on a kernel by kernel basis. Keep it simple and limit to arm64 for now.
|
||||
// e.g. XNNPACK maxpool has x64 and arm64 fp16 kernels.
|
||||
#if XNN_ARCH_ARM64
|
||||
#define XNNPACK_FP16_SUPPORTED
|
||||
#endif
|
||||
|
||||
std::pair<AllocatorPtr&, xnn_allocator*> GetStoredAllocator();
|
||||
|
||||
} // namespace xnnpack
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)
|
||||
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED)
|
||||
|
||||
#include "core/providers/cpu/nn/pool.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
|
|
|||
Loading…
Reference in a new issue