diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 979324d074..42c9be410b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -304,10 +304,12 @@ Do not modify directly.* |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d05801348b..f9d963eda0 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -765,8 +765,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, If class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterElements); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, Where); @@ -830,6 +830,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split); #if !defined(DISABLE_OPTIONAL_TYPE) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalHasElement); @@ -2018,8 +2020,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { RoiAlign)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2128,11 +2130,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ReduceSumSquare)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_OPTIONAL_TYPE) BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 7e5a304445..2d932dfe59 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -86,9 +86,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( .TypeConstraint("Tind", BuildKernelDefConstraints()), Scatter); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ScatterElements, 16, + 17, + KernelDefBuilder() + .MayInplace(0, 0) + .TypeConstraint("T", + BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraints()), + Scatter); + +ONNX_CPU_OPERATOR_KERNEL( + ScatterElements, + 18, KernelDefBuilder() .MayInplace(0, 0) .TypeConstraint("T", @@ -166,6 +177,76 @@ struct Func_Mul { } }; +template +struct Func_Min { + void operator()(T* a, const T* b) const { + (*a) = (*a) < (*b) ? (*a) : (*b); + } +}; + +template <> +struct Func_Min { + void operator()(bool*, const bool*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'min'."); + } +}; + +template <> +struct Func_Min { + void operator()(std::string*, const std::string*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'min'."); + } +}; + +template <> +struct Func_Min { + void operator()(MLFloat16*, const MLFloat16*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); + } +}; + +template <> +struct Func_Min { + void operator()(BFloat16*, const BFloat16*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); + } +}; + +template +struct Func_Max { + void operator()(T* a, const T* b) const { + (*a) = (*a) > (*b) ? (*a) : (*b); + } +}; + +template <> +struct Func_Max { + void operator()(bool*, const bool*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'max'."); + } +}; + +template <> +struct Func_Max { + void operator()(std::string*, const std::string*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'max'."); + } +}; + +template <> +struct Func_Max { + void operator()(MLFloat16*, const MLFloat16*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); + } +}; + +template <> +struct Func_Max { + void operator()(BFloat16*, const BFloat16*) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); + } +}; + template Status GetIndices( const Tensor& data_input, const Tensor& indices_input, int64_t axis, @@ -317,7 +398,13 @@ struct ScatterDataDispatchTarget { else if(reduction == "mul") return ScatterData( Func_Mul(), data_input, indices_data, updates_input, axis, data_output); - else // if (reduction == "none") + else if (reduction == "min") + return ScatterData( + Func_Min(), data_input, indices_data, updates_input, axis, data_output); + else if (reduction == "max") + return ScatterData( + Func_Max(), data_input, indices_data, updates_input, axis, data_output); + else // if (reduction == "none") return ScatterData( Func_Assignment(), data_input, indices_data, updates_input, axis, data_output); } diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc index d7ede706bd..82120a775f 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc @@ -38,9 +38,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( BuildKernelDefConstraintsFromTypeList()), ScatterND); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ScatterND, 16, + 17, + KernelDefBuilder() + .TypeConstraint("T", + BuildKernelDefConstraintsFromTypeList()), + ScatterND); + +ONNX_CPU_OPERATOR_KERNEL( + ScatterND, + 18, KernelDefBuilder() .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), @@ -263,6 +272,84 @@ struct Func_Mul_ND { } }; +template +struct Func_Min_ND { + void operator()(T* a, const T* b, uint64_t element_to_copy) const { + while (element_to_copy-- > 0) { + (*a) = (*a) < (*b) ? (*a) : (*b); + a++; + b++; + } + } +}; + +template <> +struct Func_Min_ND { + void operator()(bool*, const bool*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'min'."); + } +}; + +template <> +struct Func_Min_ND { + void operator()(std::string*, const std::string*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'min'."); + } +}; + +template <> +struct Func_Min_ND { + void operator()(MLFloat16*, const MLFloat16*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'."); + } +}; + +template <> +struct Func_Min_ND { + void operator()(BFloat16*, const BFloat16*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'."); + } +}; + +template +struct Func_Max_ND { + void operator()(T* a, const T* b, uint64_t element_to_copy) const { + while (element_to_copy-- > 0) { + (*a) = (*a) > (*b) ? (*a) : (*b); + a++; + b++; + } + } +}; + +template <> +struct Func_Max_ND { + void operator()(bool*, const bool*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'max'."); + } +}; + +template <> +struct Func_Max_ND { + void operator()(std::string*, const std::string*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'max'."); + } +}; + +template <> +struct Func_Max_ND { + void operator()(MLFloat16*, const MLFloat16*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'."); + } +}; + +template <> +struct Func_Max_ND { + void operator()(BFloat16*, const BFloat16*, uint64_t) const { + ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'."); + } +}; + template struct ScatterNDDispatchTarget { Status operator()(OpKernelContext* context, concurrency::ThreadPool* tp, ScatterND::Reduction reduction) const { @@ -285,6 +372,20 @@ struct ScatterNDDispatchTarget { prepare.input_base + i * prepare.element_to_copy, prepare.element_to_copy); } break; + case ScatterND::Reduction::Min: { + auto func = Func_Min_ND(); + func( + prepare.output_base + prepare.element_offsets[onnxruntime::narrow(i)], + prepare.input_base + i * prepare.element_to_copy, + prepare.element_to_copy); + } break; + case ScatterND::Reduction::Max: { + auto func = Func_Max_ND(); + func( + prepare.output_base + prepare.element_offsets[onnxruntime::narrow(i)], + prepare.input_base + i * prepare.element_to_copy, + prepare.element_to_copy); + } break; default: case ScatterND::Reduction::None: { auto func = Func_Copy_ND(); diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h index 8ccdcb7d93..263f41d197 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h @@ -18,7 +18,9 @@ class ScatterND final : public OpKernel { enum class Reduction : int { None = 0, Add, - Mul + Mul, + Min, + Max, }; explicit ScatterND(const OpKernelInfo& info) : OpKernel(info) { @@ -30,6 +32,10 @@ class ScatterND final : public OpKernel { reduction_ = Reduction::Add; else if (reduction == "mul") reduction_ = Reduction::Mul; + else if (reduction == "min") + reduction_ = Reduction::Min; + else if (reduction == "max") + reduction_ = Reduction::Max; } } diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index d91511e7d9..c91b616aa5 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -104,8 +104,6 @@ "^test_identity_opt", // Following tests are for opset 16 ops and are not yet implemented in ORT "^test_roialign_aligned_*", - "^test_scatternd_*", - "^test_scatter_elements_with_duplicate_indices", //GPU failures "^test_batchnorm_epsilon_training_mode_cuda", "^test_batchnorm_example_training_mode_cuda", @@ -127,7 +125,6 @@ "^test_constant_pad_cpu", "^test_edge_pad_cpu", "^test_reflect_pad_cpu", - "^test_scatter_elements_*", "^test_softplus_example_expanded_cpu", "^test_softplus_expanded_cpu", "^test_split_*", @@ -312,6 +309,9 @@ "^test_reduce_sum_do_not_keepdims_example", "^test_reduce_sum_do_not_keepdims_random", "^test_scatter_elements_with_negative_indices", + "^test_scatter_elements_with_duplicate_indices*", + "^test_scatternd_add*", + "^test_scatternd_multiply*", "^test_squeeze", "^test_squeeze_negative_axes" ], @@ -332,6 +332,9 @@ "^test_reduce_sum_default_axes_keepdims*", "^test_reduce_sum_empty_axes_input_noop*", "^test_scatter_elements_with_negative_indices_cpu", + "^test_scatter_elements_with_duplicate_indices*", + "^test_scatternd_add*", + "^test_scatternd_multiply*", "^test_squeeze_*", // Does not support axes as input "^test_pow_types_float", // Runs disabled pow tests from the "current_failing_tests" list at the top "^test_loop11",