[inductor] add size-asserts for fallback ops (#145904)

Fix https://github.com/pytorch/pytorch/issues/144717

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145904
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang 2025-02-06 12:13:04 -08:00 committed by PyTorch MergeBot
parent b60f630de8
commit bc0191802f
9 changed files with 141 additions and 5 deletions

View file

@ -419,6 +419,7 @@ class CondTests(TestCase):
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@torch._inductor.config.patch(size_asserts=False)
def test_cond_unbacked_symint_inner(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):

View file

@ -41,6 +41,7 @@ from torch._dynamo.testing import (
)
from torch._dynamo.utils import ifdynstaticdefault
from torch._guards import CompileContext, CompileId
from torch._inductor import lowering
from torch._inductor.aoti_eager import (
aoti_compile_with_persistent_cache,
aoti_eager_cache_dir,
@ -903,6 +904,11 @@ class skip_if_cpp_wrapper:
return wrapper
def is_dynamic_shape_enabled():
# What's the best way to decide this?
return not torch._dynamo.config.assume_static_by_default
@instantiate_parametrized_tests
class CommonTemplate:
def is_dtype_supported(self, dtype: torch.dtype) -> bool:
@ -5844,6 +5850,8 @@ class CommonTemplate:
(torch.randn([8, 16, 8, 8]),),
)
# Disable size_asserts for this test due to https://github.com/pytorch/pytorch/issues/145963
@config.patch(size_asserts=os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS") == "1")
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def test_nonzero_unbacked_refinement(self):
def fn(x):
@ -12430,11 +12438,41 @@ class CommonTemplate:
reset_rng_state()
ref = f(image_latent)
opt_f = torch.compile(f)
code = run_and_get_triton_code(opt_f, image_latent)
reset_rng_state()
act = opt_f(image_latent)
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
if is_dynamic_shape_enabled():
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950
else:
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
FileCheck().check_regex(size_assert_pattern).run(code)
@lowering.force_fallback(aten.sort.default)
@unittest.skipIf(
config.cpp_wrapper,
"Inductor does not generate size/stride asserts for cpp_wrapper",
)
def test_size_asserts_for_multi_output_fallback(self):
@torch.compile
def f(x):
return x.sort()
x = torch.randn(16, 32, device=self.device)
code = run_and_get_triton_code(f, x)
if is_dynamic_shape_enabled():
FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check(
"assert_size_stride(buf2, (s0, s1), (s1, 1))"
).run(code)
else:
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
"assert_size_stride(buf2, (16, 32), (32, 1))"
).run(code)
@requires_cuda
@config.patch(use_fast_math=True)
def test_prepare_softmax_with_fast_math(self):

View file

@ -41,7 +41,13 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CPU,
HAS_CUDA,
HAS_XPU,
maybe_skip_size_asserts,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
@ -1103,7 +1109,9 @@ class TestInductorOpInfo(TestCase):
{"assert_equal": False},
),
)
return ((contextlib.nullcontext, {}),)
ctx = functools.partial(maybe_skip_size_asserts, op)
return ((ctx, {}),)
try:

View file

@ -72,6 +72,7 @@ from torch.testing._internal.common_utils import (
TestCase,
unMarkDynamoStrictTest,
)
from torch.testing._internal.inductor_utils import maybe_skip_size_asserts
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
@ -1149,7 +1150,8 @@ class TestCommon(TestCase):
sample = first_sample(self, op.sample_inputs(device, dtype))
# Call op to get prototype for out arguments
expect = op(sample.input, *sample.args, **sample.kwargs)
with maybe_skip_size_asserts(op):
expect = op(sample.input, *sample.args, **sample.kwargs)
any_requires_grad = False
def set_requires_grad(x):
@ -1170,7 +1172,7 @@ class TestCommon(TestCase):
"functions with out=... arguments don't support automatic "
"differentiation, but one of the arguments requires grad."
)
with self.assertRaises(RuntimeError, msg=msg):
with self.assertRaises(RuntimeError, msg=msg), maybe_skip_size_asserts(op):
op(sample.input, *sample.args, **sample.kwargs, out=out)
@ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))

View file

@ -6834,6 +6834,7 @@ class MultiOutput(ExternKernel):
self.get_name(),
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
)
self.codegen_size_asserts(wrapper)
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
super().__init__(None, layout, [input], ())

View file

@ -1,4 +1,7 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import dataclasses
import functools
import itertools
@ -6827,3 +6830,23 @@ from . import jagged_lowerings
jagged_lowerings.register_jagged_ops()
@contextlib.contextmanager
def force_fallback(op: torch._ops.OpOverload):
"""
A context manager to force fallback an op. Used in unit test
for FallbackKernel.
"""
assert isinstance(
op, torch._ops.OpOverload
), "Only OpOverload to make the clean up easier"
old_handler = lowerings.get(op)
try:
register_lowering(op)(fallback_handler(op))
yield
finally:
if old_handler:
lowerings[op] = old_handler
else:
lowerings.pop(op)

View file

@ -782,6 +782,13 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
return nullptr;
}
// We may add the size/stride assert at compile time due to unbacked symint,
// but at runtime, the tensor can be empty.
if (tensor.numel() == 0) {
Py_RETURN_TRUE;
}
std::stringstream msg;
int num_errors = 0;
for (auto i : c10::irange(ndim)) {

View file

@ -3134,6 +3134,15 @@ class TestCase(expecttest.TestCase):
# Is the class strict and compiling?
strict_default = False
should_reset_dynamo = False
# We disable size_asserts for test_ops since some tests fail
# due to mismatch of strides returned from eager v.s. meta kernels
# Only some of the ops has this problem, but since tests in
# test_op.py are parametrized, it's hard to do this specifically
# for the affected ops.
# It's not a big deal since these problems are captured by
# test_torchinductor_opinfo.py as well.
should_disable_size_asserts = False
if compiled:
try:
path = inspect.getfile(type(test_cls))
@ -3145,6 +3154,9 @@ class TestCase(expecttest.TestCase):
from .dynamo_test_failures import FIXME_inductor_non_strict
strict_default = filename not in FIXME_inductor_non_strict
should_reset_dynamo = True
if filename == "test_ops":
should_disable_size_asserts = True
else:
strict_default = True
# inspect.getfile can fail with these
@ -3177,7 +3189,14 @@ class TestCase(expecttest.TestCase):
suppress_errors = not strict_mode
else:
suppress_errors = torch._dynamo.config.suppress_errors
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors):
maybe_disable_size_asserts = (
torch._inductor.config.patch(size_asserts=False)
if should_disable_size_asserts
else contextlib.nullcontext()
)
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
if TEST_WITH_AOT_EAGER:
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:

View file

@ -5,6 +5,7 @@ import torch
import re
import unittest
import functools
import contextlib
import os
from subprocess import CalledProcessError
import sys
@ -140,3 +141,39 @@ IS_H100 = LazyVal(
)
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
def maybe_skip_size_asserts(op):
"""
For certain ops, there meta and eager implementation returns differents
strides. This cause size/strides assert fail. Skip adding those
asserts for now.
"""
if (
op.aten_name
in (
"fft_hfftn",
"fft_hfft",
"fft_hfft2",
"fft_ihfftn",
"fft_fft",
"fft_fft2",
"fft_fftn",
"fft_ifft",
"fft_ifft2",
"fft_ifftn",
"fft_irfft",
"fft_irfft2",
"fft_irfftn",
"fft_ihfft",
"fft_ihfft2",
"fft_rfft",
"fft_rfft2",
"fft_rfftn",
"linalg_eig",
"linalg_eigvals",
)
and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ
):
return torch._inductor.config.patch(size_asserts=False)
else:
return contextlib.nullcontext()