ScatterND supports negative indices (#9739)

* ScatterND supports negative indices
This commit is contained in:
Sherlock 2021-11-30 21:17:32 -08:00 committed by GitHub
parent c161813217
commit 175acf08f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 81 additions and 11 deletions

View file

@ -131,7 +131,6 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
element_counts[i] = input_strides[i];
}
int64_t err_indice = 0;
p.element_bytes = input_tensor->DataType()->Size();
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
@ -150,13 +149,23 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
for (int64_t i = 0; i < offset_count; ++i) {
for (int64_t j = 0; j < last_indice_dimension; ++j) {
auto indice = *(indice_offset + i * last_indice_dimension + j);
if (indice < 0 || indice >= input_shape[j]) {
err_indice = indice;
if (indice >= 0) {
if (indice >= input_shape[j]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice);
}
} else {
if (indice < -input_shape[j]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice);
} else {
indice += input_shape[j];
}
}
p.element_offsets[i] += indice * element_counts[j];
}
}
return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
return Status::OK();
}
Status ScatterND::Compute(OpKernelContext* context) const {

View file

@ -34,11 +34,19 @@ __global__ void _ScatterNDKernel(
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
// is hard. This is the approach taken by other frameworks for out of bound indices
// in their corresponding GPU backends as well.
if (index < 0)
index = 0;
// index >= -dim_value && index < dim_value
else if (index >= dim_value)
index = dim_value - 1;
if (index >= 0) {
if (index >= dim_value) {
index = dim_value - 1;
}
} else {
if (index < -dim_value) {
index = 0;
} else {
index += dim_value;
}
}
data_offset += (index * element_count_dim);
}

View file

@ -39,6 +39,15 @@ TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64) {
test.Run();
}
TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64_neg_indices) {
OpTester test("ScatterND", 11);
test.AddInput<int64_t> ("data", {2,2}, {1LL,1LL,2LL,2LL});
test.AddInput<int64_t> ("indices", {2,2}, {0LL,0LL,-1LL,-1LL});
test.AddInput<int64_t>("updates", {2}, {0LL,3LL});
test.AddOutput<int64_t>("output", {2,2}, {0LL,1LL,2LL,3LL});
test.Run();
}
TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) {
OpTester test1("ScatterND", 11);
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"});
@ -55,6 +64,22 @@ TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) {
test2.Run();
}
TEST(ScatterNDOpTest, ScatterND_matrice_string_int64_neg_indices) {
OpTester test1("ScatterND", 11);
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"});
test1.AddInput<int64_t>("indices", {2,1,2}, {0,-1,-1,0});
test1.AddInput<std::string>("updates", {2,1,2}, {"air","bob","terry","smart"});
test1.AddOutput<std::string>("output", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"});
test1.Run();
OpTester test2("ScatterND", 11);
test2.AddInput<std::string>("data", {3,3}, {"egg","","air","","terry","smart","laugh","","hop"});
test2.AddInput<int64_t>("indices", {3,2}, {-1,-2,1,0,0,-2});
test2.AddInput<std::string>("updates", {3}, {"kite","bob","dance"});
test2.AddOutput<std::string>("output", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"});
test2.Run();
}
TEST(ScatterNDOpTest, ScatterND_slice_float_int64_t) {
OpTester test("ScatterND", 11);
test.AddInput<float>("data", {2,2}, {0.0f,0.1f,0.1f,0.1f});
@ -76,14 +101,14 @@ TEST(ScatterNDOpTest, ScatterND_slice_double_int64_t) {
TEST(ScatterNDOpTest, ScatterND_3tensor_int64) {
OpTester test1("ScatterND", 11);
test1.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,1LL,1LL,1LL,1LL,6LL,7LL});
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,1LL,0LL});
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,-1LL,0LL});
test1.AddInput<int64_t>("updates", {2,2}, {2LL,3LL,4LL,5LL});
test1.AddOutput<int64_t>("output", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
test1.Run();
OpTester test2("ScatterND", 11);
test2.AddInput<int8_t>("data", {2,2,2}, {0,0,2,3,4,0,6,7});
test2.AddInput<int64_t>("indices", {2,3}, {0,0,1,1,0,1});
test2.AddInput<int64_t>("indices", {2,3}, {0,0,1,-1,0,-1});
test2.AddInput<int8_t>("updates", {2}, {1,5});
test2.AddOutput<int8_t>("output", {2,2,2}, {0,1,2,3,4,5,6,7});
test2.Run();
@ -142,7 +167,7 @@ TEST(ScatterNDOpTest, ScatterND_batched_3tensor_int64) {
OpTester test2("ScatterND", 11);
test2.AddInput<uint32_t>("data", {2,2,2}, {0,0,2,0,4,0,0,7});
test2.AddInput<int64_t>("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0});
test2.AddInput<int64_t>("indices", {2,2,3}, {0,0,-1,-1,0,-1,0,1,-1,1,-1,0});
test2.AddInput<uint32_t>("updates", {2,2}, {1,5,3,6});
test2.AddOutput<uint32_t>("output", {2,2,2}, {0,1,2,3,4,5,6,7});
test2.Run();

View file

@ -658,6 +658,34 @@ def test_gradient_correctness():
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
[[ 2, 3, 4, 4],[ 0, 1, 4, 4]]))
def test_scatternd_correctness(device, indices):
class NeuralNetScatterND(torch.nn.Module):
def __init__(self):
super(NeuralNetScatterND, self).__init__()
def forward(self, rerouted_output, dispatch_mask, expert_output):
rerouted_output[dispatch_mask] = expert_output
return rerouted_output
pt_model = NeuralNetScatterND().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, rerouted_output, dispatch_mask, expert_output):
prediction = model(rerouted_output, dispatch_mask, expert_output)
return prediction
rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device)
dispatch_mask = torch.tensor(indices, device=device)
expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device)
pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output)
ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5)
@pytest.mark.parametrize("use_fp16", [False, True])
@pytest.mark.parametrize("input_requires_grad", [False, True])
def test_gradient_correctness_conv1d(use_fp16, input_requires_grad):