support Pad(18) (#14219)

This commit is contained in:
liqun Fu 2023-01-23 12:14:35 -08:00 committed by GitHub
parent f03c507cf0
commit 05915d8393
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 103 additions and 34 deletions

View file

@ -397,25 +397,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{ "test_bitwise_xor_i32_2d", "pending opset 18 support"},
{ "test_bitwise_xor_ui8_bcast_4v3d", "pending opset 18 support"},
{ "test_bitwise_xor_ui64_bcast_3v1d", "pending opset 18 support"},
{ "test_center_crop_pad_crop", "pending opset 18 support"},
{ "test_center_crop_pad_crop_and_pad", "pending opset 18 support"},
{ "test_center_crop_pad_crop_and_pad_expanded", "pending opset 18 support"},
{ "test_center_crop_pad_crop_axes_chw", "pending opset 18 support"},
{ "test_center_crop_pad_crop_axes_chw_expanded", "pending opset 18 support"},
{ "test_center_crop_pad_crop_axes_hwc", "pending opset 18 support"},
{ "test_center_crop_pad_crop_axes_hwc_expanded", "pending opset 18 support"},
{ "test_center_crop_pad_crop_expanded", "pending opset 18 support"},
{ "test_center_crop_pad_pad", "pending opset 18 support"},
{ "test_center_crop_pad_pad_expanded", "pending opset 18 support"},
{ "test_col2im", "pending opset 18 support"},
{ "test_col2im_5d", "pending opset 18 support"},
{ "test_col2im_dilations", "pending opset 18 support"},
{ "test_col2im_pads", "pending opset 18 support"},
{ "test_col2im_strides", "pending opset 18 support"},
{ "test_constant_pad", "pending opset 18 support"},
{ "test_constant_pad_axes", "pending opset 18 support"},
{ "test_edge_pad", "pending opset 18 support"},
{ "test_reflect_pad", "pending opset 18 support"},
{ "test_scatter_elements_with_axis", "pending opset 18 support"},
{ "test_scatter_elements_without_axis", "pending opset 18 support"},
{ "test_scatter_elements_with_duplicate_indices", "pending opset 18 support"},

View file

@ -219,7 +219,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(float)|
|||[9, 15]|**T** = tensor(float)|
|||[7, 8]|**T** = tensor(float)|
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[2, 10]|**T** = tensor(double), tensor(float)|
|ParametricSoftplus|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|

View file

@ -665,7 +665,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, NonZero);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pad);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2);
@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split);
@ -1855,7 +1856,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t,
NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float,
ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t,
@ -2130,6 +2131,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split)>,

View file

@ -4,6 +4,7 @@
#include "core/providers/cpu/tensor/pad.h"
#include "core/framework/op_kernel_type_control_utils.h"
#include "core/providers/common.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math.h"
@ -66,10 +67,24 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
uint8_t,
bool);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0,
float,
double,
int32_t,
int64_t,
uint32_t,
uint64_t,
int8_t,
uint8_t,
bool);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0, int32_t, int64_t);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 13, Input, 0, int32_t, int64_t);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0, int32_t, int64_t);
} // namespace op_kernel_type_control
using EnabledPad2Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
@ -78,6 +93,9 @@ using EnabledPad11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);
using EnabledPad13Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 13, Input, 0);
using EnabledPad18Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0);
using AllEnabledPadTypes =
utils::TypeSetUnion<
@ -106,13 +124,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
BuildKernelDefConstraintsFromTypeList<EnabledPad11Types>()),
Pad);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Pad,
13, 17,
KernelDefBuilder().TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<EnabledPad13Types>()),
Pad);
ONNX_CPU_OPERATOR_KERNEL(
Pad,
13,
18,
KernelDefBuilder()
.TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<EnabledPad13Types>()),
BuildKernelDefConstraintsFromTypeList<EnabledPad18Types>()),
Pad);
@ -463,6 +489,20 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) {
return result;
}
template <class T>
void ComputePadWithAxes(
gsl::span<const int64_t> pads_tensor_raw_data,
gsl::span<const T> axes_tensor_raw_data,
size_t data_rank,
PadsVector& pads) {
size_t axes_size = axes_tensor_raw_data.size();
for (size_t i = 0; i < axes_size; ++i) {
int64_t axis = HandleNegativeAxis(onnxruntime::narrow<int64_t>(axes_tensor_raw_data[i]), data_rank);
pads[onnxruntime::narrow<size_t>(axis)] = pads_tensor_raw_data[i]; // xi_begin
pads[data_rank + onnxruntime::narrow<size_t>(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end
}
}
Status Pad::Compute(OpKernelContext* ctx) const {
const Tensor& input_tensor = *ctx->Input<Tensor>(0);
MLDataType data_type = input_tensor.DataType();
@ -479,20 +519,41 @@ Status Pad::Compute(OpKernelContext* ctx) const {
const Tensor& pads_tensor = *ctx->Input<Tensor>(1);
auto pads_tensor_dims = pads_tensor.Shape().GetDims();
ORT_ENFORCE(pads_tensor.IsDataType<int64_t>(),
"Pads tensor should be an INT64 tensor");
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
"Pads tensor should be a 1D tensor of shape [2 * input_rank] "
"or a 2D tensor of shape [1, 2 * input_rank]");
"Pads tensor should be a 1D tensor of shape [2 * num_axes] "
"or a 2D tensor of shape [1, 2 * num_axes]");
const int64_t* pads_tensor_raw_data = pads_tensor.Data<int64_t>();
size_t pads_size = static_cast<size_t>(pads_tensor.Shape().Size());
ORT_ENFORCE(pads_size == 2 * data_rank,
"Pads tensor size should be equal to twice the input dimension count ");
pads.reserve(2 * data_rank);
for (size_t i = 0; i < pads_size; ++i) {
pads.push_back(pads_tensor_raw_data[i]);
const Tensor* axes_tensor = ctx->Input<Tensor>(3);
if (axes_tensor) {
const auto& axes_tensor_dims = axes_tensor->Shape().GetDims();
ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor ");
int64_t axes_size = axes_tensor_dims[0];
pads.resize(2 * data_rank, 0);
if (axes_tensor->IsDataType<int32_t>()) {
const int32_t* axes_tensor_raw_data = axes_tensor->Data<int32_t>();
ComputePadWithAxes<int32_t>(
{pads_tensor_raw_data, onnxruntime::narrow<size_t>(2 * axes_size)},
{axes_tensor_raw_data, onnxruntime::narrow<size_t>(axes_size)},
data_rank,
pads);
} else if(axes_tensor->IsDataType<int64_t>()) {
const int64_t* axes_tensor_raw_data = axes_tensor->Data<int64_t>();
ComputePadWithAxes<int64_t>(
{pads_tensor_raw_data, onnxruntime::narrow<size_t>(2 * axes_size)},
{axes_tensor_raw_data, onnxruntime::narrow<size_t>(axes_size)},
data_rank,
pads);
}
} else {
ORT_ENFORCE(pads_size == 2 * data_rank,
"Pads tensor size should be equal to twice the input dimension count ");
for (size_t i = 0; i < pads_size; ++i) {
pads.push_back(pads_tensor_raw_data[i]);
}
}
// Separate out any negative pads into the slices array
@ -525,6 +586,7 @@ Status Pad::Compute(OpKernelContext* ctx) const {
ORT_THROW("Unsupported input data type of ", data_type);
}
}
pads_to_use = &pads;
slices_to_use = &slices;
} else {

View file

@ -805,5 +805,28 @@ TEST(PadOpTest, BoolType) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
TEST(PadOpTest, ConstantPadAxes) {
OpTester test("Pad", 18);
test.AddAttribute("mode", "constant");
test.AddInput<int32_t>("data", {1, 2, 2, 2},
{
1, 1,
1, 1,
1, 1,
1, 1});
test.AddInput<int64_t>("pads", {4}, {0, 1, 0, 1});
test.AddInput<int32_t>("value", {1}, {0});
test.AddInput<int32_t>("axes", {2}, {1, 3});
test.AddOutput<int32_t>("output", {1, 2, 2, 4},
{
0, 1, 1, 0,
0, 1, 1, 0,
0, 1, 1, 0,
0, 1, 1, 0
}
);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
} // namespace test
} // namespace onnxruntime

View file

@ -117,14 +117,9 @@
"^test_add_uint8_cuda",
"^test_roialign_aligned_*",
"^test_bitwise_*",
"^test_center_crop_pad_*",
"^test_clip_default_int8_max_expanded_cpu",
"^test_clip_default_int8_min_expanded_cpu",
"^test_col2im_*",
"^test_constant_pad_axes_cpu",
"^test_constant_pad_cpu",
"^test_edge_pad_cpu",
"^test_reflect_pad_cpu",
"^test_softplus_example_expanded_cpu",
"^test_softplus_expanded_cpu",
"^test_split_*",