mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Output of nonzero is transposed, fix fake tensor (#144695)
Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695 Approved by: https://github.com/bobrenjc93, https://github.com/albanD
This commit is contained in:
parent
76bec878da
commit
90448f0128
8 changed files with 68 additions and 41 deletions
|
|
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015
|
|||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,928600000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,943000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -1406,17 +1406,17 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
|
|||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
|
||||
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
return (alias_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1427,19 +1427,19 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
|||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
|
||||
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None
|
||||
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
|
||||
alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
return (alias_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1452,16 +1452,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
|
|||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
|
||||
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
return (arg1_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1471,18 +1471,18 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
|||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
|
||||
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1)
|
||||
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
return (arg0_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
|
|||
|
|
@ -7,19 +7,14 @@ import torch
|
|||
from torch._dynamo import config as dynamo_config
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
skipCPUIf,
|
||||
skipGPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_LINUX, parametrize
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_CUDA,
|
||||
HAS_GPU,
|
||||
requires_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
class TestUnbackedSymints(InductorTestCase):
|
||||
|
|
@ -127,7 +122,7 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@requires_gpu()
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
def test_triton_kernel_grid(self, device):
|
||||
if device == "cpu":
|
||||
|
|
@ -163,6 +158,7 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@inductor_config.patch({"max_autotune": True})
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
def test_equivalent_backed_unbacked(self, device):
|
||||
|
|
@ -195,7 +191,8 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@requires_gpu()
|
||||
@skipCPUIf(True, "precision not good enough on CPU")
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
def test_vertical_pointwise_reduction_fusion(self, device):
|
||||
# reset in case we run both cpu and cuda tests
|
||||
|
|
@ -214,9 +211,9 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
return pointwise, reduction
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(32, 16).to(GPU_TYPE),
|
||||
torch.randn(1, 16).to(GPU_TYPE),
|
||||
torch.tensor(32).to(GPU_TYPE),
|
||||
torch.randn(32, 16, device=device),
|
||||
torch.randn(1, 16, device=device),
|
||||
torch.tensor(32, device=device),
|
||||
)
|
||||
|
||||
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
|
|
@ -224,6 +221,7 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
torch.testing.assert_close(actual, expected)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
@parametrize(
|
||||
"torch_fn", [torch.mm, torch.bmm, torch.addmm], name_fn=lambda fn: fn.__name__
|
||||
|
|
@ -262,6 +260,7 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_unbacked_range_tree_divisor(self, device):
|
||||
def fn(x, num):
|
||||
|
|
@ -279,6 +278,7 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
def test_unbacked_masked_scatter(self, device):
|
||||
def fn(value, mask):
|
||||
|
|
@ -294,6 +294,7 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
@parametrize("dynamic", [False, True, None])
|
||||
def test_unbacked_slice_on_subclass(self, device, dynamic):
|
||||
|
|
@ -388,5 +389,4 @@ instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
|||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu()):
|
||||
run_tests()
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1597,6 +1597,17 @@ class FakeTensorPropTest(TestCase):
|
|||
self.assertIsNot(u0, u1)
|
||||
self.assertTrue(statically_known_true(u0 == u1))
|
||||
|
||||
def test_nonzero_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
fake_mode = FakeTensorMode(shape_env=shape_env)
|
||||
with fake_mode:
|
||||
value = torch.ones(5)
|
||||
fake_r = value.nonzero()
|
||||
|
||||
r = torch.ones(5).nonzero()
|
||||
|
||||
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
|
||||
|
||||
def test_torch_load_with_fake_mode(self):
|
||||
class TheModelClass(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -678,14 +678,16 @@ class InputWriter:
|
|||
return v
|
||||
|
||||
def tensor(self, name, t) -> None:
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
|
||||
storage = self.storage(
|
||||
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
|
||||
)
|
||||
args = []
|
||||
# NB: this is positional, must come first
|
||||
if _stride_or_default(None, shape=t.shape) != t.stride():
|
||||
if not statically_known_true(
|
||||
sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
|
||||
):
|
||||
args.append(str(tuple(t.stride())))
|
||||
if _dtype_or_default(None) != t.dtype:
|
||||
args.append(f"dtype={t.dtype!r}")
|
||||
|
|
|
|||
|
|
@ -520,7 +520,8 @@ class FakifiedOutWrapper(CompilerWrapper):
|
|||
out_metas: list[torch.Tensor] = field(default_factory=list)
|
||||
# TracingContext.fwd_output_strides
|
||||
# Generated from actually doing compile
|
||||
fwd_output_strides: Optional[list[list[int]]] = None
|
||||
# NB: an entry is None if it's not a Tensor
|
||||
fwd_output_strides: Optional[list[Optional[list[int]]]] = None
|
||||
needs_post_compile: bool = True
|
||||
|
||||
def pre_compile(
|
||||
|
|
@ -551,12 +552,23 @@ class FakifiedOutWrapper(CompilerWrapper):
|
|||
for i in range(len(out)):
|
||||
if not isinstance(out[i], Tensor):
|
||||
continue
|
||||
strides = fwd_output_strides[i]
|
||||
# fwd_output_strides is best effort by Inductor. When an output
|
||||
# Tensor has unbacked SymInts, Inductor may sometimes be unable
|
||||
# to compute what the output stride would be. If Inductor doesn't
|
||||
# have any clear direction on the layout, we don't have to run
|
||||
# as_strided. To repro without this, run:
|
||||
#
|
||||
# python test/distributed/test_dynamo_distributed.py
|
||||
# TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding
|
||||
if strides is None:
|
||||
continue
|
||||
if all(
|
||||
statically_known_true(s1 == s2)
|
||||
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
|
||||
for s1, s2 in zip(out[i].stride(), strides)
|
||||
):
|
||||
continue
|
||||
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
|
||||
out[i] = out[i].as_strided(out[i].shape, strides)
|
||||
return out
|
||||
|
||||
# To be called post compile
|
||||
|
|
|
|||
|
|
@ -315,14 +315,16 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool:
|
|||
if a.ndim != 5:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 4, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if length == 1:
|
||||
if guard_size_oblivious(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if stride != expected_stride:
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
|
|
@ -436,7 +438,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
|||
if guard_size_oblivious(length == 1):
|
||||
continue
|
||||
|
||||
if stride != expected_stride:
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
|
|
|
|||
|
|
@ -463,7 +463,7 @@ def nonzero(fake_mode, func, arg):
|
|||
|
||||
arg.nonzero_memo = nnz
|
||||
|
||||
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
||||
return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64)
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)
|
||||
|
|
|
|||
Loading…
Reference in a new issue