mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] add size-asserts for fallback ops (#145904)
Fix https://github.com/pytorch/pytorch/issues/144717 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145904 Approved by: https://github.com/jansel
This commit is contained in:
parent
b60f630de8
commit
bc0191802f
9 changed files with 141 additions and 5 deletions
|
|
@ -419,6 +419,7 @@ class CondTests(TestCase):
|
|||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@torch._inductor.config.patch(size_asserts=False)
|
||||
def test_cond_unbacked_symint_inner(self, device):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, p, a):
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from torch._dynamo.testing import (
|
|||
)
|
||||
from torch._dynamo.utils import ifdynstaticdefault
|
||||
from torch._guards import CompileContext, CompileId
|
||||
from torch._inductor import lowering
|
||||
from torch._inductor.aoti_eager import (
|
||||
aoti_compile_with_persistent_cache,
|
||||
aoti_eager_cache_dir,
|
||||
|
|
@ -903,6 +904,11 @@ class skip_if_cpp_wrapper:
|
|||
return wrapper
|
||||
|
||||
|
||||
def is_dynamic_shape_enabled():
|
||||
# What's the best way to decide this?
|
||||
return not torch._dynamo.config.assume_static_by_default
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class CommonTemplate:
|
||||
def is_dtype_supported(self, dtype: torch.dtype) -> bool:
|
||||
|
|
@ -5844,6 +5850,8 @@ class CommonTemplate:
|
|||
(torch.randn([8, 16, 8, 8]),),
|
||||
)
|
||||
|
||||
# Disable size_asserts for this test due to https://github.com/pytorch/pytorch/issues/145963
|
||||
@config.patch(size_asserts=os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS") == "1")
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_nonzero_unbacked_refinement(self):
|
||||
def fn(x):
|
||||
|
|
@ -12430,11 +12438,41 @@ class CommonTemplate:
|
|||
reset_rng_state()
|
||||
ref = f(image_latent)
|
||||
opt_f = torch.compile(f)
|
||||
|
||||
code = run_and_get_triton_code(opt_f, image_latent)
|
||||
reset_rng_state()
|
||||
act = opt_f(image_latent)
|
||||
|
||||
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
|
||||
|
||||
if is_dynamic_shape_enabled():
|
||||
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950
|
||||
else:
|
||||
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
|
||||
FileCheck().check_regex(size_assert_pattern).run(code)
|
||||
|
||||
@lowering.force_fallback(aten.sort.default)
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper,
|
||||
"Inductor does not generate size/stride asserts for cpp_wrapper",
|
||||
)
|
||||
def test_size_asserts_for_multi_output_fallback(self):
|
||||
@torch.compile
|
||||
def f(x):
|
||||
return x.sort()
|
||||
|
||||
x = torch.randn(16, 32, device=self.device)
|
||||
code = run_and_get_triton_code(f, x)
|
||||
|
||||
if is_dynamic_shape_enabled():
|
||||
FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check(
|
||||
"assert_size_stride(buf2, (s0, s1), (s1, 1))"
|
||||
).run(code)
|
||||
else:
|
||||
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
|
||||
"assert_size_stride(buf2, (16, 32), (32, 1))"
|
||||
).run(code)
|
||||
|
||||
@requires_cuda
|
||||
@config.patch(use_fast_math=True)
|
||||
def test_prepare_softmax_with_fast_math(self):
|
||||
|
|
|
|||
|
|
@ -41,7 +41,13 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_CPU,
|
||||
HAS_CUDA,
|
||||
HAS_XPU,
|
||||
maybe_skip_size_asserts,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
|
@ -1103,7 +1109,9 @@ class TestInductorOpInfo(TestCase):
|
|||
{"assert_equal": False},
|
||||
),
|
||||
)
|
||||
return ((contextlib.nullcontext, {}),)
|
||||
|
||||
ctx = functools.partial(maybe_skip_size_asserts, op)
|
||||
return ((ctx, {}),)
|
||||
|
||||
try:
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ from torch.testing._internal.common_utils import (
|
|||
TestCase,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import maybe_skip_size_asserts
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
|
@ -1149,7 +1150,8 @@ class TestCommon(TestCase):
|
|||
sample = first_sample(self, op.sample_inputs(device, dtype))
|
||||
|
||||
# Call op to get prototype for out arguments
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
with maybe_skip_size_asserts(op):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
any_requires_grad = False
|
||||
|
||||
def set_requires_grad(x):
|
||||
|
|
@ -1170,7 +1172,7 @@ class TestCommon(TestCase):
|
|||
"functions with out=... arguments don't support automatic "
|
||||
"differentiation, but one of the arguments requires grad."
|
||||
)
|
||||
with self.assertRaises(RuntimeError, msg=msg):
|
||||
with self.assertRaises(RuntimeError, msg=msg), maybe_skip_size_asserts(op):
|
||||
op(sample.input, *sample.args, **sample.kwargs, out=out)
|
||||
|
||||
@ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))
|
||||
|
|
|
|||
|
|
@ -6834,6 +6834,7 @@ class MultiOutput(ExternKernel):
|
|||
self.get_name(),
|
||||
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
|
||||
)
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(None, layout, [input], ())
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
|
|
@ -6827,3 +6830,23 @@ from . import jagged_lowerings
|
|||
|
||||
|
||||
jagged_lowerings.register_jagged_ops()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def force_fallback(op: torch._ops.OpOverload):
|
||||
"""
|
||||
A context manager to force fallback an op. Used in unit test
|
||||
for FallbackKernel.
|
||||
"""
|
||||
assert isinstance(
|
||||
op, torch._ops.OpOverload
|
||||
), "Only OpOverload to make the clean up easier"
|
||||
old_handler = lowerings.get(op)
|
||||
try:
|
||||
register_lowering(op)(fallback_handler(op))
|
||||
yield
|
||||
finally:
|
||||
if old_handler:
|
||||
lowerings[op] = old_handler
|
||||
else:
|
||||
lowerings.pop(op)
|
||||
|
|
|
|||
|
|
@ -782,6 +782,13 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
|
|||
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// We may add the size/stride assert at compile time due to unbacked symint,
|
||||
// but at runtime, the tensor can be empty.
|
||||
if (tensor.numel() == 0) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
std::stringstream msg;
|
||||
int num_errors = 0;
|
||||
for (auto i : c10::irange(ndim)) {
|
||||
|
|
|
|||
|
|
@ -3134,6 +3134,15 @@ class TestCase(expecttest.TestCase):
|
|||
# Is the class strict and compiling?
|
||||
strict_default = False
|
||||
should_reset_dynamo = False
|
||||
|
||||
# We disable size_asserts for test_ops since some tests fail
|
||||
# due to mismatch of strides returned from eager v.s. meta kernels
|
||||
# Only some of the ops has this problem, but since tests in
|
||||
# test_op.py are parametrized, it's hard to do this specifically
|
||||
# for the affected ops.
|
||||
# It's not a big deal since these problems are captured by
|
||||
# test_torchinductor_opinfo.py as well.
|
||||
should_disable_size_asserts = False
|
||||
if compiled:
|
||||
try:
|
||||
path = inspect.getfile(type(test_cls))
|
||||
|
|
@ -3145,6 +3154,9 @@ class TestCase(expecttest.TestCase):
|
|||
from .dynamo_test_failures import FIXME_inductor_non_strict
|
||||
strict_default = filename not in FIXME_inductor_non_strict
|
||||
should_reset_dynamo = True
|
||||
|
||||
if filename == "test_ops":
|
||||
should_disable_size_asserts = True
|
||||
else:
|
||||
strict_default = True
|
||||
# inspect.getfile can fail with these
|
||||
|
|
@ -3177,7 +3189,14 @@ class TestCase(expecttest.TestCase):
|
|||
suppress_errors = not strict_mode
|
||||
else:
|
||||
suppress_errors = torch._dynamo.config.suppress_errors
|
||||
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors):
|
||||
|
||||
maybe_disable_size_asserts = (
|
||||
torch._inductor.config.patch(size_asserts=False)
|
||||
if should_disable_size_asserts
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
|
||||
if TEST_WITH_AOT_EAGER:
|
||||
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
|
||||
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import re
|
||||
import unittest
|
||||
import functools
|
||||
import contextlib
|
||||
import os
|
||||
from subprocess import CalledProcessError
|
||||
import sys
|
||||
|
|
@ -140,3 +141,39 @@ IS_H100 = LazyVal(
|
|||
)
|
||||
|
||||
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
|
||||
|
||||
def maybe_skip_size_asserts(op):
|
||||
"""
|
||||
For certain ops, there meta and eager implementation returns differents
|
||||
strides. This cause size/strides assert fail. Skip adding those
|
||||
asserts for now.
|
||||
"""
|
||||
if (
|
||||
op.aten_name
|
||||
in (
|
||||
"fft_hfftn",
|
||||
"fft_hfft",
|
||||
"fft_hfft2",
|
||||
"fft_ihfftn",
|
||||
"fft_fft",
|
||||
"fft_fft2",
|
||||
"fft_fftn",
|
||||
"fft_ifft",
|
||||
"fft_ifft2",
|
||||
"fft_ifftn",
|
||||
"fft_irfft",
|
||||
"fft_irfft2",
|
||||
"fft_irfftn",
|
||||
"fft_ihfft",
|
||||
"fft_ihfft2",
|
||||
"fft_rfft",
|
||||
"fft_rfft2",
|
||||
"fft_rfftn",
|
||||
"linalg_eig",
|
||||
"linalg_eigvals",
|
||||
)
|
||||
and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ
|
||||
):
|
||||
return torch._inductor.config.patch(size_asserts=False)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
|
|
|||
Loading…
Reference in a new issue