diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index a01bd0e34b..7b04f3d679 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1152,7 +1152,8 @@ Do not modify directly.*
|Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**
or
*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**
or
*in* input:**T**
*out* outputs:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**
or
*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**
or
*in* input:**T**
*out* outputs:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp
index df99a83c7c..638d31c82c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp
@@ -44,5 +44,6 @@ public:
DML_OP_DEFINE_CREATION_FUNCTION(Split7, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Split11, VersionedKernel);
DML_OP_DEFINE_CREATION_FUNCTION(Split13, VersionedKernel);
+DML_OP_DEFINE_CREATION_FUNCTION(Split18, VersionedKernel);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 53bce5c715..83df19b5b6 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -280,6 +280,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Flatten);
DML_OP_EXTERN_CREATION_FUNCTION(Split7);
DML_OP_EXTERN_CREATION_FUNCTION(Split11);
DML_OP_EXTERN_CREATION_FUNCTION(Split13);
+DML_OP_EXTERN_CREATION_FUNCTION(Split18);
DML_OP_EXTERN_CREATION_FUNCTION(Transpose);
DML_OP_EXTERN_CREATION_FUNCTION(Tile);
DML_OP_EXTERN_CREATION_FUNCTION(Concat);
@@ -629,6 +630,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_VER( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO_VER( 13, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Moves splits from constant parameter to dynamic input.
+ {REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
index b47e0c5ed1..a4c7b2fde3 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
@@ -64,6 +64,7 @@ namespace AttrName
static constexpr const char* NewAxis = "new_axis";
static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes";
static constexpr const char* NormalizeVariance = "normalize_variance";
+ static constexpr const char* NumOutputs = "num_outputs";
static constexpr const char* P = "p";
static constexpr const char* PaddingMode = "padding_mode";
static constexpr const char* OutputHeight = "output_height";
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 403f15ebf4..13c7e9d0a4 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -8,6 +8,13 @@
namespace OperatorHelper
{
+ template
+ T DivideRoundUp(T x, T y)
+ {
+ assert(y != 0);
+ return (x + y - 1) / y;
+ }
+
bool ContainsEmptyDimensions(gsl::span dimensions)
{
return std::find(dimensions.begin(), dimensions.end(), 0u) != dimensions.end();
@@ -923,6 +930,25 @@ namespace OperatorHelper
const uint32_t inputDimCount = gsl::narrow_cast(inputDimensions.size());
const uint32_t axis = operatorAttributes.GetOptionalAttribute(AttrName::Axis, 0);
m_axis = static_cast(HandleNegativeAxis(axis, inputDimCount));
+
+ if (opsetVersion >= 18) // num_outputs attribute is only defined in opset18.
+ {
+ const uint32_t numOutputs = operatorAttributes.GetOptionalAttribute(AttrName::NumOutputs, 0);
+ if (numOutputs > 0)
+ {
+ ML_CHECK_VALID_ARGUMENT(m_split.size() == 0);
+ auto inputSizeAlongAxis = inputDimensions.at(m_axis);
+ auto outputSizeAlongAxis = DivideRoundUp(inputSizeAlongAxis, numOutputs);
+ m_split.resize(numOutputs, outputSizeAlongAxis);
+ // Every output has the same size except potentially the last one, which may be smaller.
+ m_split.back() = static_cast(inputSizeAlongAxis - (numOutputs - 1) * outputSizeAlongAxis);
+ }
+ else
+ {
+ // There is no num_outputs attribute set, so splits must be set.
+ ML_CHECK_VALID_ARGUMENT(m_split.size() > 0);
+ }
+ }
}
std::vector SplitHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 8629668bd4..2815cec715 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1484,6 +1484,7 @@ using ShapeInferenceHelper_Flatten13 = FlattenHelper;
using ShapeInferenceHelper_Split7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Split11 = VersionedOpsetHelper;
using ShapeInferenceHelper_Split13 = VersionedOpsetHelper;
+using ShapeInferenceHelper_Split18 = VersionedOpsetHelper;
using ShapeInferenceHelper_Transpose = TransposeHelper;
using ShapeInferenceHelper_Concat = ConcatHelper;
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 054faad5ba..10ba8bb156 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -400,6 +400,7 @@ namespace OperatorHelper
static const int sc_sinceVer_ReduceMin = 18;
static const int sc_sinceVer_ReduceProd = 18;
static const int sc_sinceVer_ReduceSumSquare = 18;
+ static const int sc_sinceVer_Split = 18;
}
namespace MsftOperatorSet1
diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc
index a1dd470a2d..7712a0a5bf 100644
--- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc
@@ -720,7 +720,13 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
3.f, 4.f}});
int64_t num_outputs = 0;
- RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false,
+ const std::unordered_set excluded_providers =
+ {
+ kTensorrtExecutionProvider,
+ kQnnExecutionProvider,
+ kDmlExecutionProvider, // Error message differs from expected CPU EP error message.
+ };
+ RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
"Attribute `num_outputs` value cannot be lower than 1");
outputs.clear();
@@ -730,7 +736,7 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
{0.f, 0.f}});
num_outputs = 3;
- RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false,
+ RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
"Invalid num_outputs value of 3. Size of dimension being split is 2");
}