From 0c0583254eef65c12bca6d76fbde26ea6832ccfe Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Wed, 15 Jan 2025 15:20:22 -0800 Subject: [PATCH] [inductor] fix index.Tensor fallback (#144736) The original issue is we see accuracy problem in a meta internal model [meta internal link](https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065/). The debugging is hard but the root cause is relatively simple. The root cause is that the model has mix-device inputs for index.Tensor which causes Inductor to fallback. And the meta kernel for index.Tensor returns a tensor with inconsistent strides to the eager kernel. The following code snippet ``` import torch from torch._subclasses import FakeTensorMode device = "cuda" x = torch.randn((24, 16, 32, 32), device=device).to(memory_format=torch.channels_last) x = x.view(2, 12, 16, 32, 32) i1 = torch.arange(2).unsqueeze(-1) i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3] print(f"Eager stride: {x[i1, i2].stride()}") mode = FakeTensorMode() with mode: f_x = mode.from_tensor(x) f_i1 = mode.from_tensor(i1) f_i2 = mode.from_tensor(i2) f_out = f_x[f_i1, f_i2] print(f"Meta stride: {f_out.stride()}") ``` would output: ``` Eager stride: (49152, 16384, 1, 512, 16) Meta stride: (49152, 16384, 1024, 32, 1) ``` In this PR, I fix the problem to run eager kernel to get the index.Tensor fallback's output layout. A better solution would be to change meta/eager kernel implementation so that their output layout matches. But I'm not sure how to properly do that. In the index.Tensor meta kernel, we always produce dense output: https://github.com/pytorch/pytorch/blob/6d56277682715e56cfdfcaff6f770acebda966d7/torch/_meta_registrations.py#L3184 . While the eager kernel seems to leverage TensorIteratorBase to decide some dimension permutation: https://github.com/pytorch/pytorch/blob/6d56277682715e56cfdfcaff6f770acebda966d7/aten/src/ATen/TensorIterator.cpp#L232-L308 . We can duplicate this logic to the meta kernel implementation if we really want meta matches eager. I can follow up on this if people have strong opinion to do this. And here is an issue https://github.com/pytorch/pytorch/issues/144717 for asserting size/strides for fallback kernels. With that, the issue debugged here would be much easier to root cause. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144736 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 41 +++++++++++++++++++++++++++++ test/test_meta.py | 18 +++++++++++++ torch/_meta_registrations.py | 33 ++++++++++++++++++++++- 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ddbc38e52da..7dac6320836 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -37,6 +37,7 @@ from torch._dynamo.testing import ( CompileCounterWithBackend, expectedFailureCodegenDynamic, rand_strided, + reset_rng_state, same, skipIfPy312, ) @@ -12269,6 +12270,46 @@ class CommonTemplate: with self.assertRaisesRegex(RuntimeError, "Output size is too small"): _ = torch.compile(model)(inputs) + @requires_gpu() + @config.patch(fallback_random=True) + @unittest.skipIf( + config.cpp_wrapper, + "cpp wrapper does not support sort properly: https://gist.github.com/shunting314/e58f637f9972f1ad1a033d73cee6e42a", + ) + def test_mix_device_index(self): + """ + A tiny repro for this meta internal issue: https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065 + whose root cause is Inductor having wrong assumption of index.Tensor's output + stride. + """ + image_latent = ( + torch.randn((24, 16, 32, 32), device=GPU_TYPE) + .to(memory_format=torch.channels_last) + .view(2, 12, 16, 32, 32) + ) + + def f(image_latent): + indices = torch.argsort(torch.rand(2, 12), dim=-1) + + tar_latent = image_latent[torch.arange(2).unsqueeze(-1), indices[:, :3]] + + # The original model uses einops. In this unit test, we use view op directly + # to avoid importing einops + # tar_latent_rearranged = einops.rearrange( + # tar_latent, "b n c h w -> (b n) c h w" + # ) + tar_latent_rearranged = tar_latent.view(-1, *tar_latent.size()[2:]) + + return tar_latent_rearranged + + reset_rng_state() + ref = f(image_latent) + opt_f = torch.compile(f) + reset_rng_state() + act = opt_f(image_latent) + + torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) + @dataclasses.dataclass class TestFailure: diff --git a/test/test_meta.py b/test/test_meta.py index b2f322740b8..8a6e29362e7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1805,6 +1805,24 @@ class TestMeta(TestCase): self.assertEqual(nz.stride(), torch.Size([1, 24])) + def test_stride_for_index_Tensor(self): + from torch._subclasses import FakeTensorMode + x = torch.randn((24, 16, 32, 32)).to(memory_format=torch.channels_last) + x = x.view(2, 12, 16, 32, 32) + + i1 = torch.arange(2).unsqueeze(-1) + i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3] + + out = x[i1, i2] + + mode = FakeTensorMode() + with mode: + f_x = mode.from_tensor(x) + f_i1 = mode.from_tensor(i1) + f_i2 = mode.from_tensor(i2) + f_out = f_x[f_i1, f_i2] + + self.assertEqual(out.stride(), f_out.stride()) instantiate_device_type_tests(TestMeta, globals()) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index d2eacac952d..aeb2a4a73bc 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3260,7 +3260,38 @@ def meta_index_Tensor(self, indices): before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) - return self.new_empty(before_shape + replacement_shape + after_shape) + + def _restride_src(self): + """ + This follows restride_src in TensorAdvancedIndexing.cpp + """ + shape = before_shape + replacement_shape + after_shape + strides = list(self.stride()) + strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len( + replacement_shape + ) + return self.as_strided(shape, strides) + + out = self.new_empty(before_shape + replacement_shape + after_shape) + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(self.numel() == 0): + # No need to worry about the output strides if self is empty. + return out + + # Try to follow eager to decide the output stride based on self. + # Note that perm here is the reverse of the 'perm_' decided by + # TensorIteratorBase::reorder_dimensions + restrided_self = _restride_src(self) + perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self) + + # Follow TensorIteratorBase::allocate_or_resize_outputs + if list(perm) != list(range(len(perm))): + perm_shape = utils.apply_perm(out.shape, perm) + new_stride = utils.make_contiguous_strides_for(perm_shape) + new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm)) + out = out.as_strided(out.size(), new_stride) + return out @register_meta([aten.convolution_backward.default])