mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
DML EP Register Split18 (#15931)
Register Split18 for DirectML Split13 was previously implemented. Split18 adds a new attribute called "num_outputs" that must be used mutually exclusively with the "split" input. The "num_outputs" attribute wil split the tensor evenly (and handles odd uneven splits). To implement, the DML split tensor just needs to be overridden in the presence of the num_output attribute. --------- Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
parent
04ea561fc8
commit
a7ad859e3a
8 changed files with 42 additions and 3 deletions
|
|
@ -1152,7 +1152,8 @@ Do not modify directly.*
|
|||
|Softsign|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SpaceToDepth|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -44,5 +44,6 @@ public:
|
|||
DML_OP_DEFINE_CREATION_FUNCTION(Split7, VersionedKernel<DmlOperatorSplit, 7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Split11, VersionedKernel<DmlOperatorSplit, 11>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Split13, VersionedKernel<DmlOperatorSplit, 13>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Split18, VersionedKernel<DmlOperatorSplit, 18>);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -280,6 +280,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Flatten);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Split7);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Split11);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Split13);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Split18);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Transpose);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Tile);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Concat);
|
||||
|
|
@ -629,6 +630,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_VER( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis.
|
||||
{REG_INFO_VER( 13, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Moves splits from constant parameter to dynamic input.
|
||||
{REG_INFO_VER( 18, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ namespace AttrName
|
|||
static constexpr const char* NewAxis = "new_axis";
|
||||
static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes";
|
||||
static constexpr const char* NormalizeVariance = "normalize_variance";
|
||||
static constexpr const char* NumOutputs = "num_outputs";
|
||||
static constexpr const char* P = "p";
|
||||
static constexpr const char* PaddingMode = "padding_mode";
|
||||
static constexpr const char* OutputHeight = "output_height";
|
||||
|
|
|
|||
|
|
@ -8,6 +8,13 @@
|
|||
|
||||
namespace OperatorHelper
|
||||
{
|
||||
template <typename T = uint32_t>
|
||||
T DivideRoundUp(T x, T y)
|
||||
{
|
||||
assert(y != 0);
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
bool ContainsEmptyDimensions(gsl::span<const DimensionType> dimensions)
|
||||
{
|
||||
return std::find(dimensions.begin(), dimensions.end(), 0u) != dimensions.end();
|
||||
|
|
@ -923,6 +930,25 @@ namespace OperatorHelper
|
|||
const uint32_t inputDimCount = gsl::narrow_cast<int32_t>(inputDimensions.size());
|
||||
const uint32_t axis = operatorAttributes.GetOptionalAttribute<int32_t>(AttrName::Axis, 0);
|
||||
m_axis = static_cast<int>(HandleNegativeAxis(axis, inputDimCount));
|
||||
|
||||
if (opsetVersion >= 18) // num_outputs attribute is only defined in opset18.
|
||||
{
|
||||
const uint32_t numOutputs = operatorAttributes.GetOptionalAttribute<int32_t>(AttrName::NumOutputs, 0);
|
||||
if (numOutputs > 0)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(m_split.size() == 0);
|
||||
auto inputSizeAlongAxis = inputDimensions.at(m_axis);
|
||||
auto outputSizeAlongAxis = DivideRoundUp(inputSizeAlongAxis, numOutputs);
|
||||
m_split.resize(numOutputs, outputSizeAlongAxis);
|
||||
// Every output has the same size except potentially the last one, which may be smaller.
|
||||
m_split.back() = static_cast<int>(inputSizeAlongAxis - (numOutputs - 1) * outputSizeAlongAxis);
|
||||
}
|
||||
else
|
||||
{
|
||||
// There is no num_outputs attribute set, so splits must be set.
|
||||
ML_CHECK_VALID_ARGUMENT(m_split.size() > 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> SplitHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
|
|
|
|||
|
|
@ -1484,6 +1484,7 @@ using ShapeInferenceHelper_Flatten13 = FlattenHelper;
|
|||
using ShapeInferenceHelper_Split7 = VersionedOpsetHelper<SplitHelper, 7>;
|
||||
using ShapeInferenceHelper_Split11 = VersionedOpsetHelper<SplitHelper, 11>;
|
||||
using ShapeInferenceHelper_Split13 = VersionedOpsetHelper<SplitHelper, 13>;
|
||||
using ShapeInferenceHelper_Split18 = VersionedOpsetHelper<SplitHelper, 18>;
|
||||
using ShapeInferenceHelper_Transpose = TransposeHelper;
|
||||
using ShapeInferenceHelper_Concat = ConcatHelper;
|
||||
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper<SliceHelper, 7>;
|
||||
|
|
|
|||
|
|
@ -400,6 +400,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_ReduceMin = 18;
|
||||
static const int sc_sinceVer_ReduceProd = 18;
|
||||
static const int sc_sinceVer_ReduceSumSquare = 18;
|
||||
static const int sc_sinceVer_Split = 18;
|
||||
}
|
||||
|
||||
namespace MsftOperatorSet1
|
||||
|
|
|
|||
|
|
@ -720,7 +720,13 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
|
|||
3.f, 4.f}});
|
||||
|
||||
int64_t num_outputs = 0;
|
||||
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false,
|
||||
const std::unordered_set<std::string> excluded_providers =
|
||||
{
|
||||
kTensorrtExecutionProvider,
|
||||
kQnnExecutionProvider,
|
||||
kDmlExecutionProvider, // Error message differs from expected CPU EP error message.
|
||||
};
|
||||
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
|
||||
"Attribute `num_outputs` value cannot be lower than 1");
|
||||
|
||||
outputs.clear();
|
||||
|
|
@ -730,7 +736,7 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
|
|||
{0.f, 0.f}});
|
||||
|
||||
num_outputs = 3;
|
||||
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, true, true, num_outputs, false,
|
||||
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
|
||||
"Invalid num_outputs value of 3. Size of dimension being split is 2");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue