fix gather issue when index is shape of n by 1 (#99709)

Fix https://github.com/pytorch/pytorch/issues/99595

When the index is shape of {N, 1}, it will also have strides of {1, 0}, which is the same as an expanded tensor (e.g. shape of {5, 5} and strides {1, 0}), leading to wrong output.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99709
Approved by: https://github.com/XiaobingSuper, https://github.com/ezyang
This commit is contained in:
mingfeima 2023-04-23 10:04:55 +08:00 committed by PyTorch MergeBot
parent e9e5ffe83e
commit 4c9d660733
2 changed files with 19 additions and 0 deletions

View file

@ -69,6 +69,14 @@ static inline bool can_use_expanded_index_path(
return false;
}
// allow only different size on dim 0 for src and index
// https://github.com/pytorch/pytorch/issues/99595
for (const auto dim : c10::irange(1, index.dim())) {
if (src.size(dim) != index.size(dim)) {
return false;
}
}
if (is_scatter_like) {
// using `spmm` for scatter would require sorting on index,
// this is only perf beneficial when the inner dimension, aka, `channels`

View file

@ -314,6 +314,17 @@ class TestScatterGather(TestCase):
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_gather_expanded_index(self, device, dtype):
# Test when index is [N, 1], which would have stride [1, 0]
# should be excluded from the fast path when index ix expanded
input = torch.arange(25).view(5, 5)
input2 = input.to(dtype=dtype)
idx = torch.arange(5).view(5, 1)
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx)
self.assertEqual(out.to(dtype=dtype), out2)
def helper(input_size, idx_size):
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone()