mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Bugfix for Scatter and GatherElementsGrad (#7593)
* bugfix for scatter and gather elements grad * resolve comments
This commit is contained in:
parent
cea0ea1591
commit
0c91b643fe
8 changed files with 54 additions and 5 deletions
|
|
@ -261,7 +261,7 @@ Status GatherElements::ValidateInputShapes(const TensorShape& input_data_shape,
|
|||
for (int64_t i = 0; i < indices_rank; ++i) {
|
||||
// for all axes except the axis of interest,
|
||||
// make sure that the corresponding 'indices' shape
|
||||
// value if within bounds of the corresponding 'data' shape
|
||||
// value is within bounds of the corresponding 'data' shape
|
||||
if (i != axis) {
|
||||
if (indices_shape[i] < 0 || indices_shape[i] > input_data_shape[i])
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
|
|||
|
|
@ -279,7 +279,9 @@ Status Scatter<EnabledDataTypes>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < input_dims.size(); ++i) {
|
||||
if (input_dims[i] < indices_dims[i]) {
|
||||
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
|
||||
// value is within bounds of the corresponding 'data' shape.
|
||||
if (static_cast<int64_t>(i) != axis_ && input_dims[i] < indices_dims[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
|
||||
" is greater than input dim=", input_dims[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -143,7 +143,9 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < input_dims.size(); ++i) {
|
||||
if (input_dims[i] < indices_dims[i]) {
|
||||
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
|
||||
// value is within bounds of the corresponding 'data' shape.
|
||||
if (static_cast<int64_t>(i) != axis_ && input_dims[i] < indices_dims[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
|
||||
" is greater than input dim=", input_dims[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -278,5 +278,21 @@ TEST(Scatter, SameUpdateWithoutAxis) {
|
|||
scatter_same_updates_tests("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_with_larger_indices_on_axis_tests(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
|
||||
test.AddInput<float>("data", {1, 2}, {1.0f, 2.0f});
|
||||
test.AddInput<int64_t>("indices", {1, 4}, {0, 0, 0, 0});
|
||||
test.AddInput<float>("updates", {1, 4}, {3.0f, 3.0f, 3.0f, 3.0f});
|
||||
test.AddOutput<float>("y", {1, 2}, {3.0f, 2.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Scatter, LargerIndicesOnAxis) {
|
||||
scatter_with_larger_indices_on_axis_tests("Scatter", 9);
|
||||
scatter_with_larger_indices_on_axis_tests("ScatterElements", 11);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2385,6 +2385,20 @@ TEST(GradientCheckerTest, GatherElementsGrad) {
|
|||
{MakeAttribute("axis", axis)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
// GatherElementsGradWithLargerIndiceOnAxis
|
||||
TensorInfo data_info({2, 2}, true);
|
||||
TensorInfo indice_info({2, 4}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4}, {1, 1, 1, 1, 1, 1, 1, 1}};
|
||||
|
||||
TensorInfo y_info({2, 4}, true);
|
||||
int64_t axis = 1;
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indice_info}, {y_info}, &max_error, x_datas,
|
||||
{MakeAttribute("axis", axis)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, TopKGrad) {
|
||||
|
|
|
|||
|
|
@ -222,6 +222,17 @@ TEST(GatherElementsGrad, SameUpdateWithoutAxisMLFloat16) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherElementsGrad, LargerIndicesOnAxis) {
|
||||
onnxruntime::test::OpTester test("GatherElementsGrad", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddInput<float>("dY", {1, 4}, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
std::vector<int64_t> data_shape = {1, 2};
|
||||
test.AddInput<int64_t>("data_shape", {2}, data_shape);
|
||||
test.AddInput<int64_t>("indices", {1, 4}, {0, 1, 0, 1});
|
||||
test.AddOutput<float>("dX", {1, 2}, {4.4f, 6.6f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -64,7 +64,9 @@ Status GatherElementsGrad::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < output_dims.size(); ++i) {
|
||||
if (output_dims[i] < indices_dims[i]) {
|
||||
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
|
||||
// value is within bounds of the corresponding 'data' shape.
|
||||
if (static_cast<int64_t>(i) != axis_ && output_dims[i] < indices_dims[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
|
||||
" is greater than Output dim=", output_dims[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,7 +108,9 @@ Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < output_dims.size(); ++i) {
|
||||
if (output_dims[i] < indices_dims[i]) {
|
||||
// For all axes except the axis of interest, make sure that the corresponding 'indices' shape
|
||||
// value is within bounds of the corresponding 'data' shape.
|
||||
if (static_cast<int64_t>(i) != axis_ && output_dims[i] < indices_dims[i]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Indices dim=", indices_dims[i], " at pos=", i,
|
||||
" is greater than Output dim=", output_dims[i]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue