From d3cb2a5572b879441debcaac99da86d706f28869 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 24 Sep 2019 17:11:01 -0700 Subject: [PATCH] Add ScatterElements CPU kernel (#1796) * Support ScatterElements CPU kernel * Nits and remove test exclusion * PR feedback * Fix bug * Fix test * Remove unused variable * PR comments --- .../providers/cpu/cpu_execution_provider.cc | 6 +- .../core/providers/cpu/tensor/scatter.cc | 37 +++++-- onnxruntime/test/onnx/main.cc | 2 - .../providers/cpu/tensor/scatter_op_test.cc | 102 ++++++++++++++---- .../test/python/onnx_backend_test_series.py | 1 - 5 files changed, 114 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 43f6fadae4..b615d3ae27 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -251,7 +251,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Asi class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scan); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scatter); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Scatter); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, TfIdfVectorizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, TfIdfVectorizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, TfIdfVectorizer); @@ -317,6 +317,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, So class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements); void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -553,7 +554,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -619,6 +620,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 403ef62ec3..7b96e80464 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -10,10 +10,11 @@ namespace onnxruntime { class Scatter final : public OpKernel { public: - Scatter(const OpKernelInfo& info) : OpKernel(info) { + explicit Scatter(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); } + ~Scatter() = default; Status Compute(OpKernelContext* context) const override; @@ -21,9 +22,18 @@ class Scatter final : public OpKernel { int64_t axis_; }; -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Scatter, - 9, + 9, 10, + KernelDefBuilder() + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + Scatter); + +ONNX_CPU_OPERATOR_KERNEL( + ScatterElements, + 11, KernelDefBuilder() .MayInplace(0, 0) .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) @@ -34,14 +44,25 @@ template Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input, const int64_t axis, Tensor* data_output) { const TensorShape& input_data_shape = data_input->Shape(); - const Tin* indices_data = indices_input->template Data(); + const Tin* indices_data_raw= indices_input->template Data(); const auto num_indices = indices_input->Shape().Size(); + + std::vector indices_data; + indices_data.reserve(num_indices); + + auto axis_dim_limit = input_data_shape[axis]; + for (int64_t i = 0; i < num_indices; ++i) { - Tin idx = indices_data[i]; - if (idx < 0 || idx >= input_data_shape[axis]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices element out of data bounds, idx=", idx, - " data_dim=", input_data_shape[axis]); + Tin idx = indices_data_raw[i]; + + if (idx < -axis_dim_limit || idx >= axis_dim_limit) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices element out of data bounds, idx=", idx, + " must be within the inclusive range [", -axis_dim_limit, + ",", axis_dim_limit - 1, "]"); } + + indices_data.push_back(idx < 0 ? idx + static_cast(axis_dim_limit) : idx); } const auto input_elements = input_data_shape.Size(); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 41cd6f4bf1..1262439acd 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -406,8 +406,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"unique_sorted_axis_3d", "Unique not implemented yet"}, {"unique_sorted_axis", "Unique not implemented yet"}, {"unique_sorted_with_negative_axis", "Unique not implemented yet"}, - {"scatter_elements_with_axis", "not implemented yet"}, - {"scatter_elements_without_axis", "not implemented yet"}, {"round", "not implemented yet"}, {"gather_elements_1", "not implemented yet"}, {"gather_elements_0", "not implemented yet"}, diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index fc9d72a07d..7dd0337074 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -7,10 +7,8 @@ namespace onnxruntime { namespace test { -const int Scatter_ver = 9; - -TEST(ScatterOpTest, WithoutAxis) { - OpTester test("Scatter", Scatter_ver); +static void scatter_without_axis_tests(const char* op_name, int op_version) { + OpTester test(op_name, op_version); std::vector input; input.resize(3 * 3); @@ -32,8 +30,13 @@ TEST(ScatterOpTest, WithoutAxis) { test.Run(); } -TEST(ScatterOpTest, WithAxis) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, WithoutAxis) { + scatter_without_axis_tests("Scatter", 9); + scatter_without_axis_tests("ScatterElements", 11); +} + +static void scatter_with_axis_tests(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 1); test.AddInput("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); @@ -43,8 +46,13 @@ TEST(ScatterOpTest, WithAxis) { test.Run(); } -TEST(ScatterOpTest, WithAxisThreeDims) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, WithAxis) { + scatter_with_axis_tests("Scatter", 9); + scatter_with_axis_tests("ScatterElements", 11); +} + +static void scatter_three_dim_with_axis_0(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 0); test.AddInput("data", {1, 3, 3}, @@ -67,8 +75,13 @@ TEST(ScatterOpTest, WithAxisThreeDims) { test.Run(); } -TEST(ScatterOpTest, ThreeDimsWithAxisGE_1) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, ThreeDimsWithAxis_0) { + scatter_three_dim_with_axis_0("Scatter", 9); + scatter_three_dim_with_axis_0("ScatterElements", 11); +} + +static void scatter_three_dim_with_axis_2(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 2); test.AddInput("data", {1, 3, 3}, @@ -92,8 +105,13 @@ TEST(ScatterOpTest, ThreeDimsWithAxisGE_1) { test.Run(); } -TEST(ScatterOpTest, WithAxisStrings) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, ThreeDimsWithAxis_2) { + scatter_three_dim_with_axis_2("Scatter", 9); + scatter_three_dim_with_axis_2("ScatterElements", 11); +} + +static void scatter_string(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 1); test.AddInput("data", {1, 5}, {"1.0f", "2.0f", "3.0f", "4.0f", "5.0f"}); @@ -103,8 +121,13 @@ TEST(ScatterOpTest, WithAxisStrings) { test.Run(); } -TEST(ScatterOpTest, NegativeAxis) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, String) { + scatter_string("Scatter", 9); + scatter_string("ScatterElements", 11); +} + +static void scatter_negative_axis(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", -1); test.AddInput("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); @@ -114,8 +137,13 @@ TEST(ScatterOpTest, NegativeAxis) { test.Run(); } -TEST(ScatterOpTest, IndicesUpdatesDimsDonotMatch) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, NegativeAxis) { + scatter_negative_axis("Scatter", 9); + scatter_negative_axis("ScatterElements", 11); +} + +static void scatter_indices_updates_dont_match(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 1); test.AddInput("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); @@ -125,8 +153,13 @@ TEST(ScatterOpTest, IndicesUpdatesDimsDonotMatch) { test.Run(OpTester::ExpectResult::kExpectFailure, "Indices vs updates dimensions differs at position=1 3 vs 2"); } -TEST(ScatterOpTest, ValidIndex) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, IndicesUpdatesDontMatch) { + scatter_indices_updates_dont_match("Scatter", 9); + scatter_indices_updates_dont_match("ScatterElements", 11); +} + +static void scatter_valid_index(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 0); test.AddInput("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); @@ -136,15 +169,42 @@ TEST(ScatterOpTest, ValidIndex) { test.Run(); } -TEST(ScatterOpTest, InvalidIndex) { - OpTester test("Scatter", Scatter_ver); +TEST(Scatter, ValidAxis) { + scatter_valid_index("Scatter", 9); + scatter_valid_index("ScatterElements", 11); +} + +static void scatter_invalid_index(const char* op_name, int op_version) { + OpTester test(op_name, op_version); test.AddAttribute("axis", 0); test.AddInput("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); test.AddInput("indices", {1, 1, 1}, {4}); test.AddInput("updates", {1, 1, 1}, {5.0f}); test.AddOutput("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); - test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 data_dim=4"); + test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]"); } + +TEST(Scatter, InvalidIndex) { + scatter_invalid_index("Scatter", 9); + scatter_invalid_index("ScatterElements", 11); +} + +static void scatter_valid_negative_index(const char* op_name, int op_version) { + OpTester test(op_name, op_version); + test.AddAttribute("axis", 0); + + test.AddInput("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddInput("indices", {1, 1, 1}, {-1}); + test.AddInput("updates", {1, 1, 1}, {5.0f}); + test.AddOutput("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); + test.Run(); +} + +TEST(Scatter, ValidNegativeIndex) { + scatter_valid_negative_index("Scatter", 9); + scatter_valid_negative_index("ScatterElements", 11); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index b923003ed6..7805a79bd1 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -115,7 +115,6 @@ def create_backend_test(testname=None): '^test_dynamicquantizelinear_max_adjusted_expanded*', '^test_dynamicquantizelinear_min_adjusted_expanded*', '^test_gather_elements*', - '^test_scatter_elements*', '^test_top_k*', '^test_unique_*', '^test_mod_float_mixed_sign_example_cpu.*', #onnxruntime::Mod::Compute fmod_ was false. fmod attribute must be true for float, float16 and double types