diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ddbc38e52da..7dac6320836 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -37,6 +37,7 @@ from torch._dynamo.testing import ( CompileCounterWithBackend, expectedFailureCodegenDynamic, rand_strided, + reset_rng_state, same, skipIfPy312, ) @@ -12269,6 +12270,46 @@ class CommonTemplate: with self.assertRaisesRegex(RuntimeError, "Output size is too small"): _ = torch.compile(model)(inputs) + @requires_gpu() + @config.patch(fallback_random=True) + @unittest.skipIf( + config.cpp_wrapper, + "cpp wrapper does not support sort properly: https://gist.github.com/shunting314/e58f637f9972f1ad1a033d73cee6e42a", + ) + def test_mix_device_index(self): + """ + A tiny repro for this meta internal issue: https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065 + whose root cause is Inductor having wrong assumption of index.Tensor's output + stride. + """ + image_latent = ( + torch.randn((24, 16, 32, 32), device=GPU_TYPE) + .to(memory_format=torch.channels_last) + .view(2, 12, 16, 32, 32) + ) + + def f(image_latent): + indices = torch.argsort(torch.rand(2, 12), dim=-1) + + tar_latent = image_latent[torch.arange(2).unsqueeze(-1), indices[:, :3]] + + # The original model uses einops. In this unit test, we use view op directly + # to avoid importing einops + # tar_latent_rearranged = einops.rearrange( + # tar_latent, "b n c h w -> (b n) c h w" + # ) + tar_latent_rearranged = tar_latent.view(-1, *tar_latent.size()[2:]) + + return tar_latent_rearranged + + reset_rng_state() + ref = f(image_latent) + opt_f = torch.compile(f) + reset_rng_state() + act = opt_f(image_latent) + + torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) + @dataclasses.dataclass class TestFailure: diff --git a/test/test_meta.py b/test/test_meta.py index b2f322740b8..8a6e29362e7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1805,6 +1805,24 @@ class TestMeta(TestCase): self.assertEqual(nz.stride(), torch.Size([1, 24])) + def test_stride_for_index_Tensor(self): + from torch._subclasses import FakeTensorMode + x = torch.randn((24, 16, 32, 32)).to(memory_format=torch.channels_last) + x = x.view(2, 12, 16, 32, 32) + + i1 = torch.arange(2).unsqueeze(-1) + i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3] + + out = x[i1, i2] + + mode = FakeTensorMode() + with mode: + f_x = mode.from_tensor(x) + f_i1 = mode.from_tensor(i1) + f_i2 = mode.from_tensor(i2) + f_out = f_x[f_i1, f_i2] + + self.assertEqual(out.stride(), f_out.stride()) instantiate_device_type_tests(TestMeta, globals()) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index d2eacac952d..aeb2a4a73bc 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3260,7 +3260,38 @@ def meta_index_Tensor(self, indices): before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) - return self.new_empty(before_shape + replacement_shape + after_shape) + + def _restride_src(self): + """ + This follows restride_src in TensorAdvancedIndexing.cpp + """ + shape = before_shape + replacement_shape + after_shape + strides = list(self.stride()) + strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len( + replacement_shape + ) + return self.as_strided(shape, strides) + + out = self.new_empty(before_shape + replacement_shape + after_shape) + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(self.numel() == 0): + # No need to worry about the output strides if self is empty. + return out + + # Try to follow eager to decide the output stride based on self. + # Note that perm here is the reverse of the 'perm_' decided by + # TensorIteratorBase::reorder_dimensions + restrided_self = _restride_src(self) + perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self) + + # Follow TensorIteratorBase::allocate_or_resize_outputs + if list(perm) != list(range(len(perm))): + perm_shape = utils.apply_perm(out.shape, perm) + new_stride = utils.make_contiguous_strides_for(perm_shape) + new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm)) + out = out.as_strided(out.size(), new_stride) + return out @register_meta([aten.convolution_backward.default])