From 90448f0128d8090a07325be24bd513c4d14bed6d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 24 Jan 2025 21:24:02 -0800 Subject: [PATCH] Output of nonzero is transposed, fix fake tensor (#144695) Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695 Approved by: https://github.com/bobrenjc93, https://github.com/albanD --- .../pr_time_benchmarks/expected_results.csv | 2 +- test/inductor/test_auto_functionalize.py | 32 +++++++++---------- test/inductor/test_unbacked_symints.py | 30 ++++++++--------- test/test_fake_tensor.py | 11 +++++++ torch/_dynamo/debug_utils.py | 6 ++-- .../_aot_autograd/runtime_wrappers.py | 18 +++++++++-- torch/_prims_common/__init__.py | 8 +++-- torch/_subclasses/fake_impls.py | 2 +- 8 files changed, 68 insertions(+), 41 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 73a8582c84b..94ab0123bae 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,928600000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,943000000,0.015 diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index b3796620b66..5eb60cdc1dc 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -1406,17 +1406,17 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): """\ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1) - nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None + nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None - _to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None + _to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] - getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None - slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None + slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None return (alias_1, slice_2)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1427,19 +1427,19 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): """\ def forward(self, arg0_1: "f32[2][1]cpu"): clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1) - nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None + nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None - _to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None + _to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] - getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None - slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None + slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None return (alias_1, slice_2)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1452,16 +1452,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"): graph_inductor, """\ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): - nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1) + nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1) sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None - convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None + convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1) - alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type) + alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type) foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None - slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None + slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None return (arg1_1, slice_2)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1471,18 +1471,18 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): graph_inductor, """\ def forward(self, arg0_1: "f32[2][1]cpu"): - nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1) + nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1) sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None - convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None + convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1) - alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type) + alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type) foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None - slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None + slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None return (arg0_1, slice_2)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 3de13a6703b..5289b170f6b 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -7,19 +7,14 @@ import torch from torch._dynamo import config as dynamo_config from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase as InductorTestCase -from torch._inductor.utils import is_big_gpu from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, + skipCPUIf, skipGPUIf, ) -from torch.testing._internal.common_utils import IS_LINUX, parametrize -from torch.testing._internal.inductor_utils import ( - GPU_TYPE, - HAS_CUDA, - HAS_GPU, - requires_gpu, -) +from torch.testing._internal.common_utils import parametrize +from torch.testing._internal.inductor_utils import HAS_GPU class TestUnbackedSymints(InductorTestCase): @@ -127,7 +122,7 @@ class TestUnbackedSymints(InductorTestCase): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) - @requires_gpu() + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) def test_triton_kernel_grid(self, device): if device == "cpu": @@ -163,6 +158,7 @@ class TestUnbackedSymints(InductorTestCase): torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @inductor_config.patch({"max_autotune": True}) @dynamo_config.patch({"capture_scalar_outputs": True}) def test_equivalent_backed_unbacked(self, device): @@ -195,7 +191,8 @@ class TestUnbackedSymints(InductorTestCase): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) - @requires_gpu() + @skipCPUIf(True, "precision not good enough on CPU") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) def test_vertical_pointwise_reduction_fusion(self, device): # reset in case we run both cpu and cuda tests @@ -214,9 +211,9 @@ class TestUnbackedSymints(InductorTestCase): return pointwise, reduction example_inputs = ( - torch.randn(32, 16).to(GPU_TYPE), - torch.randn(1, 16).to(GPU_TYPE), - torch.tensor(32).to(GPU_TYPE), + torch.randn(32, 16, device=device), + torch.randn(1, 16, device=device), + torch.tensor(32, device=device), ) actual = torch.compile(fn, fullgraph=True)(*example_inputs) @@ -224,6 +221,7 @@ class TestUnbackedSymints(InductorTestCase): torch.testing.assert_close(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) @parametrize( "torch_fn", [torch.mm, torch.bmm, torch.addmm], name_fn=lambda fn: fn.__name__ @@ -262,6 +260,7 @@ class TestUnbackedSymints(InductorTestCase): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_unbacked_range_tree_divisor(self, device): def fn(x, num): @@ -279,6 +278,7 @@ class TestUnbackedSymints(InductorTestCase): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) def test_unbacked_masked_scatter(self, device): def fn(value, mask): @@ -294,6 +294,7 @@ class TestUnbackedSymints(InductorTestCase): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) @parametrize("dynamic", [False, True, None]) def test_unbacked_slice_on_subclass(self, device, dynamic): @@ -388,5 +389,4 @@ instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) if __name__ == "__main__": from torch._inductor.test_case import run_tests - if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu()): - run_tests() + run_tests() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 2577066ef18..b81e6861f4d 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1597,6 +1597,17 @@ class FakeTensorPropTest(TestCase): self.assertIsNot(u0, u1) self.assertTrue(statically_known_true(u0 == u1)) + def test_nonzero_stride(self): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + with fake_mode: + value = torch.ones(5) + fake_r = value.nonzero() + + r = torch.ones(5).nonzero() + + self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous()) + def test_torch_load_with_fake_mode(self): class TheModelClass(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 23de50aaf46..1e849136f38 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -678,14 +678,16 @@ class InputWriter: return v def tensor(self, name, t) -> None: - from torch.fx.experimental.symbolic_shapes import statically_known_true + from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq storage = self.storage( t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device ) args = [] # NB: this is positional, must come first - if _stride_or_default(None, shape=t.shape) != t.stride(): + if not statically_known_true( + sym_eq(_stride_or_default(None, shape=t.shape), t.stride()) + ): args.append(str(tuple(t.stride()))) if _dtype_or_default(None) != t.dtype: args.append(f"dtype={t.dtype!r}") diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index b45a6ccdfa9..954865c76b1 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -520,7 +520,8 @@ class FakifiedOutWrapper(CompilerWrapper): out_metas: list[torch.Tensor] = field(default_factory=list) # TracingContext.fwd_output_strides # Generated from actually doing compile - fwd_output_strides: Optional[list[list[int]]] = None + # NB: an entry is None if it's not a Tensor + fwd_output_strides: Optional[list[Optional[list[int]]]] = None needs_post_compile: bool = True def pre_compile( @@ -551,12 +552,23 @@ class FakifiedOutWrapper(CompilerWrapper): for i in range(len(out)): if not isinstance(out[i], Tensor): continue + strides = fwd_output_strides[i] + # fwd_output_strides is best effort by Inductor. When an output + # Tensor has unbacked SymInts, Inductor may sometimes be unable + # to compute what the output stride would be. If Inductor doesn't + # have any clear direction on the layout, we don't have to run + # as_strided. To repro without this, run: + # + # python test/distributed/test_dynamo_distributed.py + # TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding + if strides is None: + continue if all( statically_known_true(s1 == s2) - for s1, s2 in zip(out[i].stride(), fwd_output_strides[i]) + for s1, s2 in zip(out[i].stride(), strides) ): continue - out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) + out[i] = out[i].as_strided(out[i].shape, strides) return out # To be called post compile diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index bbe032bf2f9..57bc423cf3d 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -315,14 +315,16 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool: if a.ndim != 5: return False + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + expected_stride = 1 for idx in (1, 4, 3, 2, 0): length = a.shape[idx] - if length == 1: + if guard_size_oblivious(length == 1): continue stride = a.stride()[idx] - if stride != expected_stride: + if guard_size_oblivious(stride != expected_stride): return False expected_stride *= length @@ -436,7 +438,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: if guard_size_oblivious(length == 1): continue - if stride != expected_stride: + if guard_size_oblivious(stride != expected_stride): return False expected_stride *= length diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 61228eab181..dcada2e845c 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -463,7 +463,7 @@ def nonzero(fake_mode, func, arg): arg.nonzero_memo = nnz - return arg.new_empty((nnz, arg.dim()), dtype=torch.int64) + return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64) @register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)