Output of nonzero is transposed, fix fake tensor (#144695)

Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695
Approved by: https://github.com/bobrenjc93, https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2025-01-24 21:24:02 -08:00 committed by PyTorch MergeBot
parent 76bec878da
commit 90448f0128
8 changed files with 68 additions and 41 deletions

View file

@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,928600000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,943000000,0.015

1 add_loop_eager compile_time_instruction_count 3066000000 0.015
18
19
20
21
22
23
24

View file

@ -1406,17 +1406,17 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1427,19 +1427,19 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1452,16 +1452,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg1_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1471,18 +1471,18 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1)
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg0_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,

View file

@ -7,19 +7,14 @@ import torch
from torch._dynamo import config as dynamo_config
from torch._inductor import config as inductor_config
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import is_big_gpu
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
skipCPUIf,
skipGPUIf,
)
from torch.testing._internal.common_utils import IS_LINUX, parametrize
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CUDA,
HAS_GPU,
requires_gpu,
)
from torch.testing._internal.common_utils import parametrize
from torch.testing._internal.inductor_utils import HAS_GPU
class TestUnbackedSymints(InductorTestCase):
@ -127,7 +122,7 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@requires_gpu()
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_triton_kernel_grid(self, device):
if device == "cpu":
@ -163,6 +158,7 @@ class TestUnbackedSymints(InductorTestCase):
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@inductor_config.patch({"max_autotune": True})
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_equivalent_backed_unbacked(self, device):
@ -195,7 +191,8 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@requires_gpu()
@skipCPUIf(True, "precision not good enough on CPU")
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_vertical_pointwise_reduction_fusion(self, device):
# reset in case we run both cpu and cuda tests
@ -214,9 +211,9 @@ class TestUnbackedSymints(InductorTestCase):
return pointwise, reduction
example_inputs = (
torch.randn(32, 16).to(GPU_TYPE),
torch.randn(1, 16).to(GPU_TYPE),
torch.tensor(32).to(GPU_TYPE),
torch.randn(32, 16, device=device),
torch.randn(1, 16, device=device),
torch.tensor(32, device=device),
)
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
@ -224,6 +221,7 @@ class TestUnbackedSymints(InductorTestCase):
torch.testing.assert_close(actual, expected)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch({"capture_scalar_outputs": True})
@parametrize(
"torch_fn", [torch.mm, torch.bmm, torch.addmm], name_fn=lambda fn: fn.__name__
@ -262,6 +260,7 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_unbacked_range_tree_divisor(self, device):
def fn(x, num):
@ -279,6 +278,7 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_unbacked_masked_scatter(self, device):
def fn(value, mask):
@ -294,6 +294,7 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch({"capture_scalar_outputs": True})
@parametrize("dynamic", [False, True, None])
def test_unbacked_slice_on_subclass(self, device, dynamic):
@ -388,5 +389,4 @@ instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu()):
run_tests()
run_tests()

View file

@ -1597,6 +1597,17 @@ class FakeTensorPropTest(TestCase):
self.assertIsNot(u0, u1)
self.assertTrue(statically_known_true(u0 == u1))
def test_nonzero_stride(self):
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
value = torch.ones(5)
fake_r = value.nonzero()
r = torch.ones(5).nonzero()
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
def test_torch_load_with_fake_mode(self):
class TheModelClass(torch.nn.Module):
def __init__(self) -> None:

View file

@ -678,14 +678,16 @@ class InputWriter:
return v
def tensor(self, name, t) -> None:
from torch.fx.experimental.symbolic_shapes import statically_known_true
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
storage = self.storage(
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
)
args = []
# NB: this is positional, must come first
if _stride_or_default(None, shape=t.shape) != t.stride():
if not statically_known_true(
sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
):
args.append(str(tuple(t.stride())))
if _dtype_or_default(None) != t.dtype:
args.append(f"dtype={t.dtype!r}")

View file

@ -520,7 +520,8 @@ class FakifiedOutWrapper(CompilerWrapper):
out_metas: list[torch.Tensor] = field(default_factory=list)
# TracingContext.fwd_output_strides
# Generated from actually doing compile
fwd_output_strides: Optional[list[list[int]]] = None
# NB: an entry is None if it's not a Tensor
fwd_output_strides: Optional[list[Optional[list[int]]]] = None
needs_post_compile: bool = True
def pre_compile(
@ -551,12 +552,23 @@ class FakifiedOutWrapper(CompilerWrapper):
for i in range(len(out)):
if not isinstance(out[i], Tensor):
continue
strides = fwd_output_strides[i]
# fwd_output_strides is best effort by Inductor. When an output
# Tensor has unbacked SymInts, Inductor may sometimes be unable
# to compute what the output stride would be. If Inductor doesn't
# have any clear direction on the layout, we don't have to run
# as_strided. To repro without this, run:
#
# python test/distributed/test_dynamo_distributed.py
# TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding
if strides is None:
continue
if all(
statically_known_true(s1 == s2)
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
for s1, s2 in zip(out[i].stride(), strides)
):
continue
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
out[i] = out[i].as_strided(out[i].shape, strides)
return out
# To be called post compile

View file

@ -315,14 +315,16 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool:
if a.ndim != 5:
return False
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
expected_stride = 1
for idx in (1, 4, 3, 2, 0):
length = a.shape[idx]
if length == 1:
if guard_size_oblivious(length == 1):
continue
stride = a.stride()[idx]
if stride != expected_stride:
if guard_size_oblivious(stride != expected_stride):
return False
expected_stride *= length
@ -436,7 +438,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
if guard_size_oblivious(length == 1):
continue
if stride != expected_stride:
if guard_size_oblivious(stride != expected_stride):
return False
expected_stride *= length

View file

@ -463,7 +463,7 @@ def nonzero(fake_mode, func, arg):
arg.nonzero_memo = nnz
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64)
@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)