From 0c91b643fed7f6be75ec2d7c75e9c4a26738e86f Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 7 May 2021 14:02:26 +0800 Subject: [PATCH] Bugfix for Scatter and GatherElementsGrad (#7593) * bugfix for scatter and gather elements grad * resolve comments --- .../core/providers/cpu/tensor/gather_elements.cc | 2 +- onnxruntime/core/providers/cpu/tensor/scatter.cc | 4 +++- .../providers/cuda/tensor/scatter_elements.cc | 4 +++- .../test/providers/cpu/tensor/scatter_op_test.cc | 16 ++++++++++++++++ .../test/gradient/gradient_ops_test.cc | 14 ++++++++++++++ .../cuda/gather_elements_grad_test.cc | 11 +++++++++++ .../cpu/tensor/gather_elements_grad.cc | 4 +++- .../cuda/tensor/gather_elements_grad.cc | 4 +++- 8 files changed, 54 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/gather_elements.cc b/onnxruntime/core/providers/cpu/tensor/gather_elements.cc index 73782d7afd..b09148440d 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_elements.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_elements.cc @@ -261,7 +261,7 @@ Status GatherElements::ValidateInputShapes(const TensorShape& input_data_shape, for (int64_t i = 0; i < indices_rank; ++i) { // for all axes except the axis of interest, // make sure that the corresponding 'indices' shape - // value if within bounds of the corresponding 'data' shape + // value is within bounds of the corresponding 'data' shape if (i != axis) { if (indices_shape[i] < 0 || indices_shape[i] > input_data_shape[i]) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 5814eaa5f1..8ab616caf1 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -279,7 +279,9 @@ Status Scatter::Compute(OpKernelContext* context) const { } for (size_t i = 0; i < input_dims.size(); ++i) { - if (input_dims[i] < indices_dims[i]) { + // For all axes except the axis of interest, make sure that the corresponding 'indices' shape + // value is within bounds of the corresponding 'data' shape. + if (static_cast(i) != axis_ && input_dims[i] < indices_dims[i]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i, " is greater than input dim=", input_dims[i]); } diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index 1536efe168..8142888402 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -143,7 +143,9 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { } for (size_t i = 0; i < input_dims.size(); ++i) { - if (input_dims[i] < indices_dims[i]) { + // For all axes except the axis of interest, make sure that the corresponding 'indices' shape + // value is within bounds of the corresponding 'data' shape. + if (static_cast(i) != axis_ && input_dims[i] < indices_dims[i]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i, " is greater than input dim=", input_dims[i]); } diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index cb95e77333..9b9c32074d 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -278,5 +278,21 @@ TEST(Scatter, SameUpdateWithoutAxis) { scatter_same_updates_tests("ScatterElements", 11); } +static void scatter_with_larger_indices_on_axis_tests(const char* op_name, int op_version) { + OpTester test(op_name, op_version); + test.AddAttribute("axis", 1); + + test.AddInput("data", {1, 2}, {1.0f, 2.0f}); + test.AddInput("indices", {1, 4}, {0, 0, 0, 0}); + test.AddInput("updates", {1, 4}, {3.0f, 3.0f, 3.0f, 3.0f}); + test.AddOutput("y", {1, 2}, {3.0f, 2.0f}); + test.Run(); +} + +TEST(Scatter, LargerIndicesOnAxis) { + scatter_with_larger_indices_on_axis_tests("Scatter", 9); + scatter_with_larger_indices_on_axis_tests("ScatterElements", 11); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 6bbf6fa28a..fa59607cff 100755 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -2385,6 +2385,20 @@ TEST(GradientCheckerTest, GatherElementsGrad) { {MakeAttribute("axis", axis)}); EXPECT_IS_TINY(max_error); } + + { + // GatherElementsGradWithLargerIndiceOnAxis + TensorInfo data_info({2, 2}, true); + TensorInfo indice_info({2, 4}, false, nullptr, DataTypeImpl::GetTensorType()); + std::vector> x_datas = {{1, 2, 3, 4}, {1, 1, 1, 1, 1, 1, 1, 1}}; + + TensorInfo y_info({2, 4}, true); + int64_t axis = 1; + + gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axis", axis)}); + EXPECT_IS_TINY(max_error); + } } TEST(GradientCheckerTest, TopKGrad) { diff --git a/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc index 88a41ca884..c949badfb8 100644 --- a/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/gather_elements_grad_test.cc @@ -222,6 +222,17 @@ TEST(GatherElementsGrad, SameUpdateWithoutAxisMLFloat16) { test.Run(); } +TEST(GatherElementsGrad, LargerIndicesOnAxis) { + onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain); + test.AddAttribute("axis", 1); + test.AddInput("dY", {1, 4}, {1.1f, 2.2f, 3.3f, 4.4f}); + std::vector data_shape = {1, 2}; + test.AddInput("data_shape", {2}, data_shape); + test.AddInput("indices", {1, 4}, {0, 1, 0, 1}); + test.AddOutput("dX", {1, 2}, {4.4f, 6.6f}); + test.Run(); +} + } // namespace test } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.cc b/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.cc index d4dc79de8e..bebe7627a8 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.cc @@ -64,7 +64,9 @@ Status GatherElementsGrad::Compute(OpKernelContext* context) const { } for (size_t i = 0; i < output_dims.size(); ++i) { - if (output_dims[i] < indices_dims[i]) { + // For all axes except the axis of interest, make sure that the corresponding 'indices' shape + // value is within bounds of the corresponding 'data' shape. + if (static_cast(i) != axis_ && output_dims[i] < indices_dims[i]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i, " is greater than Output dim=", output_dims[i]); } diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc index dae18fc35f..2a7c32ea1d 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_elements_grad.cc @@ -108,7 +108,9 @@ Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const { } for (size_t i = 0; i < output_dims.size(); ++i) { - if (output_dims[i] < indices_dims[i]) { + // For all axes except the axis of interest, make sure that the corresponding 'indices' shape + // value is within bounds of the corresponding 'data' shape. + if (static_cast(i) != axis_ && output_dims[i] < indices_dims[i]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i, " is greater than Output dim=", output_dims[i]); }