From 53ab82d8f5c1f3f12b956cdd8745bf17262ef97e Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:21 -0300 Subject: [PATCH] Implement `generator.throw(exception)` (#144424) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144424 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422, #144423 --- test/dynamo/test_generator.py | 394 ++++++++++++++++++++++++++- torch/_dynamo/exc.py | 8 + torch/_dynamo/variables/functions.py | 100 +++++++ 3 files changed, 501 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 8aa86eb4bb3..f0e38d5f4a6 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -7,7 +7,7 @@ from collections import OrderedDict import torch import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.exc import Unsupported +from torch._dynamo.exc import InternalTorchDynamoError, Unsupported from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -977,6 +977,255 @@ class TestGeneratorClose(GeneratorTestsBase): self.assertEqual(z, 2) +class TestGeneratorThrow(GeneratorTestsBase): + def test_throw(self): + def whoo(t): + try: + yield t.sin() + except RuntimeError: + yield t.cos() + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.cos()) + + @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE") + def test_throw_with_finally(self): + z = 0 + + def whoo(): + nonlocal z + z = 0 + try: + try: + yield 1 + except ValueError: + yield 2 + finally: + z += 2 + except ValueError: + z += 33 + yield 4 + finally: + z += 1 + z += 10 + + def f(x): + gen = whoo() + next(gen) + gen.throw(ValueError) + return x.sin() + + self._compile_check(f) + self.assertEqual(z, 3) + + def test_throw_without_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + z += 10 + except RuntimeError: + z += 100 + yield t.cos() + z += 1_000 + z += 10_000 + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.cos()) + self.assertEqual(z, 101) + + def test_throw_three_arguments(self): + def whoo(t): + try: + yield t.sin() + except ValueError: + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(ValueError, "Error", None) + return a + b + + t = torch.randn(2) + with self.assertRaises(InternalTorchDynamoError): + fn(t) + + def test_throw_no_yield_after_throw(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + finally: + z += 100 + + def fn(t): + gen = whoo(t) + a = next(gen) + try: + gen.throw(ValueError) + except StopIteration: + return a + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(z, 111) + self.assertEqual(y, t.sin()) + + def test_throw_not_catch(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + yield t.cos() + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + with self.assertRaises(RuntimeError): + fn(t) + + def test_throw_raise_difference_exc(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError as e: + z += 10 + raise RuntimeError from e + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(ValueError) + return a + b + + t = torch.randn(2) + with self.assertRaises(RuntimeError): + fn(t) + + def test_throw_yield_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except RuntimeError: + z += 10 + yield t.cos() + finally: + z += 100 + yield t.tan() # RuntimeError: generator ignored GeneratorExit + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + with self.assertRaises(Unsupported): + fn(t) + + @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE") + def test_throw_try_except_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + yield t.cos() + except RuntimeError: + z += 100 + yield t.tan() + finally: + z += 1000 + z += 10_000 + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.tan()) + self.assertEqual(z, 1 + 100 + 1000) + + def test_exception_context_with_yield(self): + def f(): + yield + + def fn(t): + gen = f() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError: + z = 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + return t.sin() + + self._compile_check(fn) + + class GeneratorCloseCPythonTests(GeneratorTestsBase): # Taken from commit # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 @@ -1118,6 +1367,149 @@ class GeneratorCloseCPythonTests(GeneratorTestsBase): fn(t) +class GeneratorThrowCpythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + @unittest.expectedFailure + def test_exception_context_with_yield(self): + def f(): + try: + raise KeyError("a") + except Exception: + yield + + def fn(t): + gen = f() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError as e: + context = e.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_exception_context_with_yield_inside_generator(self): + # Check that the context is also available from inside the generator + # with yield, as opposed to outside. + def f(): + z = 0 + try: + raise KeyError("a") + except Exception: + try: + yield + except Exception as exc: + z = 1 + assert type(exc) == ValueError + context = exc.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + yield "b" + finally: + assert z == 1 + + def fn(t): + gen = f() + gen.send(None) + actual = gen.throw(ValueError) + # This ensures that the assertions inside were executed. + assert actual == "b" + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_exception_context_with_yield_from(self): + def f(): + yield + + def g(): + try: + raise KeyError("a") + except Exception: + yield from f() + + def fn(t): + gen = g() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError as e: + context = e.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + @unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW") + @unittest.expectedFailure + def test_exception_context_with_yield_from_with_context_cycle(self): + # Check trying to create an exception context cycle: + # https://bugs.python.org/issue40696 + has_cycle = None + + def f(): + yield + + def g(exc): + nonlocal has_cycle + try: + raise exc + except Exception: + try: + yield from f() + except Exception as exc: + has_cycle = exc is exc.__context__ + yield + + def fn(t): + exc = KeyError("a") + gen = g(exc) + gen.send(None) + gen.throw(exc) + # This also distinguishes from the initial has_cycle=None. + assert has_cycle is False + return t.sin() + + self._compile_check(fn) + + def test_throw_after_none_exc_type(self): + def g(): + try: + raise KeyError + except KeyError: + pass + + try: + yield + except Exception: + raise RuntimeError # noqa: B904 + + def fn(t): + gen = g() + gen.send(None) + z = 0 + try: + gen.throw(ValueError) + except RuntimeError: + z += 1 + except Exception: + raise AssertionError # noqa: B904 + assert z == 1 + return t.sin() + + self._compile_check(fn) + + class GeneratorCPythonTests(GeneratorTestsBase): # Taken from commit # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index a2f5f938bdc..aaa51f0a53c 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -309,6 +309,14 @@ observed_exception_map = { } +def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]: + if exc_type not in observed_exception_map: + observed_exception_map[exc_type] = type( + f"Observed{exc_type.__name__}Error", (ObservedException,), {} + ) + return observed_exception_map[exc_type] + + def raise_observed_exception( exc_type: type[Exception], tx: InstructionTranslatorBase, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 293ba87a516..d51be98f69f 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -16,7 +16,9 @@ import torch from .. import polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( + get_dynamo_observed_exception, handle_observed_exception, + IncorrectUsage, InfiniteGeneratorError, ObservedException, ObservedGeneratorExit, @@ -604,6 +606,104 @@ class LocalGeneratorObjectVariable(VariableTracker): # https://github.com/python/cpython/pull/104771 assert tracer.symbolic_result is not None return tracer.symbolic_result + elif name == "throw": + # * Raises an exception at the point where the generator was paused, and + # returns the next value yielded by the generator. + # * If the generator exits without yielding, raise StopIteration + # * If the generator function does not catch the passed-in exception, + # or raises a different exception, then that exception propagates to the caller. + + if len(args) > 1: + raise IncorrectUsage( + "the (type, exc, tb) signature of throw() is deprecated, " + "use the single-arg signature instead." + ) + + # Setup the exception table and jump target in case of try...finally + tracer = self._get_inline_tracer(tx) + try: + self._setup_exception(tx, args[0]) + except ObservedException: + # propagate the exception back to the parent caller + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + raise + + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let’s walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + except ObservedException: + # Propagate anything else back to the parent caller + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + else: + raise_observed_exception(RuntimeError, tracer) + return retval super().call_method(tx, name, args, kwargs)