diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index e1659d6dd1..b570ef7f44 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1044,7 +1044,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||7+|**T** = tensor(float), tensor(float16)|
-|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**
or
*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp
index 84046f74ea..a014db5adb 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp
@@ -15,7 +15,7 @@ public:
{
const uint32_t inputCount = kernelInfo.GetInputCount();
ML_CHECK_VALID_ARGUMENT((opsetVersion >= 2 && opsetVersion < 11 && inputCount == 1)
- || (opsetVersion >= 11 && inputCount >= 2 && inputCount <= 3));
+ || (opsetVersion >= 11 && inputCount >= 2 && inputCount <= 4));
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
std::vector> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor.
@@ -68,12 +68,12 @@ public:
paddingDesc.EndPadding = m_endPadding.data();
// PaddingValueDataType will always be equal to inputDataTensorDataType
// Assigning paddingValueDataType to inputDataTensorDataType because this field
- // has to be assigned even if program does not go through below conditional
+ // has to be assigned even if program does not go through below conditional
// logic for some corner test case (like opsetVersion >= 11, but no validInput at index 2)
// Same applies to paddingValue.
paddingDesc.PaddingValueDataType = this->m_inputTensorDescs[0].GetDmlDataType();
CastToClampedScalarUnion(paddingDesc.PaddingValueDataType, 0.0f, /*out*/&paddingDesc.PaddingValue);
-
+
// Read the constant value which can come from an attribute or tensor.
if (opsetVersion >= 11)
{
@@ -107,7 +107,7 @@ void CALLBACK QueryPad(IMLOperatorSupportQueryContextPrivate* context, /*out*/ b
*isSupported = true;
MLOperatorAttributes attributes(context);
-
+
std::vector padding = attributes.GetOptionalAttributeVectorInt32(AttrName::Pads);
*isSupported = std::none_of(padding.begin(), padding.end(), [](int32_t padCount) {return padCount < 0; });
}
@@ -115,5 +115,6 @@ void CALLBACK QueryPad(IMLOperatorSupportQueryContextPrivate* context, /*out*/ b
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel);
+DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 44300a5f68..13919420ee 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -292,6 +292,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Slice13);
DML_OP_EXTERN_CREATION_FUNCTION(Pad7);
DML_OP_EXTERN_CREATION_FUNCTION(Pad11);
DML_OP_EXTERN_CREATION_FUNCTION(Pad13);
+DML_OP_EXTERN_CREATION_FUNCTION(Pad18);
DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth);
DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace);
DML_OP_EXTERN_CREATION_FUNCTION(Sqrt);
@@ -650,6 +651,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryPad)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
+ {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)},
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index bb484ec424..370f336ff5 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -41,6 +41,21 @@ namespace OperatorHelper
}
}
+ void HandleEmptyAxes(
+ /*inout*/std::vector& axes,
+ gsl::span inputShape,
+ bool treatEmptyAsNop
+ )
+ {
+ // If axes is not specified, reduce over all the dimensions.
+ // If empty axes should be treated as a nop, then just leave them as-is.
+ if (axes.empty() && !treatEmptyAsNop)
+ {
+ axes.resize(inputShape.size());
+ std::iota(axes.begin(), axes.end(), 0);
+ }
+ }
+
float CastFloat16ToFloat32(uint16_t input)
{
// Promote float16m10e5s1 to float32m23e8s1.
@@ -1122,12 +1137,36 @@ namespace OperatorHelper
}
ML_CHECK_VALID_ARGUMENT(padding.size() % 2 == 0, "Padding must be even count, including begin/end pairs.");
+ std::vector inputShape = shapeInformation.GetInputTensorShape(0);
+ uint32_t dimCount = gsl::narrow_cast(inputShape.size());
+ m_startPadding.resize(dimCount, 0);
+ m_endPadding.resize(dimCount, 0);
+ std::vector axes;
- uint32_t dimCount = gsl::narrow_cast(padding.size() / 2);
- m_startPadding.resize(dimCount);
- m_endPadding.resize(dimCount);
- std::copy(padding.begin(), padding.begin() + dimCount, m_startPadding.begin());
- std::copy(padding.begin() + dimCount, padding.begin() + dimCount * 2, m_endPadding.begin());
+ // Handle possible axes input
+ if (opsetVersion >= 18)
+ {
+ if (kernelInformation.IsInputValid(3))
+ {
+ ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(3), /*out*/ axes);
+ }
+ HandleEmptyAxes(axes, inputShape, false);
+ ML_CHECK_VALID_ARGUMENT(axes.size() * 2 == padding.size(), "The number of elements in padding should be 2 times the number of axes.");
+ HandleNegativeAxes(axes, dimCount);
+ }
+ else
+ {
+ HandleEmptyAxes(axes, inputShape, false);
+ }
+
+ uint32_t numAxes = gsl::narrow_cast(axes.size());
+ for (int32_t i = 0; i < axes.size(); i++)
+ {
+ auto xi_begin = padding[i];
+ auto xi_end = padding[i+axes.size()];
+ m_startPadding[axes[i]] = xi_begin;
+ m_endPadding[axes[i]] = xi_end;
+ }
}
std::vector PaddingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
@@ -1360,21 +1399,6 @@ namespace OperatorHelper
}
}
- void ReduceHelperBase::HandleEmptyAxes(
- /*inout*/std::vector& axes,
- gsl::span inputShape,
- bool treatEmptyAsNop
- )
- {
- // If axes is not specified, reduce over all the dimensions.
- // If empty axes should be treated as a nop, then just leave them as-is.
- if (axes.empty() && !treatEmptyAsNop)
- {
- axes.resize(inputShape.size());
- std::iota(axes.begin(), axes.end(), 0);
- }
- }
-
void EinSumHelper::Initialize()
{
ParseEquationComponents();
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 20ba5ad7a0..485e20c1df 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -687,13 +687,6 @@ public:
std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
-private:
- static void HandleEmptyAxes(
- /*inout*/std::vector& onnxAxes,
- gsl::span inputShape,
- bool treatEmptyAsNop
- );
-
protected:
std::vector m_axes;
int m_keepDims = 0; // Keep the dimensions rather than removing size 1 dimensions.
@@ -1526,6 +1519,7 @@ using ShapeInferenceHelper_Slice13 = VersionedOpsetHelper; // N
using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper;
using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper;
+using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper;
using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper;
using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 0332d51a97..c1e525400b 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -404,6 +404,7 @@ namespace OperatorHelper
static const int sc_sinceVer_BitwiseOr = 18;
static const int sc_sinceVer_BitwiseXor = 18;
static const int sc_sinceVer_BitwiseNot = 18;
+ static const int sc_sinceVer_Pad = 18;
static const int sc_sinceVer_Split = 18;
}
diff --git a/onnxruntime/test/providers/cpu/tensor/pad_test.cc b/onnxruntime/test/providers/cpu/tensor/pad_test.cc
index 9b6c3f1b36..98ded07f8c 100644
--- a/onnxruntime/test/providers/cpu/tensor/pad_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/pad_test.cc
@@ -1011,6 +1011,24 @@ TEST(PadOpTest, ConstantPadAxesTest3) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kNnapiExecutionProvider});
}
+TEST(PadOpTest, ConstantPadAxesTest4) {
+ OpTester test("Pad", 18);
+ test.AddAttribute("mode", "constant");
+ test.AddInput("data", {1, 2, 2, 2},
+ {1.0f, 1.0f,
+ 1.0f, 1.0f,
+ 1.0f, 1.0f,
+ 1.0f, 1.0f});
+ test.AddInput("pads", {8}, {0, 0, 0, 1, 0, 0, 0, 1}, true /* pads_is_initializer */);
+ test.AddInput("value", {1}, {0.0f}, true /* value_is_initializer */);
+ test.AddOutput("output", {1, 2, 2, 4},
+ {0.0f, 1.0f, 1.0f, 0.0f,
+ 0.0f, 1.0f, 1.0f, 0.0f,
+ 0.0f, 1.0f, 1.0f, 0.0f,
+ 0.0f, 1.0f, 1.0f, 0.0f});
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kNnapiExecutionProvider});
+}
+
TEST(PadOpTest, ConstantPadAxesOutOfOrder) {
// Specified out of order axes values
OpTester test("Pad", 18);