mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix gather issue when index is shape of n by 1 (#99709)
Fix https://github.com/pytorch/pytorch/issues/99595 When the index is shape of {N, 1}, it will also have strides of {1, 0}, which is the same as an expanded tensor (e.g. shape of {5, 5} and strides {1, 0}), leading to wrong output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99709 Approved by: https://github.com/XiaobingSuper, https://github.com/ezyang
This commit is contained in:
parent
e9e5ffe83e
commit
4c9d660733
2 changed files with 19 additions and 0 deletions
|
|
@ -69,6 +69,14 @@ static inline bool can_use_expanded_index_path(
|
|||
return false;
|
||||
}
|
||||
|
||||
// allow only different size on dim 0 for src and index
|
||||
// https://github.com/pytorch/pytorch/issues/99595
|
||||
for (const auto dim : c10::irange(1, index.dim())) {
|
||||
if (src.size(dim) != index.size(dim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_scatter_like) {
|
||||
// using `spmm` for scatter would require sorting on index,
|
||||
// this is only perf beneficial when the inner dimension, aka, `channels`
|
||||
|
|
|
|||
|
|
@ -314,6 +314,17 @@ class TestScatterGather(TestCase):
|
|||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_gather_expanded_index(self, device, dtype):
|
||||
# Test when index is [N, 1], which would have stride [1, 0]
|
||||
# should be excluded from the fast path when index ix expanded
|
||||
input = torch.arange(25).view(5, 5)
|
||||
input2 = input.to(dtype=dtype)
|
||||
|
||||
idx = torch.arange(5).view(5, 1)
|
||||
out = torch.gather(input, 0, idx)
|
||||
out2 = torch.gather(input2, 0, idx)
|
||||
|
||||
self.assertEqual(out.to(dtype=dtype), out2)
|
||||
|
||||
def helper(input_size, idx_size):
|
||||
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
||||
input2 = input.clone()
|
||||
|
|
|
|||
Loading…
Reference in a new issue