mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-14 01:13:40 +00:00
ScatterND supports negative indices (#9739)
* ScatterND supports negative indices
This commit is contained in:
parent
c161813217
commit
175acf08f4
4 changed files with 81 additions and 11 deletions
|
|
@ -131,7 +131,6 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
|
|||
element_counts[i] = input_strides[i];
|
||||
}
|
||||
|
||||
int64_t err_indice = 0;
|
||||
p.element_bytes = input_tensor->DataType()->Size();
|
||||
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
|
||||
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
|
||||
|
|
@ -150,13 +149,23 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
|
|||
for (int64_t i = 0; i < offset_count; ++i) {
|
||||
for (int64_t j = 0; j < last_indice_dimension; ++j) {
|
||||
auto indice = *(indice_offset + i * last_indice_dimension + j);
|
||||
if (indice < 0 || indice >= input_shape[j]) {
|
||||
err_indice = indice;
|
||||
|
||||
if (indice >= 0) {
|
||||
if (indice >= input_shape[j]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice);
|
||||
}
|
||||
} else {
|
||||
if (indice < -input_shape[j]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice);
|
||||
} else {
|
||||
indice += input_shape[j];
|
||||
}
|
||||
}
|
||||
|
||||
p.element_offsets[i] += indice * element_counts[j];
|
||||
}
|
||||
}
|
||||
return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScatterND::Compute(OpKernelContext* context) const {
|
||||
|
|
|
|||
|
|
@ -34,11 +34,19 @@ __global__ void _ScatterNDKernel(
|
|||
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
|
||||
// is hard. This is the approach taken by other frameworks for out of bound indices
|
||||
// in their corresponding GPU backends as well.
|
||||
if (index < 0)
|
||||
index = 0;
|
||||
// index >= -dim_value && index < dim_value
|
||||
|
||||
else if (index >= dim_value)
|
||||
index = dim_value - 1;
|
||||
if (index >= 0) {
|
||||
if (index >= dim_value) {
|
||||
index = dim_value - 1;
|
||||
}
|
||||
} else {
|
||||
if (index < -dim_value) {
|
||||
index = 0;
|
||||
} else {
|
||||
index += dim_value;
|
||||
}
|
||||
}
|
||||
|
||||
data_offset += (index * element_count_dim);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,15 @@ TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64_neg_indices) {
|
||||
OpTester test("ScatterND", 11);
|
||||
test.AddInput<int64_t> ("data", {2,2}, {1LL,1LL,2LL,2LL});
|
||||
test.AddInput<int64_t> ("indices", {2,2}, {0LL,0LL,-1LL,-1LL});
|
||||
test.AddInput<int64_t>("updates", {2}, {0LL,3LL});
|
||||
test.AddOutput<int64_t>("output", {2,2}, {0LL,1LL,2LL,3LL});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) {
|
||||
OpTester test1("ScatterND", 11);
|
||||
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"});
|
||||
|
|
@ -55,6 +64,22 @@ TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) {
|
|||
test2.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_matrice_string_int64_neg_indices) {
|
||||
OpTester test1("ScatterND", 11);
|
||||
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"});
|
||||
test1.AddInput<int64_t>("indices", {2,1,2}, {0,-1,-1,0});
|
||||
test1.AddInput<std::string>("updates", {2,1,2}, {"air","bob","terry","smart"});
|
||||
test1.AddOutput<std::string>("output", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("ScatterND", 11);
|
||||
test2.AddInput<std::string>("data", {3,3}, {"egg","","air","","terry","smart","laugh","","hop"});
|
||||
test2.AddInput<int64_t>("indices", {3,2}, {-1,-2,1,0,0,-2});
|
||||
test2.AddInput<std::string>("updates", {3}, {"kite","bob","dance"});
|
||||
test2.AddOutput<std::string>("output", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_slice_float_int64_t) {
|
||||
OpTester test("ScatterND", 11);
|
||||
test.AddInput<float>("data", {2,2}, {0.0f,0.1f,0.1f,0.1f});
|
||||
|
|
@ -76,14 +101,14 @@ TEST(ScatterNDOpTest, ScatterND_slice_double_int64_t) {
|
|||
TEST(ScatterNDOpTest, ScatterND_3tensor_int64) {
|
||||
OpTester test1("ScatterND", 11);
|
||||
test1.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,1LL,1LL,1LL,1LL,6LL,7LL});
|
||||
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,1LL,0LL});
|
||||
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,-1LL,0LL});
|
||||
test1.AddInput<int64_t>("updates", {2,2}, {2LL,3LL,4LL,5LL});
|
||||
test1.AddOutput<int64_t>("output", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("ScatterND", 11);
|
||||
test2.AddInput<int8_t>("data", {2,2,2}, {0,0,2,3,4,0,6,7});
|
||||
test2.AddInput<int64_t>("indices", {2,3}, {0,0,1,1,0,1});
|
||||
test2.AddInput<int64_t>("indices", {2,3}, {0,0,1,-1,0,-1});
|
||||
test2.AddInput<int8_t>("updates", {2}, {1,5});
|
||||
test2.AddOutput<int8_t>("output", {2,2,2}, {0,1,2,3,4,5,6,7});
|
||||
test2.Run();
|
||||
|
|
@ -142,7 +167,7 @@ TEST(ScatterNDOpTest, ScatterND_batched_3tensor_int64) {
|
|||
|
||||
OpTester test2("ScatterND", 11);
|
||||
test2.AddInput<uint32_t>("data", {2,2,2}, {0,0,2,0,4,0,0,7});
|
||||
test2.AddInput<int64_t>("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0});
|
||||
test2.AddInput<int64_t>("indices", {2,2,3}, {0,0,-1,-1,0,-1,0,1,-1,1,-1,0});
|
||||
test2.AddInput<uint32_t>("updates", {2,2}, {1,5,3,6});
|
||||
test2.AddOutput<uint32_t>("output", {2,2,2}, {0,1,2,3,4,5,6,7});
|
||||
test2.Run();
|
||||
|
|
|
|||
|
|
@ -658,6 +658,34 @@ def test_gradient_correctness():
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
||||
@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
|
||||
[[ 2, 3, 4, 4],[ 0, 1, 4, 4]]))
|
||||
def test_scatternd_correctness(device, indices):
|
||||
class NeuralNetScatterND(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetScatterND, self).__init__()
|
||||
|
||||
def forward(self, rerouted_output, dispatch_mask, expert_output):
|
||||
rerouted_output[dispatch_mask] = expert_output
|
||||
return rerouted_output
|
||||
|
||||
pt_model = NeuralNetScatterND().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, rerouted_output, dispatch_mask, expert_output):
|
||||
prediction = model(rerouted_output, dispatch_mask, expert_output)
|
||||
return prediction
|
||||
|
||||
rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device)
|
||||
dispatch_mask = torch.tensor(indices, device=device)
|
||||
expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device)
|
||||
|
||||
pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output)
|
||||
ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output)
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fp16", [False, True])
|
||||
@pytest.mark.parametrize("input_requires_grad", [False, True])
|
||||
def test_gradient_correctness_conv1d(use_fp16, input_requires_grad):
|
||||
|
|
|
|||
Loading…
Reference in a new issue