diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 0e19fadd59c..a71d3468c2b 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -405,6 +405,10 @@ void _assert_async_cpu(const Tensor& self) { TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero"); } +void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) { + TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed"); +} + // Sorting-based algorithm for isin(); used when the number of test elements is large. static void isin_sorting( const Tensor& elements, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 14a9576cfbf..b00918b18c4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -170,6 +170,9 @@ CPU: _assert_async_cpu CUDA: _assert_async_cuda +- func: _assert_async.msg(Tensor self, str assert_msg) -> () + dispatch: + CPU: _assert_async_msg_cpu - func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index e97f8fb1885..c3ac9b4e962 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2583,6 +2583,31 @@ def forward(self, x): ): gm, _ = torch._dynamo.export(f, torch.randn(5, 6), aten_graph=True) + @config.patch(assume_static_by_default=False) + def test_export_persist_assert(self): + def f(x): + assert x.shape[0] > 4, "Shape must be more than 4" + return x.cos() + x.sin() + + gm, guard = torch._dynamo.export( + f, torch.randn(5, 4, 6), aten_graph=True, tracing_mode="symbolic" + ) + + def has_aten_op(gm, op): + for node in gm.graph.nodes: + if node.target == op: + return True + return False + + self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) + + gm.graph.eliminate_dead_code() + gm.recompile() + self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) + + with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): + gm(torch.randn(3, 4, 5)) + def test_access_class_method_from_user_class(self): class A: @classmethod diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 7222a38a6ed..855cba49ed3 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2496,7 +2496,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) self.assertTrue(same(exported(*args), f(*args))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"): exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) def test_not_rewrite_assert_for_other_errors(self): @@ -2521,7 +2521,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) self.assertTrue(same(exported(*args), f(*args))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(RuntimeError, "assertion error"): exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) def test_rewrite_assert_with_non_string_msg(self): diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ba286ac7496..c8c8ac2103a 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -35,6 +35,7 @@ aten::_amp_update_scale aten::_amp_update_scale.out aten::_amp_update_scale_ aten::_assert_async +aten::_assert_async.msg aten::_cdist_backward aten::_cdist_backward.out aten::_cdist_forward diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 622a8989858..37b6e0fb025 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1901,6 +1901,17 @@ class CommonTemplate: with self.assertRaisesRegex(RuntimeError, ""): fn(torch.randn(1, 5)) + def test_inductor_assert(self): + @torch._dynamo.optimize("inductor", dynamic=True) + def fn(a): + assert a.shape[0] >= 2 and a.shape[1] >= 4 + return a.cos() + + inp = torch.randn(2, 4, 6) + torch._dynamo.mark_dynamic(inp, 0) + torch._dynamo.mark_dynamic(inp, 1) + self.assertEqual(fn(inp), inp.cos()) + def test_split(self): def fn(a): t = torch.split(a, 3, -1) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index f332cf012e4..7a5ec9b8652 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -55,7 +55,13 @@ from .source import ( GlobalWeakRefSource, LocalSource, ) -from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs +from .utils import ( + counters, + get_fake_value, + graph_break_dup_warning_checker, + istype, + proxy_args_kwargs, +) from .variables.base import MutableLocal, typestr, VariableTracker from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.builtin import BuiltinVariable @@ -249,12 +255,35 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): self.jump(inst) return - # Manually insert torch._assert instead of python assert and jump over + # TODO maybe should respect DtoH sync intention of users later?? + # Manually insert torch._assert_async instead of python assert and jump over # assert related instructions as we don't need them anymore. + + # if we see Tensor as assert statement, no need to call scalar_tensor + if isinstance(value, TensorVariable): + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((value, error_msg), {}), + ) + self.jump(inst) + return + + scalar_to_tensor_proxy = self.output.create_proxy( + "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) + ) + + scalar_to_tensor = wrap_fx_proxy( + self, + scalar_to_tensor_proxy, + example_value=get_fake_value(scalar_to_tensor_proxy.node, self), + **VariableTracker.propagate([value]), + ) + self.output.create_proxy( "call_function", - torch._assert, - *proxy_args_kwargs((value, error_msg), {}), + torch._assert_async, + *proxy_args_kwargs((scalar_to_tensor, error_msg), {}), ) self.jump(inst) return diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 7a7a1a91419..282c49414a5 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -54,6 +54,13 @@ def _unsafe_view(self, size): return self.view(size) +# TODO: for now, inductor doesn't handle asserts +# because the condition is symbool -> tensor in the graph. +@register_decomposition([aten._assert_async.msg]) +def assert_async_msg_decomp(tensor, msg): + return + + @register_decomposition([aten.clamp]) @pw_cast_for_opmath def clamp(x, min=None, max=None): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index e8dde26d80d..1708e031181 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -295,6 +295,16 @@ def meta_angle_out(self, out): return out.copy_(torch.angle(self)) +@register_meta(aten._assert_async.default) +def assert_async(val): + return + + +@register_meta(aten._assert_async.msg) +def assert_async_meta(val, assert_msg): + return + + # From aten/src/ATen/native/LinearAlgebraUtils.h def squareCheckInputs(self: Tensor, f_name: str): assert ( diff --git a/torch/fx/node.py b/torch/fx/node.py index 5db3e3d3753..c11923caf57 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -32,6 +32,8 @@ Argument = Optional[Union[ _side_effectful_functions: Set[Callable] = { torch._assert, + torch._assert_async, + _ops.aten._assert_async.msg, _ops.aten.copy_.default, _ops.profiler._record_function_enter, _ops.profiler._record_function_enter_new, diff --git a/torch/overrides.py b/torch/overrides.py index 0313c7da42a..9939ac1da00 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -387,7 +387,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.argmin: lambda input: -1, torch.argsort: lambda input, dim=None: -1, torch.asin: lambda input, out=None: -1, - torch._assert_async: lambda input: -1, + torch._assert_async: lambda input, msg: -1, torch.arcsin: lambda input, out=None: -1, torch.asinh: lambda input, out=None: -1, torch.arcsinh: lambda input, out=None: -1, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 11211da68b4..8b7429e9826 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -49,6 +49,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ # All of these operators don't have any tensor like returns FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ "_assert_async", # no return + "_assert_async.msg", # no return "_dimI", # returns an int "_dimV", # returns an int "_has_same_storage_numel", # returns a boolean