Bugfix for Scatter and GatherElementsGrad (#7593)

* bugfix for scatter and gather elements grad

* resolve comments
This commit is contained in:
Vincent Wang 2021-05-07 14:02:26 +08:00 committed by GitHub
parent cea0ea1591
commit 0c91b643fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 54 additions and 5 deletions

View file

@ -261,7 +261,7 @@ Status GatherElements::ValidateInputShapes(const TensorShape& input_data_shape,
for (int64_t i = 0; i < indices_rank; ++i) {
// for all axes except the axis of interest,
// make sure that the corresponding 'indices' shape
// value if within bounds of the corresponding 'data' shape
// value is within bounds of the corresponding 'data' shape
if (i != axis) {
if (indices_shape[i] < 0 || indices_shape[i] > input_data_shape[i])
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,

View file

@ -279,7 +279,9 @@ Status Scatter<EnabledDataTypes>::Compute(OpKernelContext* context) const {
}
for (size_t i = 0; i < input_dims.size(); ++i) {
if (input_dims[i] < indices_dims[i]) {
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
// value is within bounds of the corresponding 'data' shape.
if (static_cast<int64_t>(i) != axis_ && input_dims[i] < indices_dims[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
" is greater than input dim=", input_dims[i]);
}

View file

@ -143,7 +143,9 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
}
for (size_t i = 0; i < input_dims.size(); ++i) {
if (input_dims[i] < indices_dims[i]) {
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
// value is within bounds of the corresponding 'data' shape.
if (static_cast<int64_t>(i) != axis_ && input_dims[i] < indices_dims[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
" is greater than input dim=", input_dims[i]);
}

View file

@ -278,5 +278,21 @@ TEST(Scatter, SameUpdateWithoutAxis) {
scatter_same_updates_tests("ScatterElements", 11);
}
static void scatter_with_larger_indices_on_axis_tests(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("data", {1, 2}, {1.0f, 2.0f});
test.AddInput<int64_t>("indices", {1, 4}, {0, 0, 0, 0});
test.AddInput<float>("updates", {1, 4}, {3.0f, 3.0f, 3.0f, 3.0f});
test.AddOutput<float>("y", {1, 2}, {3.0f, 2.0f});
test.Run();
}
TEST(Scatter, LargerIndicesOnAxis) {
scatter_with_larger_indices_on_axis_tests("Scatter", 9);
scatter_with_larger_indices_on_axis_tests("ScatterElements", 11);
}
} // namespace test
} // namespace onnxruntime

View file

@ -2385,6 +2385,20 @@ TEST(GradientCheckerTest, GatherElementsGrad) {
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}
{
// GatherElementsGradWithLargerIndiceOnAxis
TensorInfo data_info({2, 2}, true);
TensorInfo indice_info({2, 4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4}, {1, 1, 1, 1, 1, 1, 1, 1}};
TensorInfo y_info({2, 4}, true);
int64_t axis = 1;
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axis", axis)});
EXPECT_IS_TINY(max_error);
}
}
TEST(GradientCheckerTest, TopKGrad) {

View file

@ -222,6 +222,17 @@ TEST(GatherElementsGrad, SameUpdateWithoutAxisMLFloat16) {
test.Run();
}
TEST(GatherElementsGrad, LargerIndicesOnAxis) {
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("dY", {1, 4}, {1.1f, 2.2f, 3.3f, 4.4f});
std::vector<int64_t> data_shape = {1, 2};
test.AddInput<int64_t>("data_shape", {2}, data_shape);
test.AddInput<int64_t>("indices", {1, 4}, {0, 1, 0, 1});
test.AddOutput<float>("dX", {1, 2}, {4.4f, 6.6f});
test.Run();
}
} // namespace test
} // namespace cuda
} // namespace onnxruntime

View file

@ -64,7 +64,9 @@ Status GatherElementsGrad::Compute(OpKernelContext* context) const {
}
for (size_t i = 0; i < output_dims.size(); ++i) {
if (output_dims[i] < indices_dims[i]) {
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
// value is within bounds of the corresponding 'data' shape.
if (static_cast<int64_t>(i) != axis_ && output_dims[i] < indices_dims[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
" is greater than Output dim=", output_dims[i]);
}

View file

@ -108,7 +108,9 @@ Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const {
}
for (size_t i = 0; i < output_dims.size(); ++i) {
if (output_dims[i] < indices_dims[i]) {
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
// value is within bounds of the corresponding 'data' shape.
if (static_cast<int64_t>(i) != axis_ && output_dims[i] < indices_dims[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
" is greater than Output dim=", output_dims[i]);
}