diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 419f7524be..b46e94b739 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1080,7 +1080,8 @@ Do not modify directly.* |||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp index 68b567e086..86598fce62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp @@ -108,6 +108,22 @@ public: } }; +void CALLBACK QueryScatter(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported) +{ + *isSupported = false; + + MLOperatorAttributes attributes(context); + + // DML does not support reduction. + std::string reduction = attributes.GetOptionalAttribute(AttrName::Reduction, "none"); + if (reduction != "none") + { + return; + } + + *isSupported = true; +} + DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter); DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter); DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 1be82b8cd5..480b59c47a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -274,6 +274,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(NonZero); DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); DML_OP_EXTERN_QUERY_FUNCTION(Resize); +DML_OP_EXTERN_QUERY_FUNCTION(Scatter); DML_OP_EXTERN_QUERY_FUNCTION(EinSum); DML_OP_EXTERN_QUERY_FUNCTION(RecurrentNeuralNetwork); DML_OP_EXTERN_QUERY_FUNCTION(BatchNormalization); @@ -488,6 +489,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)}, {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)}, {REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)}, + {REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)}, {REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 5224103793..fe59778d08 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -69,6 +69,7 @@ namespace AttrName static constexpr const char* OutputWidth = "output_width"; static constexpr const char* Pads = "pads"; static constexpr const char* PooledShape = "pooled_shape"; + static constexpr const char* Reduction = "reduction"; static constexpr const char* Reverse = "reverse"; static constexpr const char* SampleSize = "sample_size"; static constexpr const char* SamplingRatio = "sampling_ratio"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 29d70aa084..a673da6176 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -378,6 +378,7 @@ namespace OperatorHelper static const int sc_sinceVer_Where = 16; static const int sc_sinceVer_GreaterOrEqual = 16; static const int sc_sinceVer_LessOrEqual = 16; + static const int sc_sinceVer_ScatterND = 16; } // namespace OnnxOperatorSet16 namespace OnnxOperatorSet17