diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 419f7524be..b46e94b739 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1080,7 +1080,8 @@ Do not modify directly.*
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
-|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp
index 68b567e086..86598fce62 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp
@@ -108,6 +108,22 @@ public:
}
};
+void CALLBACK QueryScatter(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
+{
+ *isSupported = false;
+
+ MLOperatorAttributes attributes(context);
+
+ // DML does not support reduction.
+ std::string reduction = attributes.GetOptionalAttribute(AttrName::Reduction, "none");
+ if (reduction != "none")
+ {
+ return;
+ }
+
+ *isSupported = true;
+}
+
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 1be82b8cd5..480b59c47a 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -274,6 +274,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
DML_OP_EXTERN_QUERY_FUNCTION(Resize);
+DML_OP_EXTERN_QUERY_FUNCTION(Scatter);
DML_OP_EXTERN_QUERY_FUNCTION(EinSum);
DML_OP_EXTERN_QUERY_FUNCTION(RecurrentNeuralNetwork);
DML_OP_EXTERN_QUERY_FUNCTION(BatchNormalization);
@@ -488,6 +489,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)},
{REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)},
+ {REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)},
{REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
index 5224103793..fe59778d08 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
@@ -69,6 +69,7 @@ namespace AttrName
static constexpr const char* OutputWidth = "output_width";
static constexpr const char* Pads = "pads";
static constexpr const char* PooledShape = "pooled_shape";
+ static constexpr const char* Reduction = "reduction";
static constexpr const char* Reverse = "reverse";
static constexpr const char* SampleSize = "sample_size";
static constexpr const char* SamplingRatio = "sampling_ratio";
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 29d70aa084..a673da6176 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -378,6 +378,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Where = 16;
static const int sc_sinceVer_GreaterOrEqual = 16;
static const int sc_sinceVer_LessOrEqual = 16;
+ static const int sc_sinceVer_ScatterND = 16;
} // namespace OnnxOperatorSet16
namespace OnnxOperatorSet17