DML EP Register Split18 (#15931)

Register Split18 for DirectML

Split13 was previously implemented. Split18 adds a new attribute called
"num_outputs" that must be used mutually exclusively with the "split"
input.

The "num_outputs" attribute wil split the tensor evenly (and handles odd
uneven splits). To implement, the DML split tensor just needs to be
overridden in the presence of the num_output attribute.

---------

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
Sheil Kumar 2023-05-16 11:58:19 -07:00 committed by GitHub
parent 04ea561fc8
commit a7ad859e3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 3 deletions

View file

@ -1152,7 +1152,8 @@ Do not modify directly.*
|Softsign|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|SpaceToDepth|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|

View file

@ -44,5 +44,6 @@ public:
DML_OP_DEFINE_CREATION_FUNCTION(Split7, VersionedKernel<DmlOperatorSplit, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Split11, VersionedKernel<DmlOperatorSplit, 11>);
DML_OP_DEFINE_CREATION_FUNCTION(Split13, VersionedKernel<DmlOperatorSplit, 13>);
DML_OP_DEFINE_CREATION_FUNCTION(Split18, VersionedKernel<DmlOperatorSplit, 18>);
} // namespace Dml

View file

@ -280,6 +280,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Flatten);
DML_OP_EXTERN_CREATION_FUNCTION(Split7);
DML_OP_EXTERN_CREATION_FUNCTION(Split11);
DML_OP_EXTERN_CREATION_FUNCTION(Split13);
DML_OP_EXTERN_CREATION_FUNCTION(Split18);
DML_OP_EXTERN_CREATION_FUNCTION(Transpose);
DML_OP_EXTERN_CREATION_FUNCTION(Tile);
DML_OP_EXTERN_CREATION_FUNCTION(Concat);
@ -629,6 +630,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_VER( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO_VER( 13, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Moves splits from constant parameter to dynamic input.
{REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},

View file

@ -64,6 +64,7 @@ namespace AttrName
static constexpr const char* NewAxis = "new_axis";
static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes";
static constexpr const char* NormalizeVariance = "normalize_variance";
static constexpr const char* NumOutputs = "num_outputs";
static constexpr const char* P = "p";
static constexpr const char* PaddingMode = "padding_mode";
static constexpr const char* OutputHeight = "output_height";

View file

@ -8,6 +8,13 @@
namespace OperatorHelper
{
template <typename T = uint32_t>
T DivideRoundUp(T x, T y)
{
assert(y != 0);
return (x + y - 1) / y;
}
bool ContainsEmptyDimensions(gsl::span<const DimensionType> dimensions)
{
return std::find(dimensions.begin(), dimensions.end(), 0u) != dimensions.end();
@ -923,6 +930,25 @@ namespace OperatorHelper
const uint32_t inputDimCount = gsl::narrow_cast<int32_t>(inputDimensions.size());
const uint32_t axis = operatorAttributes.GetOptionalAttribute<int32_t>(AttrName::Axis, 0);
m_axis = static_cast<int>(HandleNegativeAxis(axis, inputDimCount));
if (opsetVersion >= 18) // num_outputs attribute is only defined in opset18.
{
const uint32_t numOutputs = operatorAttributes.GetOptionalAttribute<int32_t>(AttrName::NumOutputs, 0);
if (numOutputs > 0)
{
ML_CHECK_VALID_ARGUMENT(m_split.size() == 0);
auto inputSizeAlongAxis = inputDimensions.at(m_axis);
auto outputSizeAlongAxis = DivideRoundUp(inputSizeAlongAxis, numOutputs);
m_split.resize(numOutputs, outputSizeAlongAxis);
// Every output has the same size except potentially the last one, which may be smaller.
m_split.back() = static_cast<int>(inputSizeAlongAxis - (numOutputs - 1) * outputSizeAlongAxis);
}
else
{
// There is no num_outputs attribute set, so splits must be set.
ML_CHECK_VALID_ARGUMENT(m_split.size() > 0);
}
}
}
std::vector<EdgeShapes> SplitHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const

View file

@ -1484,6 +1484,7 @@ using ShapeInferenceHelper_Flatten13 = FlattenHelper;
using ShapeInferenceHelper_Split7 = VersionedOpsetHelper<SplitHelper, 7>;
using ShapeInferenceHelper_Split11 = VersionedOpsetHelper<SplitHelper, 11>;
using ShapeInferenceHelper_Split13 = VersionedOpsetHelper<SplitHelper, 13>;
using ShapeInferenceHelper_Split18 = VersionedOpsetHelper<SplitHelper, 18>;
using ShapeInferenceHelper_Transpose = TransposeHelper;
using ShapeInferenceHelper_Concat = ConcatHelper;
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper<SliceHelper, 7>;

View file

@ -400,6 +400,7 @@ namespace OperatorHelper
static const int sc_sinceVer_ReduceMin = 18;
static const int sc_sinceVer_ReduceProd = 18;
static const int sc_sinceVer_ReduceSumSquare = 18;
static const int sc_sinceVer_Split = 18;
}
namespace MsftOperatorSet1

View file

@ -720,7 +720,13 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
3.f, 4.f}});
int64_t num_outputs = 0;
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false,
const std::unordered_set<std::string> excluded_providers =
{
kTensorrtExecutionProvider,
kQnnExecutionProvider,
kDmlExecutionProvider, // Error message differs from expected CPU EP error message.
};
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
"Attribute `num_outputs` value cannot be lower than 1");
outputs.clear();
@ -730,7 +736,7 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
{0.f, 0.f}});
num_outputs = 3;
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false,
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
"Invalid num_outputs value of 3. Size of dimension being split is 2");
}