From a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Tue, 16 May 2023 11:58:19 -0700 Subject: [PATCH] DML EP Register Split18 (#15931) Register Split18 for DirectML Split13 was previously implemented. Split18 adds a new attribute called "num_outputs" that must be used mutually exclusively with the "split" input. The "num_outputs" attribute wil split the tensor evenly (and handles odd uneven splits). To implement, the DML split tensor just needs to be overridden in the presence of the num_output attribute. --------- Co-authored-by: Dwayne Robinson --- docs/OperatorKernels.md | 3 ++- .../src/Operators/DmlOperatorSplit.cpp | 1 + .../src/Operators/OperatorRegistration.cpp | 2 ++ .../dml/OperatorAuthorHelper/Attributes.h | 1 + .../OperatorAuthorHelper/OperatorHelper.cpp | 26 +++++++++++++++++++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../OperatorAuthorHelper/OperatorVersions.h | 1 + .../providers/cpu/tensor/split_op_test.cc | 10 +++++-- 8 files changed, 42 insertions(+), 3 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a01bd0e34b..7b04f3d679 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1152,7 +1152,8 @@ Do not modify directly.* |Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**

or

*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**

or

*in* input:**T**
*out* outputs:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**

or

*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**

or

*in* input:**T**
*out* outputs:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp index df99a83c7c..638d31c82c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSplit.cpp @@ -44,5 +44,6 @@ public: DML_OP_DEFINE_CREATION_FUNCTION(Split7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Split11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Split13, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Split18, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 53bce5c715..83df19b5b6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -280,6 +280,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Flatten); DML_OP_EXTERN_CREATION_FUNCTION(Split7); DML_OP_EXTERN_CREATION_FUNCTION(Split11); DML_OP_EXTERN_CREATION_FUNCTION(Split13); +DML_OP_EXTERN_CREATION_FUNCTION(Split18); DML_OP_EXTERN_CREATION_FUNCTION(Transpose); DML_OP_EXTERN_CREATION_FUNCTION(Tile); DML_OP_EXTERN_CREATION_FUNCTION(Concat); @@ -629,6 +630,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_VER( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. {REG_INFO_VER( 13, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Moves splits from constant parameter to dynamic input. + {REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index b47e0c5ed1..a4c7b2fde3 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -64,6 +64,7 @@ namespace AttrName static constexpr const char* NewAxis = "new_axis"; static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes"; static constexpr const char* NormalizeVariance = "normalize_variance"; + static constexpr const char* NumOutputs = "num_outputs"; static constexpr const char* P = "p"; static constexpr const char* PaddingMode = "padding_mode"; static constexpr const char* OutputHeight = "output_height"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 403f15ebf4..13c7e9d0a4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -8,6 +8,13 @@ namespace OperatorHelper { + template + T DivideRoundUp(T x, T y) + { + assert(y != 0); + return (x + y - 1) / y; + } + bool ContainsEmptyDimensions(gsl::span dimensions) { return std::find(dimensions.begin(), dimensions.end(), 0u) != dimensions.end(); @@ -923,6 +930,25 @@ namespace OperatorHelper const uint32_t inputDimCount = gsl::narrow_cast(inputDimensions.size()); const uint32_t axis = operatorAttributes.GetOptionalAttribute(AttrName::Axis, 0); m_axis = static_cast(HandleNegativeAxis(axis, inputDimCount)); + + if (opsetVersion >= 18) // num_outputs attribute is only defined in opset18. + { + const uint32_t numOutputs = operatorAttributes.GetOptionalAttribute(AttrName::NumOutputs, 0); + if (numOutputs > 0) + { + ML_CHECK_VALID_ARGUMENT(m_split.size() == 0); + auto inputSizeAlongAxis = inputDimensions.at(m_axis); + auto outputSizeAlongAxis = DivideRoundUp(inputSizeAlongAxis, numOutputs); + m_split.resize(numOutputs, outputSizeAlongAxis); + // Every output has the same size except potentially the last one, which may be smaller. + m_split.back() = static_cast(inputSizeAlongAxis - (numOutputs - 1) * outputSizeAlongAxis); + } + else + { + // There is no num_outputs attribute set, so splits must be set. + ML_CHECK_VALID_ARGUMENT(m_split.size() > 0); + } + } } std::vector SplitHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 8629668bd4..2815cec715 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1484,6 +1484,7 @@ using ShapeInferenceHelper_Flatten13 = FlattenHelper; using ShapeInferenceHelper_Split7 = VersionedOpsetHelper; using ShapeInferenceHelper_Split11 = VersionedOpsetHelper; using ShapeInferenceHelper_Split13 = VersionedOpsetHelper; +using ShapeInferenceHelper_Split18 = VersionedOpsetHelper; using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 054faad5ba..10ba8bb156 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -400,6 +400,7 @@ namespace OperatorHelper static const int sc_sinceVer_ReduceMin = 18; static const int sc_sinceVer_ReduceProd = 18; static const int sc_sinceVer_ReduceSumSquare = 18; + static const int sc_sinceVer_Split = 18; } namespace MsftOperatorSet1 diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index a1dd470a2d..7712a0a5bf 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -720,7 +720,13 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) { 3.f, 4.f}}); int64_t num_outputs = 0; - RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false, + const std::unordered_set excluded_providers = + { + kTensorrtExecutionProvider, + kQnnExecutionProvider, + kDmlExecutionProvider, // Error message differs from expected CPU EP error message. + }; + RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false, "Attribute `num_outputs` value cannot be lower than 1"); outputs.clear(); @@ -730,7 +736,7 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) { {0.f, 0.f}}); num_outputs = 3; - RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false, + RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false, "Invalid num_outputs value of 3. Size of dimension being split is 2"); }