From 0ddc653d6772bb8d8d6f422a226309378b61b9a5 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 16:46:24 -0300 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- test/dynamo/test_ctx_manager.py | 431 ++++++++--------- test/dynamo/test_exceptions.py | 80 +--- test/dynamo/test_functions.py | 17 + test/dynamo/test_sys.py | 4 + test/dynamo/test_unittest.py | 608 ------------------------ torch/_dynamo/exc.py | 9 - torch/_dynamo/symbolic_convert.py | 34 +- torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/builder.py | 8 +- torch/_dynamo/variables/builtin.py | 9 +- torch/_dynamo/variables/functions.py | 5 - torch/_dynamo/variables/user_defined.py | 7 +- 12 files changed, 262 insertions(+), 952 deletions(-) delete mode 100644 test/dynamo/test_unittest.py diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 874b3a1338d..a3e1162a996 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -9,18 +9,12 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import InternalTorchDynamoError -from torch._dynamo.testing import ( - EagerAndRecordGraphs, - normalize_gm, - same, - skipIfNotPy311, -) +from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same from torch._dynamo.utils import counters from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - make_dynamo_test, parametrize, TEST_WITH_ROCM, ) @@ -1750,11 +1744,10 @@ class GraphModule(torch.nn.Module): class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase): def setUp(self): - self._old = torch._dynamo.config.enable_trace_contextlib torch._dynamo.config.enable_trace_contextlib = True def tearDown(self): - torch._dynamo.config.enable_trace_contextlib = self._old + torch._dynamo.config.enable_trace_contextlib = False def test_ctx_basic0(self): @contextlib.contextmanager @@ -2700,9 +2693,9 @@ class GraphModule(torch.nn.Module): class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py - # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py + # https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70 - @make_dynamo_test + @unittest.expectedFailure def test_contextmanager_plain(self): state = [] @@ -2712,14 +2705,24 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): yield 42 state.append(999) - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + y = t.sum() + with woohoo() as x: + assert state == [1] + assert x == 42 + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + y += x + return y + + t = torch.randn(2, 3) + y = fn(t) self.assertEqual(state, [1, 42, 999]) + self.assertEqual(y, t.sum() + 42) @unittest.expectedFailure - @make_dynamo_test def test_contextmanager_finally(self): state = [] @@ -2731,56 +2734,121 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): finally: state.append(999) - with self.assertRaises(ZeroDivisionError): - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + _y = t.sum() + with self.assertRaises(ZeroDivisionError): + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError + + fn(torch.randn(2, 3)) self.assertEqual(state, [1, 42, 999]) @unittest.expectedFailure - @make_dynamo_test def test_contextmanager_traceback(self): @contextmanager def f(): yield - try: - with f(): - 1 / 0 - except ZeroDivisionError as e: - frames = traceback.extract_tb(e.__traceback__) + frames = [] + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal frames + _y = t.sum() + try: + with f(): + 1 / 0 + except ZeroDivisionError as e: + frames = traceback.extract_tb(e.__traceback__) + + fn(torch.randn(2, 3)) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "1/0") - - # Repeat with RuntimeError (which goes through a different code path) - try: - with f(): - raise NotImplementedError(42) - except NotImplementedError as e: - frames = traceback.extract_tb(e.__traceback__) - - self.assertEqual(len(frames), 1) - self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "raise NotImplementedError(42)") + self.assertEqual(frames[0].line, "1 / 0") + + @unittest.expectedFailure + def test_contextmanager_traceback2(self): + @contextmanager + def f(): + yield + + # Repeat with RuntimeError (which goes through a different code path) + class RuntimeErrorSubclass(RuntimeError): + pass + + frames = [] + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal frames + _y = t.sum() + try: + with f(): + raise RuntimeErrorSubclass(42) + except RuntimeErrorSubclass as e: + frames = traceback.extract_tb(e.__traceback__) + + fn(torch.randn(2, 3)) + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].name, "test_contextmanager_traceback") + self.assertEqual(frames[0].line, "raise RuntimeErrorSubclass(42)") + + @unittest.expectedFailure + def test_contextmanager_traceback3(self): + @contextmanager + def f(): + yield + + frames = [] + + class StopIterationSubclass(StopIteration): + pass + + for stop_exc in ( + StopIteration("spam"), + StopIterationSubclass("spam"), + ): + with self.subTest(type=type(stop_exc)): + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal frames + _y = t.sum() + try: + with f(): + raise stop_exc + except type(stop_exc) as e: + self.assertIs(e, stop_exc) + frames = traceback.extract_tb(e.__traceback__) + else: + self.fail(f"{stop_exc} was suppressed") + + fn(torch.randn(2, 3)) + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].name, "test_contextmanager_traceback") + self.assertEqual(frames[0].line, "raise stop_exc") @unittest.expectedFailure - @make_dynamo_test def test_contextmanager_no_reraise(self): @contextmanager def whee(): yield - ctx = whee() - ctx.__enter__() - # Calling __exit__ should not result in an exception - self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + ctx = whee() + ctx.__enter__() + # Calling __exit__ should not result in an exception + self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) + return t.sum() + + fn(torch.randn(2, 3)) @unittest.expectedFailure - @make_dynamo_test def test_contextmanager_trap_yield_after_throw(self): @contextmanager def whoo(): @@ -2789,12 +2857,49 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): except Exception: yield - ctx = whoo() - ctx.__enter__() - self.assertRaises(RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None) + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(TypeError, TypeError("foo"), None) + return t.sum() + + fn(torch.randn(2, 3)) + + @unittest.expectedFailure + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + return t.sum() + + fn(torch.randn(2, 3)) + + @unittest.expectedFailure + def test_contextmanager_trap_second_yield(self): + @contextmanager + def whoo(): + yield + yield + + @torch.compile(backend="eager", fullgraph=True) + def f(t): + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(None, None, None) + + f(torch.randn(2)) @unittest.expectedFailure - @make_dynamo_test def test_contextmanager_except(self): state = [] @@ -2807,58 +2912,18 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): state.append(e.args[0]) self.assertEqual(state, [1, 42, 999]) - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError(999) + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + + fn(torch.randn(2, 3)) self.assertEqual(state, [1, 42, 999]) @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_except_stopiter(self): - @contextmanager - def woohoo(): - yield - - class StopIterationSubclass(StopIteration): - pass - - for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")): - with self.subTest(type=type(stop_exc)): - try: - with woohoo(): - raise stop_exc - except Exception as ex: - self.assertIs(ex, stop_exc) - else: - self.fail(f"{stop_exc} was suppressed") - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_except_pep479(self): - code = """\ -from __future__ import generator_stop -from contextlib import contextmanager -@contextmanager -def woohoo(): - yield -""" - locals = {} - exec(code, locals, locals) - woohoo = locals["woohoo"] - - stop_exc = StopIteration("spam") - try: - with woohoo(): - raise stop_exc - except Exception as ex: - self.assertIs(ex, stop_exc) - else: - self.fail("StopIteration was suppressed") - - @unittest.expectedFailure - @make_dynamo_test def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self): @contextmanager def test_issue29692(): @@ -2867,78 +2932,73 @@ def woohoo(): except Exception as exc: raise RuntimeError("issue29692:Chained") from exc - try: - with test_issue29692(): - raise ZeroDivisionError - except Exception as ex: - self.assertIs(type(ex), RuntimeError) - self.assertEqual(ex.args[0], "issue29692:Chained") - self.assertIsInstance(ex.__cause__, ZeroDivisionError) + @torch.compile(backend="eager", fullgraph=True) + def f(t): + try: + with test_issue29692(): + raise ZeroDivisionError + except Exception as ex: + self.assertIs(type(ex), RuntimeError) + self.assertEqual(ex.args[0], "issue29692:Chained") + self.assertIsInstance(ex.__cause__, ZeroDivisionError) - try: - with test_issue29692(): - raise StopIteration("issue29692:Unchained") - except Exception as ex: - self.assertIs(type(ex), StopIteration) - self.assertEqual(ex.args[0], "issue29692:Unchained") - self.assertIsNone(ex.__cause__) + try: + with test_issue29692(): + raise StopIteration("issue29692:Unchained") + except Exception as ex: + self.assertIs(type(ex), StopIteration) + self.assertEqual(ex.args[0], "issue29692:Unchained") + self.assertIsNone(ex.__cause__) + + f(torch.randn(2)) @unittest.expectedFailure - @make_dynamo_test - def _create_contextmanager_attribs(self): - def attribs(**kw): - def decorate(func): - for k, v in kw.items(): - setattr(func, k, v) - return func - - return decorate - + def test_contextmanager_wrap_runtimeerror(self): @contextmanager - @attribs(foo="bar") - def baz(spam): - """Whee!""" + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f"caught {exc}") from exc - return baz + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + + fn(torch.randn(2, 3)) + + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_attribs(self): - baz = self._create_contextmanager_attribs() - self.assertEqual(baz.__name__, "baz") - self.assertEqual(baz.foo, "bar") - - @make_dynamo_test def test_keywords(self): # Ensure no keyword arguments are inhibited @contextmanager def woohoo(self, func, args, kwds): yield (self, func, args, kwds) - with woohoo(self=11, func=22, args=33, kwds=44) as target: - self.assertEqual(target, (11, 22, 33, 44)) + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + with woohoo(self=11, func=22, args=33, kwds=44) as target: + self.assertEqual(target, (11, 22, 33, 44)) + + fn(torch.randn(2, 3)) @unittest.expectedFailure - @make_dynamo_test - def test_param_errors(self): - @contextmanager - def woohoo(a, *, b): - yield - - with self.assertRaises(TypeError): - woohoo() - with self.assertRaises(TypeError): - woohoo(3, 5) - with self.assertRaises(TypeError): - woohoo(b=3) - - @unittest.expectedFailure - @make_dynamo_test def test_recursive(self): depth = 0 + ncols = 0 @contextmanager def woohoo(): + nonlocal ncols + ncols += 1 nonlocal depth before = depth depth += 1 @@ -2951,70 +3011,15 @@ def woohoo(): if depth < 10: recursive() - recursive() + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + recursive() + + fn(torch.randn(2, 3)) + + self.assertEqual(ncols, 10) self.assertEqual(depth, 0) - @skipIfNotPy311 - @make_dynamo_test - def test_contextmanager_trap_no_yield(self): - @contextmanager - def whoo(): - if False: - yield - - ctx = whoo() - with self.assertRaises(RuntimeError): - ctx.__enter__() - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_trap_second_yield(self): - @contextmanager - def whoo(): - yield - yield - - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(None, None, None) - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_wrap_runtimeerror(self): - @contextmanager - def woohoo(): - try: - yield - except Exception as exc: - raise RuntimeError(f"caught {exc}") from exc - - with self.assertRaises(RuntimeError): - with woohoo(): - 1 / 0 - - # If the context manager wrapped StopIteration in a RuntimeError, - # we also unwrap it, because we can't tell whether the wrapping was - # done by the generator machinery or by the generator itself. - with self.assertRaises(StopIteration): - with woohoo(): - raise StopIteration - - @unittest.expectedFailure - @make_dynamo_test - def test_contextmanager_non_normalised(self): - @contextmanager - def whoo(): - try: - yield - except RuntimeError: - raise SyntaxError # noqa: B904 - - ctx = whoo() - ctx.__enter__() - with self.assertRaises(SyntaxError): - ctx.__exit__(RuntimeError, None, None) - instantiate_parametrized_tests(CtxManagerTests) instantiate_parametrized_tests(ContextlibContextManagerTests) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index c296f57a7e5..8ef8fd31683 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -1,7 +1,5 @@ # Owner(s): ["module: dynamo"] -import contextlib -import sys import unittest import torch @@ -456,6 +454,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): x = torch.randn(4) self.assertEqual(fn(x), opt_fn(x)) + @unittest.expectedFailure @make_dynamo_test def test_raise_set___context__(self): try: @@ -472,73 +471,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): self.assertIsNone(exc2.__context__) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") - @make_dynamo_test - def test_raise_match(self): - a = AttributeError - b = BytesWarning - c = ConnectionError - d = DeprecationWarning - e = Exception - - def fn(a, b): - try: - raise a - finally: - raise b - - def fix_exc_context(frame_exc, new_exc, old_exc): - # slightly change from ExitStack.fix_exc_context function - while 1: - exc_context = new_exc.__context__ - if exc_context is None or exc_context is old_exc: - return - if exc_context is frame_exc: - break - new_exc = exc_context - new_exc.__context__ = old_exc - - @contextlib.contextmanager - def ctx(): - try: - yield - finally: - frame_exc = prev_exc = sys.exc_info() - args = [(d, c), (b, a)] - for x, y in args: - try: - fn(x, y) - except BaseException: - new_exc = sys.exc_info() - fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1]) - prev_exc = new_exc - - try: - fixed_ctx = prev_exc[1].__context__ - raise prev_exc[1] - except BaseException: - prev_exc[1].__context__ = fixed_ctx - raise - - try: - with ctx(): - raise e - except Exception as exc: - self.assertIsInstance(exc, a) - self.assertIsInstance(exc.__context__, b) - self.assertIsInstance(exc.__context__.__context__, c) - self.assertIsInstance(exc.__context__.__context__.__context__, d) - self.assertIsInstance( - exc.__context__.__context__.__context__.__context__, e - ) - - @make_dynamo_test - def test_raise_ZeroDivisionError(self): - try: - 1 / 0 - except Exception: - pass - class CPythonExceptionTests(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py @@ -562,6 +494,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): self.assertIsNone(e.__context__) self.assertIsNone(e.__cause__) + @unittest.expectedFailure @make_dynamo_test def testChainingDescriptors(self): try: @@ -581,6 +514,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): e.__suppress_context__ = False self.assertFalse(e.__suppress_context__) + @unittest.expectedFailure @make_dynamo_test def test_context_of_exception_in_try_and_finally(self): try: @@ -596,6 +530,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(exc, ve) self.assertIs(exc.__context__, te) + @unittest.expectedFailure @make_dynamo_test def test_context_of_exception_in_except_and_finally(self): try: @@ -615,6 +550,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(exc.__context__, ve) self.assertIs(exc.__context__.__context__, te) + @unittest.expectedFailure @make_dynamo_test def test_context_of_exception_in_else_and_finally(self): try: @@ -634,7 +570,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(exc, oe) self.assertIs(exc.__context__, ve) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") + @unittest.expectedFailure @make_dynamo_test def test_raise_does_not_create_context_chain_cycle(self): A = AssertionError @@ -673,7 +609,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(c.__context__, b) self.assertIsNone(b.__context__) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") + @unittest.expectedFailure @make_dynamo_test def test_no_hang_on_context_chain_cycle1(self): # See issue 25782. Cycle in context chain. @@ -729,7 +665,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(b.__context__, a) self.assertIs(a.__context__, c) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") + @unittest.expectedFailure @make_dynamo_test def test_no_hang_on_context_chain_cycle3(self): # See issue 25782. Longer context chain with cycle. diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 479c2576e43..bf37ae33acf 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2619,6 +2619,23 @@ class GraphModule(torch.nn.Module): else: return x.cos() + @unittest.expectedFailure + def test_getattr_metaclass(self): + class Meta(type): + def __getattr__(cls, name): + return len(name) + + class C(metaclass=Meta): + attr = 123 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return t + C.attr + C.dynamic_attr + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 123 + 12) + def test_two_point_iter(self): def fn(x, y): it = map(lambda n: n + 1, range(6)) diff --git a/test/dynamo/test_sys.py b/test/dynamo/test_sys.py index e3cf5b439a6..2f7bd717869 100644 --- a/test/dynamo/test_sys.py +++ b/test/dynamo/test_sys.py @@ -37,6 +37,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase): def test_sys_exception_no_exception(self): self.assertEqual(sys.exception(), None) + @unittest.expectedFailure @make_dynamo_test def test_exc_info_with_exception_instance(self): def f(): @@ -53,6 +54,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(exc_info[1], e) self.assertIs(exc_info[2], e.__traceback__) + @unittest.expectedFailure @make_dynamo_test def test_exc_info_with_exception_type(self): def f(): @@ -69,6 +71,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase): self.assertIs(exc_info[1], e) self.assertIs(exc_info[2], e.__traceback__) + @unittest.expectedFailure @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_sys_exception_with_exception_instance(self): @@ -84,6 +87,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase): self.assertIsInstance(e, ValueError) self.assertIs(exc, e) + @unittest.expectedFailure @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_sys_exception_with_exception_type(self): diff --git a/test/dynamo/test_unittest.py b/test/dynamo/test_unittest.py deleted file mode 100644 index 2dc1a766d76..00000000000 --- a/test/dynamo/test_unittest.py +++ /dev/null @@ -1,608 +0,0 @@ -# Owner(s): ["module: dynamo"] -import sys -import unittest -import warnings -from itertools import product - -import torch -import torch._dynamo.test_case -from torch.testing._internal.common_utils import make_dynamo_test - - -class Test_Assertions(torch._dynamo.test_case.TestCase): - # Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py - # https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py - def setUp(self): - if sys.version_info < (3, 11): - self.skipTest( - "Tracing the unittest module needs exception table (Python 3.11+) to work" - ) - super().setUp() - - @make_dynamo_test - def test_AlmostEqual(self): - self.assertAlmostEqual(1.00000001, 1.0) - self.assertNotAlmostEqual(1.0000001, 1.0) - self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0 - ) - - self.assertAlmostEqual(1.1, 1.0, places=0) - self.assertRaises( - self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1 - ) - - self.assertAlmostEqual(0, 0.1 + 0.1j, places=0) - self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1) - self.assertRaises( - self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1 - ) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0 - ) - - self.assertAlmostEqual(float("inf"), float("inf")) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf") - ) - - @make_dynamo_test - def test_AmostEqualWithDelta(self): - self.assertAlmostEqual(1.1, 1.0, delta=0.5) - self.assertAlmostEqual(1.0, 1.1, delta=0.5) - self.assertNotAlmostEqual(1.1, 1.0, delta=0.05) - self.assertNotAlmostEqual(1.0, 1.1, delta=0.05) - - self.assertAlmostEqual(1.0, 1.0, delta=0.5) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5 - ) - - self.assertRaises( - self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05 - ) - self.assertRaises( - self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5 - ) - - self.assertRaises( - TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2 - ) - self.assertRaises( - TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2 - ) - - @make_dynamo_test - def test_assertRaises(self): - def _raise(e): - raise e - - self.assertRaises(KeyError, _raise, KeyError) - self.assertRaises(KeyError, _raise, KeyError("key")) - try: - self.assertRaises(KeyError, lambda: None) - except self.failureException as e: - self.assertIn("KeyError not raised", str(e)) - else: - self.fail("assertRaises() didn't fail") - try: - self.assertRaises(KeyError, _raise, ValueError) - except ValueError: - pass - else: - self.fail("assertRaises() didn't let exception pass through") - with self.assertRaises(KeyError) as cm: - try: - raise KeyError - except Exception as e: - exc = e - raise - self.assertIs(cm.exception, exc) - - with self.assertRaises(KeyError): - raise KeyError("key") - try: - with self.assertRaises(KeyError): - pass - except self.failureException as e: - self.assertIn("KeyError not raised", str(e)) - else: - self.fail("assertRaises() didn't fail") - try: - with self.assertRaises(KeyError): - raise ValueError - except ValueError: - pass - else: - self.fail("assertRaises() didn't let exception pass through") - - @unittest.expectedFailure - @make_dynamo_test - def testAssertNotRegex(self): - self.assertNotRegex("Ala ma kota", r"r+") - try: - self.assertNotRegex("Ala ma kota", r"k.t", "Message") - except self.failureException as e: - self.assertIn("Message", e.args[0]) - else: - self.fail("assertNotRegex should have failed.") - - -class TestLongMessage(torch._dynamo.test_case.TestCase): - - """Test that the individual asserts honour longMessage. - This actually tests all the message behaviour for - asserts that use longMessage.""" - - def setUp(self): - if sys.version_info < (3, 11): - return self.skipTest( - "Tracing the unittest module needs exception table (Python 3.11+) to work" - ) - super().setUp() - - class TestableTestFalse(unittest.TestCase): - longMessage = False - failureException = self.failureException - - def testTest(self): - pass - - class TestableTestTrue(unittest.TestCase): - longMessage = True - failureException = self.failureException - - def testTest(self): - pass - - self.testableTrue = TestableTestTrue("testTest") - self.testableFalse = TestableTestFalse("testTest") - - def testDefault(self): - self.assertTrue(unittest.TestCase.longMessage) - - def test_formatMsg(self): - self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo") - self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo") - - self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo") - self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo") - - # This blows up if _formatMessage uses string concatenation - self.testableTrue._formatMessage(object(), "foo") - - def test_formatMessage_unicode_error(self): - one = "".join(chr(i) for i in range(255)) - # this used to cause a UnicodeDecodeError constructing msg - self.testableTrue._formatMessage(one, "\uFFFD") - - def assertMessages(self, methodName, args, errors): - """ - Check that methodName(*args) raises the correct error messages. - errors should be a list of 4 regex that match the error when: - 1) longMessage = False and no msg passed; - 2) longMessage = False and msg passed; - 3) longMessage = True and no msg passed; - 4) longMessage = True and msg passed; - """ - - def getMethod(i): - useTestableFalse = i < 2 - if useTestableFalse: - test = self.testableFalse - else: - test = self.testableTrue - return getattr(test, methodName) - - for i, expected_regex in enumerate(errors): - testMethod = getMethod(i) - kwargs = {} - withMsg = i % 2 - if withMsg: - kwargs = {"msg": "oops"} - - with self.assertRaisesRegex( - self.failureException, expected_regex=expected_regex - ): - testMethod(*args, **kwargs) - # with self.assertRaises(self.failureException) as cm: - # testMethod(*args, **kwargs) - # self.assertIn(expected_regex, str(cm.exception)) - - @make_dynamo_test - def testAssertTrue(self): - self.assertMessages( - "assertTrue", - (False,), - [ - "False is not true", - "oops", - "False is not true", - "False is not true : oops", - ], - ) - - @make_dynamo_test - def testAssertFalse(self): - self.assertMessages( - "assertFalse", - (True,), - [ - "True is not false", - "oops", - "True is not false", - "True is not false : oops", - ], - ) - - @make_dynamo_test - def testNotEqual(self): - self.assertMessages( - "assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"] - ) - - @make_dynamo_test - def testAlmostEqual(self): - self.assertMessages( - "assertAlmostEqual", - (1, 2), - [ - r"^1 != 2 within 7 places \(1 difference\)$", - "^oops$", - r"^1 != 2 within 7 places \(1 difference\)$", - r"^1 != 2 within 7 places \(1 difference\) : oops$", - ], - ) - - @make_dynamo_test - def testNotAlmostEqual(self): - self.assertMessages( - "assertNotAlmostEqual", - (1, 1), - [ - "^1 == 1 within 7 places$", - "^oops$", - "^1 == 1 within 7 places$", - "^1 == 1 within 7 places : oops$", - ], - ) - - @make_dynamo_test - def test_baseAssertEqual(self): - self.assertMessages( - "_baseAssertEqual", - (1, 2), - ["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertSequenceEqual(self): - # Error messages are multiline so not testing on full message - # assertTupleEqual and assertListEqual delegate to this method - self.assertMessages( - "assertSequenceEqual", - ([], [None]), - [r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"], - ) - - @make_dynamo_test - def testAssertSetEqual(self): - self.assertMessages( - "assertSetEqual", - (set(), set([None])), # noqa: C405 - ["None$", "^oops$", "None$", "None : oops$"], - ) - - @make_dynamo_test - def testAssertIn(self): - self.assertMessages( - "assertIn", - (None, []), - [ - r"^None not found in \[\]$", - "^oops$", - r"^None not found in \[\]$", - r"^None not found in \[\] : oops$", - ], - ) - - @make_dynamo_test - def testAssertNotIn(self): - self.assertMessages( - "assertNotIn", - (None, [None]), - [ - r"^None unexpectedly found in \[None\]$", - "^oops$", - r"^None unexpectedly found in \[None\]$", - r"^None unexpectedly found in \[None\] : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertDictEqual(self): - self.assertMessages( - "assertDictEqual", - ({}, {"key": "value"}), - [ - r"\+ \{'key': 'value'\}$", - "^oops$", - r"\+ \{'key': 'value'\}$", - r"\+ \{'key': 'value'\} : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertMultiLineEqual(self): - self.assertMessages( - "assertMultiLineEqual", - ("", "foo"), - [r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"], - ) - - @make_dynamo_test - def testAssertLess(self): - self.assertMessages( - "assertLess", - (2, 1), - [ - "^2 not less than 1$", - "^oops$", - "^2 not less than 1$", - "^2 not less than 1 : oops$", - ], - ) - - @make_dynamo_test - def testAssertLessEqual(self): - self.assertMessages( - "assertLessEqual", - (2, 1), - [ - "^2 not less than or equal to 1$", - "^oops$", - "^2 not less than or equal to 1$", - "^2 not less than or equal to 1 : oops$", - ], - ) - - @make_dynamo_test - def testAssertGreater(self): - self.assertMessages( - "assertGreater", - (1, 2), - [ - "^1 not greater than 2$", - "^oops$", - "^1 not greater than 2$", - "^1 not greater than 2 : oops$", - ], - ) - - @make_dynamo_test - def testAssertGreaterEqual(self): - self.assertMessages( - "assertGreaterEqual", - (1, 2), - [ - "^1 not greater than or equal to 2$", - "^oops$", - "^1 not greater than or equal to 2$", - "^1 not greater than or equal to 2 : oops$", - ], - ) - - @make_dynamo_test - def testAssertIsNone(self): - self.assertMessages( - "assertIsNone", - ("not None",), - [ - "^'not None' is not None$", - "^oops$", - "^'not None' is not None$", - "^'not None' is not None : oops$", - ], - ) - - @make_dynamo_test - def testAssertIsNotNone(self): - self.assertMessages( - "assertIsNotNone", - (None,), - [ - "^unexpectedly None$", - "^oops$", - "^unexpectedly None$", - "^unexpectedly None : oops$", - ], - ) - - @make_dynamo_test - def testAssertIs(self): - self.assertMessages( - "assertIs", - (None, "foo"), - [ - "^None is not 'foo'$", - "^oops$", - "^None is not 'foo'$", - "^None is not 'foo' : oops$", - ], - ) - - @make_dynamo_test - def testAssertIsNot(self): - self.assertMessages( - "assertIsNot", - (None, None), - [ - "^unexpectedly identical: None$", - "^oops$", - "^unexpectedly identical: None$", - "^unexpectedly identical: None : oops$", - ], - ) - - @make_dynamo_test - def testAssertRegex(self): - self.assertMessages( - "assertRegex", - ("foo", "bar"), - [ - "^Regex didn't match:", - "^oops$", - "^Regex didn't match:", - "^Regex didn't match: (.*) : oops$", - ], - ) - - @make_dynamo_test - def testAssertNotRegex(self): - self.assertMessages( - "assertNotRegex", - ("foo", "foo"), - [ - "^Regex matched:", - "^oops$", - "^Regex matched:", - "^Regex matched: (.*) : oops$", - ], - ) - - def assertMessagesCM(self, methodName, args, func, errors): - """ - Check that the correct error messages are raised while executing: - with method(*args): - func() - *errors* should be a list of 4 regex that match the error when: - 1) longMessage = False and no msg passed; - 2) longMessage = False and msg passed; - 3) longMessage = True and no msg passed; - 4) longMessage = True and msg passed; - """ - p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"})) - for (cls, kwargs), err in zip(p, errors): - method = getattr(cls, methodName) - with self.assertRaisesRegex(cls.failureException, err): - with method(*args, **kwargs) as cm: # noqa: F841 - func() - - @make_dynamo_test - def testAssertRaises(self): - self.assertMessagesCM( - "assertRaises", - (TypeError,), - lambda: None, - [ - "^TypeError not raised$", - "^oops$", - "^TypeError not raised$", - "^TypeError not raised : oops$", - ], - ) - - @make_dynamo_test - def testAssertRaisesRegex(self): - self.assertMessagesCM( - "assertRaisesRegex", - (TypeError, "unused regex"), - lambda: None, - [ - "^TypeError not raised$", - "^oops$", - "^TypeError not raised$", - "^TypeError not raised : oops$", - ], - ) - - # test error raised but with wrong message - def raise_wrong_message(): - raise TypeError("foo") - - self.assertMessagesCM( - "assertRaisesRegex", - (TypeError, "regex"), - raise_wrong_message, - [ - '^"regex" does not match "foo"$', - "^oops$", - '^"regex" does not match "foo"$', - '^"regex" does not match "foo" : oops$', - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertWarns(self): - self.assertMessagesCM( - "assertWarns", - (UserWarning,), - lambda: None, - [ - "^UserWarning not triggered$", - "^oops$", - "^UserWarning not triggered$", - "^UserWarning not triggered : oops$", - ], - ) - - @unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13") - @make_dynamo_test - def test_assertNotWarns(self): - def warn_future(): - warnings.warn("xyz", FutureWarning, stacklevel=2) - - self.assertMessagesCM( - "_assertNotWarns", - (FutureWarning,), - warn_future, - [ - "^FutureWarning triggered$", - "^oops$", - "^FutureWarning triggered$", - "^FutureWarning triggered : oops$", - ], - ) - - @unittest.expectedFailure - @make_dynamo_test - def testAssertWarnsRegex(self): - # test error not raised - self.assertMessagesCM( - "assertWarnsRegex", - (UserWarning, "unused regex"), - lambda: None, - [ - "^UserWarning not triggered$", - "^oops$", - "^UserWarning not triggered$", - "^UserWarning not triggered : oops$", - ], - ) - - # test warning raised but with wrong message - def raise_wrong_message(): - warnings.warn("foo") - - self.assertMessagesCM( - "assertWarnsRegex", - (UserWarning, "regex"), - raise_wrong_message, - [ - '^"regex" does not match "foo"$', - "^oops$", - '^"regex" does not match "foo"$', - '^"regex" does not match "foo" : oops$', - ], - ) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 16cc13cea51..5168543013a 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -291,13 +291,6 @@ observed_exception_map = { } -def create_dynamo_observed_exception(e: type[Exception]) -> None: - if e not in observed_exception_map: - name = getattr(e, "__name__", str(e)) - internal_exc = type(f"Observed{name}Error", (ObservedException,), {}) - observed_exception_map[e] = internal_exc - - def raise_observed_exception( exc_type: type[Exception], tx: InstructionTranslatorBase, @@ -311,8 +304,6 @@ def raise_observed_exception( # stack and raise the exception. exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type] tx.exn_vt_stack.append(exception_vt) - if exc_type not in observed_exception_map: - create_dynamo_observed_exception(exc_type) raise observed_exception_map[exc_type] diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4dfe5b4ffcb..47ab26a1df4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3242,34 +3242,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase): else: return result - def update_parent_exn_vt_stack(self): - # TODO(anijain2305) - This works but we should probably have a - # global/central data structure for the exception stack. - parent = self.parent - p_exn = parent.exn_vt_stack - c_exn = self.exn_vt_stack - # Prior to this, the parent stack would be extended with all exceptions from - # the child frame. We don't need to append everything as the parent interpreter - # can only match the topmost exception from the child. i.e. - # - # def foo(): - # try: - # raise ValueError - # except ValueError as e: - # raise TypeError from e - # - # def bar(): - # try: - # foo() - # except TypeError: - # # it is impossible for `bar` to match on the first `ValueError` - # # raised by `foo` - # pass - # - # For a test case, check test_exceptions::test_raise_match - if len(c_exn) > len(p_exn): - parent.exn_vt_stack.append(c_exn[-1]) - @staticmethod def build_inline_tracer( parent, @@ -3376,7 +3348,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase): self.run() except exc.ObservedException as e: msg = f"Observed exception DURING INLING {code} : {e}" - self.update_parent_exn_vt_stack() + # TODO(anijain2305) - This works but we should probably have a + # global/central data structure for the exception stack. + parent.exn_vt_stack.extend(self.exn_vt_stack) log.debug(msg) # bubble up the exception to the parent frame. raise @@ -3449,8 +3423,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase): ) self.funcvar = funcvar self.parent = parent - # Propagate any exception on the parent stack - self.exn_vt_stack = list(parent.exn_vt_stack) self.num_calls = parent.num_calls self.symbolic_result = None self.nn_module_stack = parent.nn_module_stack.copy() diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 16c6da007ce..7b459ffcbb9 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -16,6 +16,7 @@ import sys import traceback import types import typing +import unittest from collections import defaultdict from pathlib import Path from typing import Any, Callable, cast, Optional, Union @@ -3147,6 +3148,7 @@ BUILTIN_SKIPLIST = ( random, traceback, linecache, + unittest, ) # third party libraries skiplist is defined by str, because users may not use these libraries. diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e0d92248dc4..890dcc5dd66 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1012,12 +1012,10 @@ class VariableBuilder: elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) - elif value in (traceback.clear_frames,): + elif value is traceback.clear_frames: return TracebackVariable(source=self.source) - elif ( - value in (sys.exc_info,) - or sys.version_info >= (3, 11) - and value is sys.exception + elif value is sys.exc_info or ( + sys.version_info >= (3, 11) and value is sys.exception ): return SysFunctionVariable(value, source=self.source) elif is_function_or_wrapper(value) and inspect.getattr_static( diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 8768f40bc49..2b1a3f1c02c 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -21,7 +21,6 @@ from .. import config, polyfills, variables from ..exc import ( AttributeMutationError, ObservedAttributeError, - raise_observed_exception, unimplemented, Unsupported, UserError, @@ -874,7 +873,7 @@ class BuiltinVariable(VariableTracker): *[x.as_python_constant() for x in args], ) except Exception as exc: - raise_observed_exception(type(exc), tx) + unimplemented(f"constant fold exception: {repr(exc)}") return VariableTracker.build(tx, res) else: @@ -1219,12 +1218,6 @@ class BuiltinVariable(VariableTracker): # Inline the user function return tx.inline_user_function_return(user_func_variable, [arg], {}) - elif isinstance(arg, (variables.ExceptionVariable,)): - if len(arg.args) == 0: - value = f"{arg.exc_type}" - else: - value = ", ".join(a.as_python_constant() for a in arg.args) - return variables.ConstantVariable.create(value=value) def _call_min_max(self, tx: "InstructionTranslator", *args): if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 8ec45898ead..0594706d95c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -633,11 +633,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): def has_closure(self): return self.closure is not None - def const_getattr(self, tx, name): - if name == "__name__": - return self.fn_name.as_python_constant() - return super().const_getattr(tx, name) - def has_self(self): return False diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ae889255652..77cf4917b7a 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -185,7 +185,12 @@ class UserDefinedClassVariable(UserDefinedVariable): try: obj = inspect.getattr_static(self.value, name) except AttributeError: - raise_observed_exception(AttributeError, tx) + if type(self.value) is type: + raise_observed_exception(AttributeError, tx) + else: + # Cannot reason about classes with a custom metaclass + # See: test_functions::test_getattr_metaclass + obj = None if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType): return variables.GetAttrVariable(self, name, source=source)