support ScatterND(18) and ScatterElement(18) (#14224)

This commit is contained in:
liqun Fu 2023-01-19 13:54:20 -08:00 committed by GitHub
parent d2c3d8eb38
commit 5d6a049141
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 218 additions and 15 deletions

View file

@ -304,10 +304,12 @@ Do not modify directly.*
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8|**I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float)|

View file

@ -765,8 +765,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, If
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterND);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, Where);
@ -830,6 +830,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split);
#if !defined(DISABLE_OPTIONAL_TYPE)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalHasElement);
@ -2018,8 +2020,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float,
GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, Where)>,
@ -2128,11 +2130,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Split)>,
#if !defined(DISABLE_OPTIONAL_TYPE)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalHasElement)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, OptionalGetElement)>,
#endif
#endif
};
for (auto& function_table_entry : function_table) {

View file

@ -86,9 +86,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
Scatter<EnabledScatterElementsDataTypes>);
ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ScatterElements,
16,
17,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<EnabledScatterElementsDataTypes>())
.TypeConstraint("Tind", BuildKernelDefConstraints<int32_t, int64_t>()),
Scatter<EnabledScatterElementsDataTypes>);
ONNX_CPU_OPERATOR_KERNEL(
ScatterElements,
18,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T",
@ -166,6 +177,76 @@ struct Func_Mul<BFloat16> {
}
};
template <class T>
struct Func_Min {
void operator()(T* a, const T* b) const {
(*a) = (*a) < (*b) ? (*a) : (*b);
}
};
template <>
struct Func_Min<bool> {
void operator()(bool*, const bool*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min<std::string> {
void operator()(std::string*, const std::string*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};
template <class T>
struct Func_Max {
void operator()(T* a, const T* b) const {
(*a) = (*a) > (*b) ? (*a) : (*b);
}
};
template <>
struct Func_Max<bool> {
void operator()(bool*, const bool*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max<std::string> {
void operator()(std::string*, const std::string*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};
template <class TIndex>
Status GetIndices(
const Tensor& data_input, const Tensor& indices_input, int64_t axis,
@ -317,7 +398,13 @@ struct ScatterDataDispatchTarget {
else if(reduction == "mul")
return ScatterData<TData>(
Func_Mul<TData>(), data_input, indices_data, updates_input, axis, data_output);
else // if (reduction == "none")
else if (reduction == "min")
return ScatterData<TData>(
Func_Min<TData>(), data_input, indices_data, updates_input, axis, data_output);
else if (reduction == "max")
return ScatterData<TData>(
Func_Max<TData>(), data_input, indices_data, updates_input, axis, data_output);
else // if (reduction == "none")
return ScatterData<TData>(
Func_Assignment<TData>(), data_input, indices_data, updates_input, axis, data_output);
}

View file

@ -38,9 +38,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
BuildKernelDefConstraintsFromTypeList<EnabledScatterNDDataTypes>()),
ScatterND);
ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ScatterND,
16,
17,
KernelDefBuilder()
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<EnabledScatterNDDataTypes>()),
ScatterND);
ONNX_CPU_OPERATOR_KERNEL(
ScatterND,
18,
KernelDefBuilder()
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<EnabledScatterNDDataTypes>()),
@ -263,6 +272,84 @@ struct Func_Mul_ND<BFloat16> {
}
};
template <class T>
struct Func_Min_ND {
void operator()(T* a, const T* b, uint64_t element_to_copy) const {
while (element_to_copy-- > 0) {
(*a) = (*a) < (*b) ? (*a) : (*b);
a++;
b++;
}
}
};
template <>
struct Func_Min_ND<bool> {
void operator()(bool*, const bool*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min_ND<std::string> {
void operator()(std::string*, const std::string*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min_ND<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min_ND<BFloat16> {
void operator()(BFloat16*, const BFloat16*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'min'.");
}
};
template <class T>
struct Func_Max_ND {
void operator()(T* a, const T* b, uint64_t element_to_copy) const {
while (element_to_copy-- > 0) {
(*a) = (*a) > (*b) ? (*a) : (*b);
a++;
b++;
}
}
};
template <>
struct Func_Max_ND<bool> {
void operator()(bool*, const bool*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: bool data type is not supported with ScatterND opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max_ND<std::string> {
void operator()(std::string*, const std::string*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: string data type is not supported with ScatterND opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max_ND<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max_ND<BFloat16> {
void operator()(BFloat16*, const BFloat16*, uint64_t) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: BFloat16 data type is not supported with ScatterND opset 18 when reduction is 'max'.");
}
};
template <typename TData>
struct ScatterNDDispatchTarget {
Status operator()(OpKernelContext* context, concurrency::ThreadPool* tp, ScatterND::Reduction reduction) const {
@ -285,6 +372,20 @@ struct ScatterNDDispatchTarget {
prepare.input_base + i * prepare.element_to_copy,
prepare.element_to_copy);
} break;
case ScatterND::Reduction::Min: {
auto func = Func_Min_ND<TData>();
func(
prepare.output_base + prepare.element_offsets[onnxruntime::narrow<size_t>(i)],
prepare.input_base + i * prepare.element_to_copy,
prepare.element_to_copy);
} break;
case ScatterND::Reduction::Max: {
auto func = Func_Max_ND<TData>();
func(
prepare.output_base + prepare.element_offsets[onnxruntime::narrow<size_t>(i)],
prepare.input_base + i * prepare.element_to_copy,
prepare.element_to_copy);
} break;
default:
case ScatterND::Reduction::None: {
auto func = Func_Copy_ND<TData>();

View file

@ -18,7 +18,9 @@ class ScatterND final : public OpKernel {
enum class Reduction : int {
None = 0,
Add,
Mul
Mul,
Min,
Max,
};
explicit ScatterND(const OpKernelInfo& info) : OpKernel(info) {
@ -30,6 +32,10 @@ class ScatterND final : public OpKernel {
reduction_ = Reduction::Add;
else if (reduction == "mul")
reduction_ = Reduction::Mul;
else if (reduction == "min")
reduction_ = Reduction::Min;
else if (reduction == "max")
reduction_ = Reduction::Max;
}
}

View file

@ -104,8 +104,6 @@
"^test_identity_opt",
// Following tests are for opset 16 ops and are not yet implemented in ORT
"^test_roialign_aligned_*",
"^test_scatternd_*",
"^test_scatter_elements_with_duplicate_indices",
//GPU failures
"^test_batchnorm_epsilon_training_mode_cuda",
"^test_batchnorm_example_training_mode_cuda",
@ -127,7 +125,6 @@
"^test_constant_pad_cpu",
"^test_edge_pad_cpu",
"^test_reflect_pad_cpu",
"^test_scatter_elements_*",
"^test_softplus_example_expanded_cpu",
"^test_softplus_expanded_cpu",
"^test_split_*",
@ -312,6 +309,9 @@
"^test_reduce_sum_do_not_keepdims_example",
"^test_reduce_sum_do_not_keepdims_random",
"^test_scatter_elements_with_negative_indices",
"^test_scatter_elements_with_duplicate_indices*",
"^test_scatternd_add*",
"^test_scatternd_multiply*",
"^test_squeeze",
"^test_squeeze_negative_axes"
],
@ -332,6 +332,9 @@
"^test_reduce_sum_default_axes_keepdims*",
"^test_reduce_sum_empty_axes_input_noop*",
"^test_scatter_elements_with_negative_indices_cpu",
"^test_scatter_elements_with_duplicate_indices*",
"^test_scatternd_add*",
"^test_scatternd_multiply*",
"^test_squeeze_*", // Does not support axes as input
"^test_pow_types_float", // Runs disabled pow tests from the "current_failing_tests" list at the top
"^test_loop11",