From 175acf08f470db0bb2e4b8eefe55cdeb87c8b132 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Tue, 30 Nov 2021 21:17:32 -0800 Subject: [PATCH] ScatterND supports negative indices (#9739) * ScatterND supports negative indices --- .../core/providers/cpu/tensor/scatter_nd.cc | 17 +++++++--- .../providers/cuda/tensor/scatter_nd_impl.cu | 16 +++++++--- .../cpu/tensor/scatter_nd_op_test.cc | 31 +++++++++++++++++-- .../python/orttraining_test_ortmodule_api.py | 28 +++++++++++++++++ 4 files changed, 81 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc index 2b3ee69561..b5a19fd44c 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc @@ -131,7 +131,6 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co element_counts[i] = input_strides[i]; } - int64_t err_indice = 0; p.element_bytes = input_tensor->DataType()->Size(); p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); p.bytes_to_copy = p.element_bytes * p.element_to_copy; @@ -150,13 +149,23 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co for (int64_t i = 0; i < offset_count; ++i) { for (int64_t j = 0; j < last_indice_dimension; ++j) { auto indice = *(indice_offset + i * last_indice_dimension + j); - if (indice < 0 || indice >= input_shape[j]) { - err_indice = indice; + + if (indice >= 0) { + if (indice >= input_shape[j]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice); + } + } else { + if (indice < -input_shape[j]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice); + } else { + indice += input_shape[j]; + } } + p.element_offsets[i] += indice * element_counts[j]; } } - return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); + return Status::OK(); } Status ScatterND::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu index 0651049a5f..e9199b5e1b 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu @@ -34,11 +34,19 @@ __global__ void _ScatterNDKernel( // This would have been an error in the CPU kernel, but throwing in the CUDA EP // is hard. This is the approach taken by other frameworks for out of bound indices // in their corresponding GPU backends as well. - if (index < 0) - index = 0; + // index >= -dim_value && index < dim_value - else if (index >= dim_value) - index = dim_value - 1; + if (index >= 0) { + if (index >= dim_value) { + index = dim_value - 1; + } + } else { + if (index < -dim_value) { + index = 0; + } else { + index += dim_value; + } + } data_offset += (index * element_count_dim); } diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc index 145fbacdb3..a3d943ea78 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc @@ -39,6 +39,15 @@ TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64) { test.Run(); } +TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64_neg_indices) { + OpTester test("ScatterND", 11); + test.AddInput ("data", {2,2}, {1LL,1LL,2LL,2LL}); + test.AddInput ("indices", {2,2}, {0LL,0LL,-1LL,-1LL}); + test.AddInput("updates", {2}, {0LL,3LL}); + test.AddOutput("output", {2,2}, {0LL,1LL,2LL,3LL}); + test.Run(); +} + TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) { OpTester test1("ScatterND", 11); test1.AddInput("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"}); @@ -55,6 +64,22 @@ TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) { test2.Run(); } +TEST(ScatterNDOpTest, ScatterND_matrice_string_int64_neg_indices) { + OpTester test1("ScatterND", 11); + test1.AddInput("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"}); + test1.AddInput("indices", {2,1,2}, {0,-1,-1,0}); + test1.AddInput("updates", {2,1,2}, {"air","bob","terry","smart"}); + test1.AddOutput("output", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"}); + test1.Run(); + + OpTester test2("ScatterND", 11); + test2.AddInput("data", {3,3}, {"egg","","air","","terry","smart","laugh","","hop"}); + test2.AddInput("indices", {3,2}, {-1,-2,1,0,0,-2}); + test2.AddInput("updates", {3}, {"kite","bob","dance"}); + test2.AddOutput("output", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"}); + test2.Run(); +} + TEST(ScatterNDOpTest, ScatterND_slice_float_int64_t) { OpTester test("ScatterND", 11); test.AddInput("data", {2,2}, {0.0f,0.1f,0.1f,0.1f}); @@ -76,14 +101,14 @@ TEST(ScatterNDOpTest, ScatterND_slice_double_int64_t) { TEST(ScatterNDOpTest, ScatterND_3tensor_int64) { OpTester test1("ScatterND", 11); test1.AddInput("data", {2,2,2}, {0LL,1LL,1LL,1LL,1LL,1LL,6LL,7LL}); - test1.AddInput("indices", {2,2}, {0LL,1LL,1LL,0LL}); + test1.AddInput("indices", {2,2}, {0LL,1LL,-1LL,0LL}); test1.AddInput("updates", {2,2}, {2LL,3LL,4LL,5LL}); test1.AddOutput("output", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL}); test1.Run(); OpTester test2("ScatterND", 11); test2.AddInput("data", {2,2,2}, {0,0,2,3,4,0,6,7}); - test2.AddInput("indices", {2,3}, {0,0,1,1,0,1}); + test2.AddInput("indices", {2,3}, {0,0,1,-1,0,-1}); test2.AddInput("updates", {2}, {1,5}); test2.AddOutput("output", {2,2,2}, {0,1,2,3,4,5,6,7}); test2.Run(); @@ -142,7 +167,7 @@ TEST(ScatterNDOpTest, ScatterND_batched_3tensor_int64) { OpTester test2("ScatterND", 11); test2.AddInput("data", {2,2,2}, {0,0,2,0,4,0,0,7}); - test2.AddInput("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0}); + test2.AddInput("indices", {2,2,3}, {0,0,-1,-1,0,-1,0,1,-1,1,-1,0}); test2.AddInput("updates", {2,2}, {1,5,3,6}); test2.AddOutput("output", {2,2,2}, {0,1,2,3,4,5,6,7}); test2.Run(); diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 121775d695..7bb0e675b6 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -658,6 +658,34 @@ def test_gradient_correctness(): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) +@pytest.mark.parametrize("device", ['cpu', 'cuda']) +@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]], + [[ 2, 3, 4, 4],[ 0, 1, 4, 4]])) +def test_scatternd_correctness(device, indices): + class NeuralNetScatterND(torch.nn.Module): + def __init__(self): + super(NeuralNetScatterND, self).__init__() + + def forward(self, rerouted_output, dispatch_mask, expert_output): + rerouted_output[dispatch_mask] = expert_output + return rerouted_output + + pt_model = NeuralNetScatterND().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, rerouted_output, dispatch_mask, expert_output): + prediction = model(rerouted_output, dispatch_mask, expert_output) + return prediction + + rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device) + dispatch_mask = torch.tensor(indices, device=device) + expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device) + + pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output) + ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output) + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5) + + @pytest.mark.parametrize("use_fp16", [False, True]) @pytest.mark.parametrize("input_requires_grad", [False, True]) def test_gradient_correctness_conv1d(use_fp16, input_requires_grad):