diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
index c34259fb96..054355609b 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
@@ -397,25 +397,11 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{ "test_bitwise_xor_i32_2d", "pending opset 18 support"},
{ "test_bitwise_xor_ui8_bcast_4v3d", "pending opset 18 support"},
{ "test_bitwise_xor_ui64_bcast_3v1d", "pending opset 18 support"},
- { "test_center_crop_pad_crop", "pending opset 18 support"},
- { "test_center_crop_pad_crop_and_pad", "pending opset 18 support"},
- { "test_center_crop_pad_crop_and_pad_expanded", "pending opset 18 support"},
- { "test_center_crop_pad_crop_axes_chw", "pending opset 18 support"},
- { "test_center_crop_pad_crop_axes_chw_expanded", "pending opset 18 support"},
- { "test_center_crop_pad_crop_axes_hwc", "pending opset 18 support"},
- { "test_center_crop_pad_crop_axes_hwc_expanded", "pending opset 18 support"},
- { "test_center_crop_pad_crop_expanded", "pending opset 18 support"},
- { "test_center_crop_pad_pad", "pending opset 18 support"},
- { "test_center_crop_pad_pad_expanded", "pending opset 18 support"},
{ "test_col2im", "pending opset 18 support"},
{ "test_col2im_5d", "pending opset 18 support"},
{ "test_col2im_dilations", "pending opset 18 support"},
{ "test_col2im_pads", "pending opset 18 support"},
{ "test_col2im_strides", "pending opset 18 support"},
- { "test_constant_pad", "pending opset 18 support"},
- { "test_constant_pad_axes", "pending opset 18 support"},
- { "test_edge_pad", "pending opset 18 support"},
- { "test_reflect_pad", "pending opset 18 support"},
{ "test_scatter_elements_with_axis", "pending opset 18 support"},
{ "test_scatter_elements_without_axis", "pending opset 18 support"},
{ "test_scatter_elements_with_duplicate_indices", "pending opset 18 support"},
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 42c9be410b..964799c3a0 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -219,7 +219,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float)|
|||[9, 15]|**T** = tensor(float)|
|||[7, 8]|**T** = tensor(float)|
-|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[2, 10]|**T** = tensor(double), tensor(float)|
|ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index f9d963eda0..b7c369e173 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -665,7 +665,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, NonZero);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pad);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL1);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2);
@@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split);
@@ -1855,7 +1856,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc
index f14ed4c2f3..0e35398eb2 100644
--- a/onnxruntime/core/providers/cpu/tensor/pad.cc
+++ b/onnxruntime/core/providers/cpu/tensor/pad.cc
@@ -4,6 +4,7 @@
#include "core/providers/cpu/tensor/pad.h"
#include "core/framework/op_kernel_type_control_utils.h"
+#include "core/providers/common.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math.h"
@@ -66,10 +67,24 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
uint8_t,
bool);
+ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(
+ kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0,
+ float,
+ double,
+ int32_t,
+ int64_t,
+ uint32_t,
+ uint64_t,
+ int8_t,
+ uint8_t,
+ bool);
+
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0, int32_t, int64_t);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Pad, 13, Input, 0, int32_t, int64_t);
+ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(
+ kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0, int32_t, int64_t);
} // namespace op_kernel_type_control
using EnabledPad2Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
@@ -78,6 +93,9 @@ using EnabledPad11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0);
using EnabledPad13Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, Pad, 13, Input, 0);
+using EnabledPad18Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(
+ kCpuExecutionProvider, kOnnxDomain, Pad, 18, Input, 0);
+
using AllEnabledPadTypes =
utils::TypeSetUnion<
@@ -106,13 +124,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
BuildKernelDefConstraintsFromTypeList()),
Pad);
+ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
+ Pad,
+ 13, 17,
+ KernelDefBuilder().TypeConstraint(
+ "T",
+ BuildKernelDefConstraintsFromTypeList()),
+ Pad);
+
ONNX_CPU_OPERATOR_KERNEL(
Pad,
- 13,
+ 18,
KernelDefBuilder()
.TypeConstraint(
"T",
- BuildKernelDefConstraintsFromTypeList()),
+ BuildKernelDefConstraintsFromTypeList()),
Pad);
@@ -463,6 +489,20 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) {
return result;
}
+template
+void ComputePadWithAxes(
+ gsl::span pads_tensor_raw_data,
+ gsl::span axes_tensor_raw_data,
+ size_t data_rank,
+ PadsVector& pads) {
+ size_t axes_size = axes_tensor_raw_data.size();
+ for (size_t i = 0; i < axes_size; ++i) {
+ int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank);
+ pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin
+ pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end
+ }
+}
+
Status Pad::Compute(OpKernelContext* ctx) const {
const Tensor& input_tensor = *ctx->Input(0);
MLDataType data_type = input_tensor.DataType();
@@ -479,20 +519,41 @@ Status Pad::Compute(OpKernelContext* ctx) const {
const Tensor& pads_tensor = *ctx->Input(1);
auto pads_tensor_dims = pads_tensor.Shape().GetDims();
- ORT_ENFORCE(pads_tensor.IsDataType(),
- "Pads tensor should be an INT64 tensor");
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
- "Pads tensor should be a 1D tensor of shape [2 * input_rank] "
- "or a 2D tensor of shape [1, 2 * input_rank]");
-
+ "Pads tensor should be a 1D tensor of shape [2 * num_axes] "
+ "or a 2D tensor of shape [1, 2 * num_axes]");
const int64_t* pads_tensor_raw_data = pads_tensor.Data();
size_t pads_size = static_cast(pads_tensor.Shape().Size());
- ORT_ENFORCE(pads_size == 2 * data_rank,
- "Pads tensor size should be equal to twice the input dimension count ");
-
pads.reserve(2 * data_rank);
- for (size_t i = 0; i < pads_size; ++i) {
- pads.push_back(pads_tensor_raw_data[i]);
+
+ const Tensor* axes_tensor = ctx->Input(3);
+ if (axes_tensor) {
+ const auto& axes_tensor_dims = axes_tensor->Shape().GetDims();
+ ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor ");
+ int64_t axes_size = axes_tensor_dims[0];
+
+ pads.resize(2 * data_rank, 0);
+ if (axes_tensor->IsDataType()) {
+ const int32_t* axes_tensor_raw_data = axes_tensor->Data();
+ ComputePadWithAxes(
+ {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)},
+ {axes_tensor_raw_data, onnxruntime::narrow(axes_size)},
+ data_rank,
+ pads);
+ } else if(axes_tensor->IsDataType()) {
+ const int64_t* axes_tensor_raw_data = axes_tensor->Data();
+ ComputePadWithAxes(
+ {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)},
+ {axes_tensor_raw_data, onnxruntime::narrow(axes_size)},
+ data_rank,
+ pads);
+ }
+ } else {
+ ORT_ENFORCE(pads_size == 2 * data_rank,
+ "Pads tensor size should be equal to twice the input dimension count ");
+ for (size_t i = 0; i < pads_size; ++i) {
+ pads.push_back(pads_tensor_raw_data[i]);
+ }
}
// Separate out any negative pads into the slices array
@@ -525,6 +586,7 @@ Status Pad::Compute(OpKernelContext* ctx) const {
ORT_THROW("Unsupported input data type of ", data_type);
}
}
+
pads_to_use = &pads;
slices_to_use = &slices;
} else {
diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc
index 4042162a26..a0c5e9f06e 100644
--- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc
@@ -805,5 +805,28 @@ TEST(PadOpTest, BoolType) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
+TEST(PadOpTest, ConstantPadAxes) {
+ OpTester test("Pad", 18);
+ test.AddAttribute("mode", "constant");
+ test.AddInput("data", {1, 2, 2, 2},
+ {
+ 1, 1,
+ 1, 1,
+ 1, 1,
+ 1, 1});
+ test.AddInput("pads", {4}, {0, 1, 0, 1});
+ test.AddInput("value", {1}, {0});
+ test.AddInput("axes", {2}, {1, 3});
+ test.AddOutput("output", {1, 2, 2, 4},
+ {
+ 0, 1, 1, 0,
+ 0, 1, 1, 0,
+ 0, 1, 1, 0,
+ 0, 1, 1, 0
+ }
+ );
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index c91b616aa5..57a2eda7df 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -117,14 +117,9 @@
"^test_add_uint8_cuda",
"^test_roialign_aligned_*",
"^test_bitwise_*",
- "^test_center_crop_pad_*",
"^test_clip_default_int8_max_expanded_cpu",
"^test_clip_default_int8_min_expanded_cpu",
"^test_col2im_*",
- "^test_constant_pad_axes_cpu",
- "^test_constant_pad_cpu",
- "^test_edge_pad_cpu",
- "^test_reflect_pad_cpu",
"^test_softplus_example_expanded_cpu",
"^test_softplus_expanded_cpu",
"^test_split_*",