mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
support Pad(18) (#14219)
This commit is contained in:
parent
f03c507cf0
commit
05915d8393
6 changed files with 103 additions and 34 deletions
|
|
@ -397,25 +397,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{ "test_bitwise_xor_i32_2d", "pending opset 18 support"},
|
||||
{ "test_bitwise_xor_ui8_bcast_4v3d", "pending opset 18 support"},
|
||||
{ "test_bitwise_xor_ui64_bcast_3v1d", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_and_pad", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_and_pad_expanded", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_axes_chw", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_axes_chw_expanded", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_axes_hwc", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_axes_hwc_expanded", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_crop_expanded", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_pad", "pending opset 18 support"},
|
||||
{ "test_center_crop_pad_pad_expanded", "pending opset 18 support"},
|
||||
{ "test_col2im", "pending opset 18 support"},
|
||||
{ "test_col2im_5d", "pending opset 18 support"},
|
||||
{ "test_col2im_dilations", "pending opset 18 support"},
|
||||
{ "test_col2im_pads", "pending opset 18 support"},
|
||||
{ "test_col2im_strides", "pending opset 18 support"},
|
||||
{ "test_constant_pad", "pending opset 18 support"},
|
||||
{ "test_constant_pad_axes", "pending opset 18 support"},
|
||||
{ "test_edge_pad", "pending opset 18 support"},
|
||||
{ "test_reflect_pad", "pending opset 18 support"},
|
||||
{ "test_scatter_elements_with_axis", "pending opset 18 support"},
|
||||
{ "test_scatter_elements_without_axis", "pending opset 18 support"},
|
||||
{ "test_scatter_elements_with_duplicate_indices", "pending opset 18 support"},
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ Do not modify directly.*
|
|||
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(float)|
|
||||
|||[9, 15]|**T** = tensor(float)|
|
||||
|||[7, 8]|**T** = tensor(float)|
|
||||
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[2, 10]|**T** = tensor(double), tensor(float)|
|
||||
|ParametricSoftplus|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|
|
|
|||
|
|
@ -665,7 +665,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, NonZero);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL1);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2);
|
||||
|
|
@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split);
|
||||
|
|
@ -1855,7 +1856,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t,
|
||||
NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float,
|
||||
ReduceL1)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t,
|
||||
|
|
@ -2130,6 +2131,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
|
||||
ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/cpu/tensor/pad.h"
|
||||
|
||||
#include "core/framework/op_kernel_type_control_utils.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/util/math.h"
|
||||
|
|
@ -66,10 +67,24 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
|
|||
uint8_t,
|
||||
bool);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0,
|
||||
float,
|
||||
double,
|
||||
int32_t,
|
||||
int64_t,
|
||||
uint32_t,
|
||||
uint64_t,
|
||||
int8_t,
|
||||
uint8_t,
|
||||
bool);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0, int32_t, int64_t);
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, Pad, 13, Input, 0, int32_t, int64_t);
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0, int32_t, int64_t);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
using EnabledPad2Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
|
||||
|
|
@ -78,6 +93,9 @@ using EnabledPad11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
|
|||
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);
|
||||
using EnabledPad13Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
|
||||
kCpuExecutionProvider, kOnnxDomain, Pad, 13, Input, 0);
|
||||
using EnabledPad18Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
|
||||
kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0);
|
||||
|
||||
|
||||
using AllEnabledPadTypes =
|
||||
utils::TypeSetUnion<
|
||||
|
|
@ -106,13 +124,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
BuildKernelDefConstraintsFromTypeList<EnabledPad11Types>()),
|
||||
Pad);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Pad,
|
||||
13, 17,
|
||||
KernelDefBuilder().TypeConstraint(
|
||||
"T",
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPad13Types>()),
|
||||
Pad);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Pad,
|
||||
13,
|
||||
18,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPad13Types>()),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledPad18Types>()),
|
||||
Pad);
|
||||
|
||||
|
||||
|
|
@ -463,6 +489,20 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) {
|
|||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ComputePadWithAxes(
|
||||
gsl::span<const int64_t> pads_tensor_raw_data,
|
||||
gsl::span<const T> axes_tensor_raw_data,
|
||||
size_t data_rank,
|
||||
PadsVector& pads) {
|
||||
size_t axes_size = axes_tensor_raw_data.size();
|
||||
for (size_t i = 0; i < axes_size; ++i) {
|
||||
int64_t axis = HandleNegativeAxis(onnxruntime::narrow<int64_t>(axes_tensor_raw_data[i]), data_rank);
|
||||
pads[onnxruntime::narrow<size_t>(axis)] = pads_tensor_raw_data[i]; // xi_begin
|
||||
pads[data_rank + onnxruntime::narrow<size_t>(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end
|
||||
}
|
||||
}
|
||||
|
||||
Status Pad::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor& input_tensor = *ctx->Input<Tensor>(0);
|
||||
MLDataType data_type = input_tensor.DataType();
|
||||
|
|
@ -479,20 +519,41 @@ Status Pad::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
const Tensor& pads_tensor = *ctx->Input<Tensor>(1);
|
||||
auto pads_tensor_dims = pads_tensor.Shape().GetDims();
|
||||
ORT_ENFORCE(pads_tensor.IsDataType<int64_t>(),
|
||||
"Pads tensor should be an INT64 tensor");
|
||||
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
|
||||
"Pads tensor should be a 1D tensor of shape [2 * input_rank] "
|
||||
"or a 2D tensor of shape [1, 2 * input_rank]");
|
||||
|
||||
"Pads tensor should be a 1D tensor of shape [2 * num_axes] "
|
||||
"or a 2D tensor of shape [1, 2 * num_axes]");
|
||||
const int64_t* pads_tensor_raw_data = pads_tensor.Data<int64_t>();
|
||||
size_t pads_size = static_cast<size_t>(pads_tensor.Shape().Size());
|
||||
ORT_ENFORCE(pads_size == 2 * data_rank,
|
||||
"Pads tensor size should be equal to twice the input dimension count ");
|
||||
|
||||
pads.reserve(2 * data_rank);
|
||||
for (size_t i = 0; i < pads_size; ++i) {
|
||||
pads.push_back(pads_tensor_raw_data[i]);
|
||||
|
||||
const Tensor* axes_tensor = ctx->Input<Tensor>(3);
|
||||
if (axes_tensor) {
|
||||
const auto& axes_tensor_dims = axes_tensor->Shape().GetDims();
|
||||
ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor ");
|
||||
int64_t axes_size = axes_tensor_dims[0];
|
||||
|
||||
pads.resize(2 * data_rank, 0);
|
||||
if (axes_tensor->IsDataType<int32_t>()) {
|
||||
const int32_t* axes_tensor_raw_data = axes_tensor->Data<int32_t>();
|
||||
ComputePadWithAxes<int32_t>(
|
||||
{pads_tensor_raw_data, onnxruntime::narrow<size_t>(2 * axes_size)},
|
||||
{axes_tensor_raw_data, onnxruntime::narrow<size_t>(axes_size)},
|
||||
data_rank,
|
||||
pads);
|
||||
} else if(axes_tensor->IsDataType<int64_t>()) {
|
||||
const int64_t* axes_tensor_raw_data = axes_tensor->Data<int64_t>();
|
||||
ComputePadWithAxes<int64_t>(
|
||||
{pads_tensor_raw_data, onnxruntime::narrow<size_t>(2 * axes_size)},
|
||||
{axes_tensor_raw_data, onnxruntime::narrow<size_t>(axes_size)},
|
||||
data_rank,
|
||||
pads);
|
||||
}
|
||||
} else {
|
||||
ORT_ENFORCE(pads_size == 2 * data_rank,
|
||||
"Pads tensor size should be equal to twice the input dimension count ");
|
||||
for (size_t i = 0; i < pads_size; ++i) {
|
||||
pads.push_back(pads_tensor_raw_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Separate out any negative pads into the slices array
|
||||
|
|
@ -525,6 +586,7 @@ Status Pad::Compute(OpKernelContext* ctx) const {
|
|||
ORT_THROW("Unsupported input data type of ", data_type);
|
||||
}
|
||||
}
|
||||
|
||||
pads_to_use = &pads;
|
||||
slices_to_use = &slices;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -805,5 +805,28 @@ TEST(PadOpTest, BoolType) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(PadOpTest, ConstantPadAxes) {
|
||||
OpTester test("Pad", 18);
|
||||
test.AddAttribute("mode", "constant");
|
||||
test.AddInput<int32_t>("data", {1, 2, 2, 2},
|
||||
{
|
||||
1, 1,
|
||||
1, 1,
|
||||
1, 1,
|
||||
1, 1});
|
||||
test.AddInput<int64_t>("pads", {4}, {0, 1, 0, 1});
|
||||
test.AddInput<int32_t>("value", {1}, {0});
|
||||
test.AddInput<int32_t>("axes", {2}, {1, 3});
|
||||
test.AddOutput<int32_t>("output", {1, 2, 2, 4},
|
||||
{
|
||||
0, 1, 1, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 1, 1, 0,
|
||||
0, 1, 1, 0
|
||||
}
|
||||
);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -117,14 +117,9 @@
|
|||
"^test_add_uint8_cuda",
|
||||
"^test_roialign_aligned_*",
|
||||
"^test_bitwise_*",
|
||||
"^test_center_crop_pad_*",
|
||||
"^test_clip_default_int8_max_expanded_cpu",
|
||||
"^test_clip_default_int8_min_expanded_cpu",
|
||||
"^test_col2im_*",
|
||||
"^test_constant_pad_axes_cpu",
|
||||
"^test_constant_pad_cpu",
|
||||
"^test_edge_pad_cpu",
|
||||
"^test_reflect_pad_cpu",
|
||||
"^test_softplus_example_expanded_cpu",
|
||||
"^test_softplus_expanded_cpu",
|
||||
"^test_split_*",
|
||||
|
|
|
|||
Loading…
Reference in a new issue