DML EP Register ScatterND-16 (#14240)

This PR registers ScatterND-16 to the DML EP

- CPU fallback is added if the reduction attribute is in use, as this is
not yet supported by DML.

Co-authored-by: Numfor Mbiziwo-Tiapo <numform@microsoft.com>
This commit is contained in:
Numfor Tiapo 2023-01-12 10:39:25 -08:00 committed by GitHub
parent 8f7eb75c3e
commit dee36f8ade
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 1 deletions

View file

@ -1080,7 +1080,8 @@ Do not modify directly.*
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|

View file

@ -108,6 +108,22 @@ public:
}
};
void CALLBACK QueryScatter(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = false;
MLOperatorAttributes attributes(context);
// DML does not support reduction.
std::string reduction = attributes.GetOptionalAttribute<std::string>(AttrName::Reduction, "none");
if (reduction != "none")
{
return;
}
*isSupported = true;
}
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter);

View file

@ -274,6 +274,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
DML_OP_EXTERN_QUERY_FUNCTION(Resize);
DML_OP_EXTERN_QUERY_FUNCTION(Scatter);
DML_OP_EXTERN_QUERY_FUNCTION(EinSum);
DML_OP_EXTERN_QUERY_FUNCTION(RecurrentNeuralNetwork);
DML_OP_EXTERN_QUERY_FUNCTION(BatchNormalization);
@ -488,6 +489,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)},
{REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)},
{REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)},
{REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},

View file

@ -69,6 +69,7 @@ namespace AttrName
static constexpr const char* OutputWidth = "output_width";
static constexpr const char* Pads = "pads";
static constexpr const char* PooledShape = "pooled_shape";
static constexpr const char* Reduction = "reduction";
static constexpr const char* Reverse = "reverse";
static constexpr const char* SampleSize = "sample_size";
static constexpr const char* SamplingRatio = "sampling_ratio";

View file

@ -378,6 +378,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Where = 16;
static const int sc_sinceVer_GreaterOrEqual = 16;
static const int sc_sinceVer_LessOrEqual = 16;
static const int sc_sinceVer_ScatterND = 16;
} // namespace OnnxOperatorSet16
namespace OnnxOperatorSet17