diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 979324d074..42c9be410b 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -304,10 +304,12 @@ Do not modify directly.*
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
-|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
-|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index d05801348b..f9d963eda0 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -765,8 +765,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, If
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterElements);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterND);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, Where);
@@ -830,6 +830,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split);
#if !defined(DISABLE_OPTIONAL_TYPE)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalHasElement);
@@ -2018,8 +2020,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
RoiAlign)>,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2128,11 +2130,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ReduceSumSquare)>,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_OPTIONAL_TYPE)
BuildKernelCreateInfo,
BuildKernelCreateInfo,
-#endif
+#endif
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc
index 7e5a304445..2d932dfe59 100644
--- a/onnxruntime/core/providers/cpu/tensor/scatter.cc
+++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc
@@ -86,9 +86,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
.TypeConstraint("Tind", BuildKernelDefConstraints()),
Scatter);
-ONNX_CPU_OPERATOR_KERNEL(
+ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ScatterElements,
16,
+ 17,
+ KernelDefBuilder()
+ .MayInplace(0, 0)
+ .TypeConstraint("T",
+ BuildKernelDefConstraintsFromTypeList())
+ .TypeConstraint("Tind", BuildKernelDefConstraints()),
+ Scatter);
+
+ONNX_CPU_OPERATOR_KERNEL(
+ ScatterElements,
+ 18,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T",
@@ -166,6 +177,76 @@ struct Func_Mul {
}
};
+template
+struct Func_Min {
+ void operator()(T* a, const T* b) const {
+ (*a) = (*a) < (*b) ? (*a) : (*b);
+ }
+};
+
+template <>
+struct Func_Min {
+ void operator()(bool*, const bool*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
+ }
+};
+
+template <>
+struct Func_Min {
+ void operator()(std::string*, const std::string*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
+ }
+};
+
+template <>
+struct Func_Min {
+ void operator()(MLFloat16*, const MLFloat16*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
+ }
+};
+
+template <>
+struct Func_Min {
+ void operator()(BFloat16*, const BFloat16*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
+ }
+};
+
+template
+struct Func_Max {
+ void operator()(T* a, const T* b) const {
+ (*a) = (*a) > (*b) ? (*a) : (*b);
+ }
+};
+
+template <>
+struct Func_Max {
+ void operator()(bool*, const bool*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
+ }
+};
+
+template <>
+struct Func_Max {
+ void operator()(std::string*, const std::string*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
+ }
+};
+
+template <>
+struct Func_Max {
+ void operator()(MLFloat16*, const MLFloat16*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
+ }
+};
+
+template <>
+struct Func_Max {
+ void operator()(BFloat16*, const BFloat16*) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
+ }
+};
+
template
Status GetIndices(
const Tensor& data_input, const Tensor& indices_input, int64_t axis,
@@ -317,7 +398,13 @@ struct ScatterDataDispatchTarget {
else if(reduction == "mul")
return ScatterData(
Func_Mul(), data_input, indices_data, updates_input, axis, data_output);
- else // if (reduction == "none")
+ else if (reduction == "min")
+ return ScatterData(
+ Func_Min(), data_input, indices_data, updates_input, axis, data_output);
+ else if (reduction == "max")
+ return ScatterData(
+ Func_Max(), data_input, indices_data, updates_input, axis, data_output);
+ else // if (reduction == "none")
return ScatterData(
Func_Assignment(), data_input, indices_data, updates_input, axis, data_output);
}
diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc
index d7ede706bd..82120a775f 100644
--- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc
+++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc
@@ -38,9 +38,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
BuildKernelDefConstraintsFromTypeList()),
ScatterND);
-ONNX_CPU_OPERATOR_KERNEL(
+ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ScatterND,
16,
+ 17,
+ KernelDefBuilder()
+ .TypeConstraint("T",
+ BuildKernelDefConstraintsFromTypeList()),
+ ScatterND);
+
+ONNX_CPU_OPERATOR_KERNEL(
+ ScatterND,
+ 18,
KernelDefBuilder()
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList()),
@@ -263,6 +272,84 @@ struct Func_Mul_ND {
}
};
+template
+struct Func_Min_ND {
+ void operator()(T* a, const T* b, uint64_t element_to_copy) const {
+ while (element_to_copy-- > 0) {
+ (*a) = (*a) < (*b) ? (*a) : (*b);
+ a++;
+ b++;
+ }
+ }
+};
+
+template <>
+struct Func_Min_ND {
+ void operator()(bool*, const bool*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'min'.");
+ }
+};
+
+template <>
+struct Func_Min_ND {
+ void operator()(std::string*, const std::string*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'min'.");
+ }
+};
+
+template <>
+struct Func_Min_ND {
+ void operator()(MLFloat16*, const MLFloat16*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'.");
+ }
+};
+
+template <>
+struct Func_Min_ND {
+ void operator()(BFloat16*, const BFloat16*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'.");
+ }
+};
+
+template
+struct Func_Max_ND {
+ void operator()(T* a, const T* b, uint64_t element_to_copy) const {
+ while (element_to_copy-- > 0) {
+ (*a) = (*a) > (*b) ? (*a) : (*b);
+ a++;
+ b++;
+ }
+ }
+};
+
+template <>
+struct Func_Max_ND {
+ void operator()(bool*, const bool*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'max'.");
+ }
+};
+
+template <>
+struct Func_Max_ND {
+ void operator()(std::string*, const std::string*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'max'.");
+ }
+};
+
+template <>
+struct Func_Max_ND {
+ void operator()(MLFloat16*, const MLFloat16*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'.");
+ }
+};
+
+template <>
+struct Func_Max_ND {
+ void operator()(BFloat16*, const BFloat16*, uint64_t) const {
+ ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'.");
+ }
+};
+
template
struct ScatterNDDispatchTarget {
Status operator()(OpKernelContext* context, concurrency::ThreadPool* tp, ScatterND::Reduction reduction) const {
@@ -285,6 +372,20 @@ struct ScatterNDDispatchTarget {
prepare.input_base + i * prepare.element_to_copy,
prepare.element_to_copy);
} break;
+ case ScatterND::Reduction::Min: {
+ auto func = Func_Min_ND();
+ func(
+ prepare.output_base + prepare.element_offsets[onnxruntime::narrow(i)],
+ prepare.input_base + i * prepare.element_to_copy,
+ prepare.element_to_copy);
+ } break;
+ case ScatterND::Reduction::Max: {
+ auto func = Func_Max_ND();
+ func(
+ prepare.output_base + prepare.element_offsets[onnxruntime::narrow(i)],
+ prepare.input_base + i * prepare.element_to_copy,
+ prepare.element_to_copy);
+ } break;
default:
case ScatterND::Reduction::None: {
auto func = Func_Copy_ND();
diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h
index 8ccdcb7d93..263f41d197 100644
--- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h
+++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h
@@ -18,7 +18,9 @@ class ScatterND final : public OpKernel {
enum class Reduction : int {
None = 0,
Add,
- Mul
+ Mul,
+ Min,
+ Max,
};
explicit ScatterND(const OpKernelInfo& info) : OpKernel(info) {
@@ -30,6 +32,10 @@ class ScatterND final : public OpKernel {
reduction_ = Reduction::Add;
else if (reduction == "mul")
reduction_ = Reduction::Mul;
+ else if (reduction == "min")
+ reduction_ = Reduction::Min;
+ else if (reduction == "max")
+ reduction_ = Reduction::Max;
}
}
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index d91511e7d9..c91b616aa5 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -104,8 +104,6 @@
"^test_identity_opt",
// Following tests are for opset 16 ops and are not yet implemented in ORT
"^test_roialign_aligned_*",
- "^test_scatternd_*",
- "^test_scatter_elements_with_duplicate_indices",
//GPU failures
"^test_batchnorm_epsilon_training_mode_cuda",
"^test_batchnorm_example_training_mode_cuda",
@@ -127,7 +125,6 @@
"^test_constant_pad_cpu",
"^test_edge_pad_cpu",
"^test_reflect_pad_cpu",
- "^test_scatter_elements_*",
"^test_softplus_example_expanded_cpu",
"^test_softplus_expanded_cpu",
"^test_split_*",
@@ -312,6 +309,9 @@
"^test_reduce_sum_do_not_keepdims_example",
"^test_reduce_sum_do_not_keepdims_random",
"^test_scatter_elements_with_negative_indices",
+ "^test_scatter_elements_with_duplicate_indices*",
+ "^test_scatternd_add*",
+ "^test_scatternd_multiply*",
"^test_squeeze",
"^test_squeeze_negative_axes"
],
@@ -332,6 +332,9 @@
"^test_reduce_sum_default_axes_keepdims*",
"^test_reduce_sum_empty_axes_input_noop*",
"^test_scatter_elements_with_negative_indices_cpu",
+ "^test_scatter_elements_with_duplicate_indices*",
+ "^test_scatternd_add*",
+ "^test_scatternd_multiply*",
"^test_squeeze_*", // Does not support axes as input
"^test_pow_types_float", // Runs disabled pow tests from the "current_failing_tests" list at the top
"^test_loop11",