Add _assert_scalar and teach Inductor to codegen it (#114148)

Inductor codegen for `_assert_async` is currently disabled because we don't really understand how to codegen `scalar_to_tensor` on a Sympy expression. I initially tried to see if I could get this to work, but I got into some weird problem involving stride sorting, so I decided to fix it properly by not going through a tensor.

So we introduce an `_assert_scalar` which takes a scalar as an argument, avoiding needing to turn a SymBool into a tensor before asserting on it. I also add `_functional_assert_scalar` for good luck, although this doesn't do anything right now because https://github.com/pytorch/pytorch/pull/104203 still hasn't been landed.

I need to customize the codegen for this operator, so I decide to directly implement it in Inductor, rather than trying to treat it as a generic ExternKernel. This leads to the new AssertScalar IR node. This is written carefully so that it doesn't get DCE'd by Inductor.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114148
Approved by: https://github.com/jansel
This commit is contained in:
Edward Z. Yang 2024-01-09 11:09:23 -08:00 committed by PyTorch MergeBot
parent d2033a0639
commit b6028acfa4
12 changed files with 109 additions and 7 deletions

View file

@ -22,6 +22,8 @@
#include <ATen/ops/_aminmax_native.h>
#include <ATen/ops/_assert_async_native.h>
#include <ATen/ops/_functional_assert_async_native.h>
#include <ATen/ops/_assert_scalar_native.h>
#include <ATen/ops/_functional_assert_scalar_native.h>
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
#include <ATen/ops/_unique.h>
#include <ATen/ops/allclose_native.h>
@ -421,6 +423,15 @@ void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
}
void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) {
TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed");
}
Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) {
_assert_scalar(scalar, assert_msg);
return dep_token.clone();
}
Tensor _functional_assert_async_msg_cpu(
const Tensor& self,
c10::string_view assert_msg,

View file

@ -175,6 +175,14 @@
CPU: _assert_async_msg_cpu
CUDA: _assert_async_msg_cuda
- func: _assert_scalar(Scalar self, str assert_msg) -> ()
dispatch:
CompositeExplicitAutograd: _assert_scalar
- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor
dispatch:
CompositeExplicitAutograd: _functional_assert_scalar
- func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor
dispatch:
CPU: _functional_assert_async_msg_cpu

View file

@ -3784,7 +3784,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_deferred_runtime_asserts(self):
@torch.compile(backend="aot_eager", fullgraph=True)
@torch.compile(fullgraph=True)
def f(x):
y = x.item()
torch._check_is_size(y)

View file

@ -334,6 +334,7 @@ aten::_foreach_zero
aten::_foreach_zero.out
aten::_foreach_zero_
aten::_functional_assert_async.msg
aten::_functional_assert_scalar
aten::_functional_sym_constrain_range
aten::_functional_sym_constrain_range_for_size
aten::_fused_adam

View file

@ -1254,7 +1254,10 @@ def forward(self, arg_0):
ep = export(M(), (torch.tensor(1), torch.ones(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Deferred runtime assertion failed -i0 <= 0"):
with self.assertRaisesRegex(
RuntimeError,
r"(Deferred runtime assertion failed -i0 <= 0|_local_scalar_dense is outside of inline constraint \[0, inf\])"
):
_ = ep(torch.tensor(-1), torch.randn(4, 5))
self.assertTrue(

View file

@ -1336,6 +1336,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
self.assertFalse(self.get_manager().new_graph_id().id == 0)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_incompatible_cudagraph_ops_item(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x.item()
self.assertEqual(foo(torch.tensor(3.0, device="cuda")), 3.0)
self.assertEqual(foo(torch.tensor(6.0, device="cuda")), 6.0)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_incompatible_cudagraph_ops_nonzero(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x.nonzero()
self.assertEqual(
foo(torch.tensor([1, 0, 2], device="cuda")), torch.tensor([[0], [2]])
)
self.assertEqual(
foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
)
def test_storage_access_error(self):
x = torch.rand([4], device="cuda")
torch._C._set_storage_access_error_msg(x, "custom error msg")

View file

@ -1345,15 +1345,12 @@ class OutputGraph(Checkpointable[OutputGraphState]):
res = sympy_interp(
PythonReferenceAnalysis, symbol_to_proxy, ra.expr
).node
res2 = self.graph.call_function(
torch.ops.aten.scalar_tensor.default, (res,)
)
self.graph.call_function(
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
# TODO: use ra.msg here, but it's pretty
# useless right now
(
res2,
res,
f"Deferred runtime assertion failed {ra.expr}",
),
)

View file

@ -127,6 +127,7 @@ def validate_ir(node_or_nodes):
(
torch._inductor.ir.ExpandView,
DynamicScalar,
AssertScalar,
TensorBox,
sympy.Symbol,
sympy.logic.boolalg.Boolean,
@ -4391,6 +4392,45 @@ class DynamicScalar(ExternKernel):
wrapper.codegen_dynamic_scalar(self)
class AssertScalar(ExternKernel):
"""
The result of a call to aten._assert_scalar
"""
def get_reads(self):
return ()
def should_allocate(self):
return False
def __init__(self, scalar, msg):
super().__init__(
# Buffer(name, layotu)
None,
NoneLayout(torch.device("cpu")), # type: ignore[arg-type]
# InputsKernel(inputs)
[],
) # type: ignore[arg-type]
self.scalar = scalar
self.msg = msg
def has_side_effects(self):
return True
def get_unbacked_symbol_uses(self):
return free_unbacked_symbols(self.scalar)
def codegen(self, wrapper):
assert not V.graph.cpp_wrapper, "NYI"
wrapper.writeline(
f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:"
)
wrapper.writeline(f" raise RuntimeError({repr(self.msg)})")
# No one should ever use this buffer, but for uniformity
# define the variable and assign it None
wrapper.writeline(f"{self.get_name()} = None")
@dataclasses.dataclass
class ExternKernelNode:
name: str

View file

@ -2563,6 +2563,14 @@ def _local_scalar_dense(data):
return sym
@register_lowering(aten._assert_scalar)
def _assert_scalar(data, msg):
buffer = ir.AssertScalar(data, msg)
# This buffer isn't used by anyone (it returns None), so we must explicitly register it
buffer.name = V.graph.register_buffer(buffer)
return buffer
def _full(fill_value, device, dtype, size):
value = fill_value
if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):

View file

@ -586,6 +586,11 @@ def has_incompatible_cudagraph_ops(gm):
"run_and_save_rng_state",
"run_with_rng_state",
"aten._local_scalar_dense",
# Technically, it's not necessary to ban this, because an
# assert_scalar with constant arguments can be validly run
# with CUDA graphs, but the operator is also pointless with
# constant arguments, so might as well ban
"aten._assert_scalar",
}
if torch.are_deterministic_algorithms_enabled():
forbidden_set.update(
@ -606,6 +611,11 @@ def has_incompatible_cudagraph_ops(gm):
for node in gm.graph.nodes:
if str(node.target) in forbidden_set:
return True
if hasattr(node.target, "tags"):
if torch.Tag.dynamic_output_shape in node.target.tags:
return True
if torch.Tag.data_dependent_output in node.target.tags:
return True
return False

View file

@ -41,6 +41,7 @@ _side_effectful_functions: Set[Callable] = {
torch._assert,
torch._assert_async,
_ops.aten._assert_async.msg,
_ops.aten._assert_scalar.default,
_ops.aten.copy_.default,
_ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default,

View file

@ -51,6 +51,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_assert_async", # no return
"_assert_async.msg", # no return
"_cslt_sparse_mm_search", # returns an int
"_assert_scalar", # no return
"_dimI", # returns an int
"_dimV", # returns an int
"_has_same_storage_numel", # returns a boolean