From b6028acfa46363c1d3262a1522741a06c307843f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 9 Jan 2024 11:09:23 -0800 Subject: [PATCH] Add _assert_scalar and teach Inductor to codegen it (#114148) Inductor codegen for `_assert_async` is currently disabled because we don't really understand how to codegen `scalar_to_tensor` on a Sympy expression. I initially tried to see if I could get this to work, but I got into some weird problem involving stride sorting, so I decided to fix it properly by not going through a tensor. So we introduce an `_assert_scalar` which takes a scalar as an argument, avoiding needing to turn a SymBool into a tensor before asserting on it. I also add `_functional_assert_scalar` for good luck, although this doesn't do anything right now because https://github.com/pytorch/pytorch/pull/104203 still hasn't been landed. I need to customize the codegen for this operator, so I decide to directly implement it in Inductor, rather than trying to treat it as a generic ExternKernel. This leads to the new AssertScalar IR node. This is written carefully so that it doesn't get DCE'd by Inductor. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/114148 Approved by: https://github.com/jansel --- aten/src/ATen/native/TensorCompare.cpp | 11 +++++ aten/src/ATen/native/native_functions.yaml | 8 ++++ test/dynamo/test_repros.py | 2 +- ...asDecompTest.test_has_decomposition.expect | 1 + test/export/test_export.py | 5 ++- test/inductor/test_cudagraph_trees.py | 22 ++++++++++ torch/_dynamo/output_graph.py | 7 +--- torch/_inductor/ir.py | 40 +++++++++++++++++++ torch/_inductor/lowering.py | 8 ++++ torch/_inductor/utils.py | 10 +++++ torch/fx/node.py | 1 + torchgen/native_function_generation.py | 1 + 12 files changed, 109 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 21d8b89707a..f7a2d0f7668 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include #include #include @@ -421,6 +423,15 @@ void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) { TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed"); } +void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) { + TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed"); +} + +Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) { + _assert_scalar(scalar, assert_msg); + return dep_token.clone(); +} + Tensor _functional_assert_async_msg_cpu( const Tensor& self, c10::string_view assert_msg, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 76d21edf9a5..4960417abdb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -175,6 +175,14 @@ CPU: _assert_async_msg_cpu CUDA: _assert_async_msg_cuda +- func: _assert_scalar(Scalar self, str assert_msg) -> () + dispatch: + CompositeExplicitAutograd: _assert_scalar + +- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor + dispatch: + CompositeExplicitAutograd: _functional_assert_scalar + - func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor dispatch: CPU: _functional_assert_async_msg_cpu diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 9ae10bfa0b8..4445ba174b2 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -3784,7 +3784,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_deferred_runtime_asserts(self): - @torch.compile(backend="aot_eager", fullgraph=True) + @torch.compile(fullgraph=True) def f(x): y = x.item() torch._check_is_size(y) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 7432b8c52a0..863d9cf005d 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -334,6 +334,7 @@ aten::_foreach_zero aten::_foreach_zero.out aten::_foreach_zero_ aten::_functional_assert_async.msg +aten::_functional_assert_scalar aten::_functional_sym_constrain_range aten::_functional_sym_constrain_range_for_size aten::_fused_adam diff --git a/test/export/test_export.py b/test/export/test_export.py index 46da2135c6a..b03a1b2401f 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1254,7 +1254,10 @@ def forward(self, arg_0): ep = export(M(), (torch.tensor(1), torch.ones(4, 5))) - with self.assertRaisesRegex(RuntimeError, r"Deferred runtime assertion failed -i0 <= 0"): + with self.assertRaisesRegex( + RuntimeError, + r"(Deferred runtime assertion failed -i0 <= 0|_local_scalar_dense is outside of inline constraint \[0, inf\])" + ): _ = ep(torch.tensor(-1), torch.randn(4, 5)) self.assertTrue( diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index f73770dc5e4..24f6e712238 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -1336,6 +1336,28 @@ if HAS_CUDA and not TEST_WITH_ASAN: out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) self.assertFalse(self.get_manager().new_graph_id().id == 0) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_incompatible_cudagraph_ops_item(self): + @torch.compile(mode="reduce-overhead") + def foo(x): + return x.item() + + self.assertEqual(foo(torch.tensor(3.0, device="cuda")), 3.0) + self.assertEqual(foo(torch.tensor(6.0, device="cuda")), 6.0) + + @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) + def test_incompatible_cudagraph_ops_nonzero(self): + @torch.compile(mode="reduce-overhead") + def foo(x): + return x.nonzero() + + self.assertEqual( + foo(torch.tensor([1, 0, 2], device="cuda")), torch.tensor([[0], [2]]) + ) + self.assertEqual( + foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]]) + ) + def test_storage_access_error(self): x = torch.rand([4], device="cuda") torch._C._set_storage_access_error_msg(x, "custom error msg") diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ef2045845ab..bb272903388 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1345,15 +1345,12 @@ class OutputGraph(Checkpointable[OutputGraphState]): res = sympy_interp( PythonReferenceAnalysis, symbol_to_proxy, ra.expr ).node - res2 = self.graph.call_function( - torch.ops.aten.scalar_tensor.default, (res,) - ) self.graph.call_function( - torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, # TODO: use ra.msg here, but it's pretty # useless right now ( - res2, + res, f"Deferred runtime assertion failed {ra.expr}", ), ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index ca3d2663f85..5db8a0bec1d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -127,6 +127,7 @@ def validate_ir(node_or_nodes): ( torch._inductor.ir.ExpandView, DynamicScalar, + AssertScalar, TensorBox, sympy.Symbol, sympy.logic.boolalg.Boolean, @@ -4391,6 +4392,45 @@ class DynamicScalar(ExternKernel): wrapper.codegen_dynamic_scalar(self) +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, scalar, msg): + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(torch.device("cpu")), # type: ignore[arg-type] + # InputsKernel(inputs) + [], + ) # type: ignore[arg-type] + self.scalar = scalar + self.msg = msg + + def has_side_effects(self): + return True + + def get_unbacked_symbol_uses(self): + return free_unbacked_symbols(self.scalar) + + def codegen(self, wrapper): + assert not V.graph.cpp_wrapper, "NYI" + wrapper.writeline( + f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:" + ) + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + @dataclasses.dataclass class ExternKernelNode: name: str diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 35c57df5c81..36341720bd9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2563,6 +2563,14 @@ def _local_scalar_dense(data): return sym +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + buffer = ir.AssertScalar(data, msg) + # This buffer isn't used by anyone (it returns None), so we must explicitly register it + buffer.name = V.graph.register_buffer(buffer) + return buffer + + def _full(fill_value, device, dtype, size): value = fill_value if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 6b71181001a..973bca92e04 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -586,6 +586,11 @@ def has_incompatible_cudagraph_ops(gm): "run_and_save_rng_state", "run_with_rng_state", "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", } if torch.are_deterministic_algorithms_enabled(): forbidden_set.update( @@ -606,6 +611,11 @@ def has_incompatible_cudagraph_ops(gm): for node in gm.graph.nodes: if str(node.target) in forbidden_set: return True + if hasattr(node.target, "tags"): + if torch.Tag.dynamic_output_shape in node.target.tags: + return True + if torch.Tag.data_dependent_output in node.target.tags: + return True return False diff --git a/torch/fx/node.py b/torch/fx/node.py index 0c7b8fe9865..616b3888f89 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -41,6 +41,7 @@ _side_effectful_functions: Set[Callable] = { torch._assert, torch._assert_async, _ops.aten._assert_async.msg, + _ops.aten._assert_scalar.default, _ops.aten.copy_.default, _ops.aten.sym_constrain_range.default, _ops.aten.sym_constrain_range_for_size.default, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 2a276115b79..79e20fa08d2 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -51,6 +51,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ "_assert_async", # no return "_assert_async.msg", # no return "_cslt_sparse_mm_search", # returns an int + "_assert_scalar", # no return "_dimI", # returns an int "_dimV", # returns an int "_has_same_storage_numel", # returns a boolean