mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add _assert_scalar and teach Inductor to codegen it (#114148)
Inductor codegen for `_assert_async` is currently disabled because we don't really understand how to codegen `scalar_to_tensor` on a Sympy expression. I initially tried to see if I could get this to work, but I got into some weird problem involving stride sorting, so I decided to fix it properly by not going through a tensor. So we introduce an `_assert_scalar` which takes a scalar as an argument, avoiding needing to turn a SymBool into a tensor before asserting on it. I also add `_functional_assert_scalar` for good luck, although this doesn't do anything right now because https://github.com/pytorch/pytorch/pull/104203 still hasn't been landed. I need to customize the codegen for this operator, so I decide to directly implement it in Inductor, rather than trying to treat it as a generic ExternKernel. This leads to the new AssertScalar IR node. This is written carefully so that it doesn't get DCE'd by Inductor. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/114148 Approved by: https://github.com/jansel
This commit is contained in:
parent
d2033a0639
commit
b6028acfa4
12 changed files with 109 additions and 7 deletions
|
|
@ -22,6 +22,8 @@
|
|||
#include <ATen/ops/_aminmax_native.h>
|
||||
#include <ATen/ops/_assert_async_native.h>
|
||||
#include <ATen/ops/_functional_assert_async_native.h>
|
||||
#include <ATen/ops/_assert_scalar_native.h>
|
||||
#include <ATen/ops/_functional_assert_scalar_native.h>
|
||||
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
||||
#include <ATen/ops/_unique.h>
|
||||
#include <ATen/ops/allclose_native.h>
|
||||
|
|
@ -421,6 +423,15 @@ void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
|
|||
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) {
|
||||
TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) {
|
||||
_assert_scalar(scalar, assert_msg);
|
||||
return dep_token.clone();
|
||||
}
|
||||
|
||||
Tensor _functional_assert_async_msg_cpu(
|
||||
const Tensor& self,
|
||||
c10::string_view assert_msg,
|
||||
|
|
|
|||
|
|
@ -175,6 +175,14 @@
|
|||
CPU: _assert_async_msg_cpu
|
||||
CUDA: _assert_async_msg_cuda
|
||||
|
||||
- func: _assert_scalar(Scalar self, str assert_msg) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _assert_scalar
|
||||
|
||||
- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _functional_assert_scalar
|
||||
|
||||
- func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor
|
||||
dispatch:
|
||||
CPU: _functional_assert_async_msg_cpu
|
||||
|
|
|
|||
|
|
@ -3784,7 +3784,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_deferred_runtime_asserts(self):
|
||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
y = x.item()
|
||||
torch._check_is_size(y)
|
||||
|
|
|
|||
|
|
@ -334,6 +334,7 @@ aten::_foreach_zero
|
|||
aten::_foreach_zero.out
|
||||
aten::_foreach_zero_
|
||||
aten::_functional_assert_async.msg
|
||||
aten::_functional_assert_scalar
|
||||
aten::_functional_sym_constrain_range
|
||||
aten::_functional_sym_constrain_range_for_size
|
||||
aten::_fused_adam
|
||||
|
|
|
|||
|
|
@ -1254,7 +1254,10 @@ def forward(self, arg_0):
|
|||
|
||||
ep = export(M(), (torch.tensor(1), torch.ones(4, 5)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Deferred runtime assertion failed -i0 <= 0"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"(Deferred runtime assertion failed -i0 <= 0|_local_scalar_dense is outside of inline constraint \[0, inf\])"
|
||||
):
|
||||
_ = ep(torch.tensor(-1), torch.randn(4, 5))
|
||||
|
||||
self.assertTrue(
|
||||
|
|
|
|||
|
|
@ -1336,6 +1336,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
||||
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_incompatible_cudagraph_ops_item(self):
|
||||
@torch.compile(mode="reduce-overhead")
|
||||
def foo(x):
|
||||
return x.item()
|
||||
|
||||
self.assertEqual(foo(torch.tensor(3.0, device="cuda")), 3.0)
|
||||
self.assertEqual(foo(torch.tensor(6.0, device="cuda")), 6.0)
|
||||
|
||||
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
||||
def test_incompatible_cudagraph_ops_nonzero(self):
|
||||
@torch.compile(mode="reduce-overhead")
|
||||
def foo(x):
|
||||
return x.nonzero()
|
||||
|
||||
self.assertEqual(
|
||||
foo(torch.tensor([1, 0, 2], device="cuda")), torch.tensor([[0], [2]])
|
||||
)
|
||||
self.assertEqual(
|
||||
foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
|
||||
)
|
||||
|
||||
def test_storage_access_error(self):
|
||||
x = torch.rand([4], device="cuda")
|
||||
torch._C._set_storage_access_error_msg(x, "custom error msg")
|
||||
|
|
|
|||
|
|
@ -1345,15 +1345,12 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
res = sympy_interp(
|
||||
PythonReferenceAnalysis, symbol_to_proxy, ra.expr
|
||||
).node
|
||||
res2 = self.graph.call_function(
|
||||
torch.ops.aten.scalar_tensor.default, (res,)
|
||||
)
|
||||
self.graph.call_function(
|
||||
torch.ops.aten._assert_async.msg,
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
# TODO: use ra.msg here, but it's pretty
|
||||
# useless right now
|
||||
(
|
||||
res2,
|
||||
res,
|
||||
f"Deferred runtime assertion failed {ra.expr}",
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ def validate_ir(node_or_nodes):
|
|||
(
|
||||
torch._inductor.ir.ExpandView,
|
||||
DynamicScalar,
|
||||
AssertScalar,
|
||||
TensorBox,
|
||||
sympy.Symbol,
|
||||
sympy.logic.boolalg.Boolean,
|
||||
|
|
@ -4391,6 +4392,45 @@ class DynamicScalar(ExternKernel):
|
|||
wrapper.codegen_dynamic_scalar(self)
|
||||
|
||||
|
||||
class AssertScalar(ExternKernel):
|
||||
"""
|
||||
The result of a call to aten._assert_scalar
|
||||
"""
|
||||
|
||||
def get_reads(self):
|
||||
return ()
|
||||
|
||||
def should_allocate(self):
|
||||
return False
|
||||
|
||||
def __init__(self, scalar, msg):
|
||||
super().__init__(
|
||||
# Buffer(name, layotu)
|
||||
None,
|
||||
NoneLayout(torch.device("cpu")), # type: ignore[arg-type]
|
||||
# InputsKernel(inputs)
|
||||
[],
|
||||
) # type: ignore[arg-type]
|
||||
self.scalar = scalar
|
||||
self.msg = msg
|
||||
|
||||
def has_side_effects(self):
|
||||
return True
|
||||
|
||||
def get_unbacked_symbol_uses(self):
|
||||
return free_unbacked_symbols(self.scalar)
|
||||
|
||||
def codegen(self, wrapper):
|
||||
assert not V.graph.cpp_wrapper, "NYI"
|
||||
wrapper.writeline(
|
||||
f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:"
|
||||
)
|
||||
wrapper.writeline(f" raise RuntimeError({repr(self.msg)})")
|
||||
# No one should ever use this buffer, but for uniformity
|
||||
# define the variable and assign it None
|
||||
wrapper.writeline(f"{self.get_name()} = None")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelNode:
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -2563,6 +2563,14 @@ def _local_scalar_dense(data):
|
|||
return sym
|
||||
|
||||
|
||||
@register_lowering(aten._assert_scalar)
|
||||
def _assert_scalar(data, msg):
|
||||
buffer = ir.AssertScalar(data, msg)
|
||||
# This buffer isn't used by anyone (it returns None), so we must explicitly register it
|
||||
buffer.name = V.graph.register_buffer(buffer)
|
||||
return buffer
|
||||
|
||||
|
||||
def _full(fill_value, device, dtype, size):
|
||||
value = fill_value
|
||||
if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
|
||||
|
|
|
|||
|
|
@ -586,6 +586,11 @@ def has_incompatible_cudagraph_ops(gm):
|
|||
"run_and_save_rng_state",
|
||||
"run_with_rng_state",
|
||||
"aten._local_scalar_dense",
|
||||
# Technically, it's not necessary to ban this, because an
|
||||
# assert_scalar with constant arguments can be validly run
|
||||
# with CUDA graphs, but the operator is also pointless with
|
||||
# constant arguments, so might as well ban
|
||||
"aten._assert_scalar",
|
||||
}
|
||||
if torch.are_deterministic_algorithms_enabled():
|
||||
forbidden_set.update(
|
||||
|
|
@ -606,6 +611,11 @@ def has_incompatible_cudagraph_ops(gm):
|
|||
for node in gm.graph.nodes:
|
||||
if str(node.target) in forbidden_set:
|
||||
return True
|
||||
if hasattr(node.target, "tags"):
|
||||
if torch.Tag.dynamic_output_shape in node.target.tags:
|
||||
return True
|
||||
if torch.Tag.data_dependent_output in node.target.tags:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ _side_effectful_functions: Set[Callable] = {
|
|||
torch._assert,
|
||||
torch._assert_async,
|
||||
_ops.aten._assert_async.msg,
|
||||
_ops.aten._assert_scalar.default,
|
||||
_ops.aten.copy_.default,
|
||||
_ops.aten.sym_constrain_range.default,
|
||||
_ops.aten.sym_constrain_range_for_size.default,
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
|||
"_assert_async", # no return
|
||||
"_assert_async.msg", # no return
|
||||
"_cslt_sparse_mm_search", # returns an int
|
||||
"_assert_scalar", # no return
|
||||
"_dimI", # returns an int
|
||||
"_dimV", # returns an int
|
||||
"_has_same_storage_numel", # returns a boolean
|
||||
|
|
|
|||
Loading…
Reference in a new issue