Update (base update)

[ghstack-poisoned]
This commit is contained in:
Guilherme Leobas 2025-02-07 16:46:24 -03:00
parent 1e6a7cfd54
commit 0ddc653d67
12 changed files with 262 additions and 952 deletions

View file

@ -9,18 +9,12 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import InternalTorchDynamoError
from torch._dynamo.testing import (
EagerAndRecordGraphs,
normalize_gm,
same,
skipIfNotPy311,
)
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
from torch._dynamo.utils import counters
from torch.nn import functional as F
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
make_dynamo_test,
parametrize,
TEST_WITH_ROCM,
)
@ -1750,11 +1744,10 @@ class GraphModule(torch.nn.Module):
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
def setUp(self):
self._old = torch._dynamo.config.enable_trace_contextlib
torch._dynamo.config.enable_trace_contextlib = True
def tearDown(self):
torch._dynamo.config.enable_trace_contextlib = self._old
torch._dynamo.config.enable_trace_contextlib = False
def test_ctx_basic0(self):
@contextlib.contextmanager
@ -2700,9 +2693,9 @@ class GraphModule(torch.nn.Module):
class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py
# https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70
@make_dynamo_test
@unittest.expectedFailure
def test_contextmanager_plain(self):
state = []
@ -2712,14 +2705,24 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
yield 42
state.append(999)
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
y = t.sum()
with woohoo() as x:
assert state == [1]
assert x == 42
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
y += x
return y
t = torch.randn(2, 3)
y = fn(t)
self.assertEqual(state, [1, 42, 999])
self.assertEqual(y, t.sum() + 42)
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_finally(self):
state = []
@ -2731,56 +2734,121 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
finally:
state.append(999)
with self.assertRaises(ZeroDivisionError):
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
_y = t.sum()
with self.assertRaises(ZeroDivisionError):
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError
fn(torch.randn(2, 3))
self.assertEqual(state, [1, 42, 999])
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_traceback(self):
@contextmanager
def f():
yield
try:
with f():
1 / 0
except ZeroDivisionError as e:
frames = traceback.extract_tb(e.__traceback__)
frames = []
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
nonlocal frames
_y = t.sum()
try:
with f():
1 / 0
except ZeroDivisionError as e:
frames = traceback.extract_tb(e.__traceback__)
fn(torch.randn(2, 3))
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "1/0")
# Repeat with RuntimeError (which goes through a different code path)
try:
with f():
raise NotImplementedError(42)
except NotImplementedError as e:
frames = traceback.extract_tb(e.__traceback__)
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "raise NotImplementedError(42)")
self.assertEqual(frames[0].line, "1 / 0")
@unittest.expectedFailure
def test_contextmanager_traceback2(self):
@contextmanager
def f():
yield
# Repeat with RuntimeError (which goes through a different code path)
class RuntimeErrorSubclass(RuntimeError):
pass
frames = []
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
nonlocal frames
_y = t.sum()
try:
with f():
raise RuntimeErrorSubclass(42)
except RuntimeErrorSubclass as e:
frames = traceback.extract_tb(e.__traceback__)
fn(torch.randn(2, 3))
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "raise RuntimeErrorSubclass(42)")
@unittest.expectedFailure
def test_contextmanager_traceback3(self):
@contextmanager
def f():
yield
frames = []
class StopIterationSubclass(StopIteration):
pass
for stop_exc in (
StopIteration("spam"),
StopIterationSubclass("spam"),
):
with self.subTest(type=type(stop_exc)):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
nonlocal frames
_y = t.sum()
try:
with f():
raise stop_exc
except type(stop_exc) as e:
self.assertIs(e, stop_exc)
frames = traceback.extract_tb(e.__traceback__)
else:
self.fail(f"{stop_exc} was suppressed")
fn(torch.randn(2, 3))
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "raise stop_exc")
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_no_reraise(self):
@contextmanager
def whee():
yield
ctx = whee()
ctx.__enter__()
# Calling __exit__ should not result in an exception
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
ctx = whee()
ctx.__enter__()
# Calling __exit__ should not result in an exception
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
return t.sum()
fn(torch.randn(2, 3))
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_trap_yield_after_throw(self):
@contextmanager
def whoo():
@ -2789,12 +2857,49 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
except Exception:
yield
ctx = whoo()
ctx.__enter__()
self.assertRaises(RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None)
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
ctx = whoo()
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(TypeError, TypeError("foo"), None)
return t.sum()
fn(torch.randn(2, 3))
@unittest.expectedFailure
def test_contextmanager_trap_no_yield(self):
@contextmanager
def whoo():
if False:
yield
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
ctx = whoo()
with self.assertRaises(RuntimeError):
ctx.__enter__()
return t.sum()
fn(torch.randn(2, 3))
@unittest.expectedFailure
def test_contextmanager_trap_second_yield(self):
@contextmanager
def whoo():
yield
yield
@torch.compile(backend="eager", fullgraph=True)
def f(t):
ctx = whoo()
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(None, None, None)
f(torch.randn(2))
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_except(self):
state = []
@ -2807,58 +2912,18 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
state.append(e.args[0])
self.assertEqual(state, [1, 42, 999])
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError(999)
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError(999)
fn(torch.randn(2, 3))
self.assertEqual(state, [1, 42, 999])
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_except_stopiter(self):
@contextmanager
def woohoo():
yield
class StopIterationSubclass(StopIteration):
pass
for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")):
with self.subTest(type=type(stop_exc)):
try:
with woohoo():
raise stop_exc
except Exception as ex:
self.assertIs(ex, stop_exc)
else:
self.fail(f"{stop_exc} was suppressed")
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_except_pep479(self):
code = """\
from __future__ import generator_stop
from contextlib import contextmanager
@contextmanager
def woohoo():
yield
"""
locals = {}
exec(code, locals, locals)
woohoo = locals["woohoo"]
stop_exc = StopIteration("spam")
try:
with woohoo():
raise stop_exc
except Exception as ex:
self.assertIs(ex, stop_exc)
else:
self.fail("StopIteration was suppressed")
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
@contextmanager
def test_issue29692():
@ -2867,78 +2932,73 @@ def woohoo():
except Exception as exc:
raise RuntimeError("issue29692:Chained") from exc
try:
with test_issue29692():
raise ZeroDivisionError
except Exception as ex:
self.assertIs(type(ex), RuntimeError)
self.assertEqual(ex.args[0], "issue29692:Chained")
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
@torch.compile(backend="eager", fullgraph=True)
def f(t):
try:
with test_issue29692():
raise ZeroDivisionError
except Exception as ex:
self.assertIs(type(ex), RuntimeError)
self.assertEqual(ex.args[0], "issue29692:Chained")
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
try:
with test_issue29692():
raise StopIteration("issue29692:Unchained")
except Exception as ex:
self.assertIs(type(ex), StopIteration)
self.assertEqual(ex.args[0], "issue29692:Unchained")
self.assertIsNone(ex.__cause__)
try:
with test_issue29692():
raise StopIteration("issue29692:Unchained")
except Exception as ex:
self.assertIs(type(ex), StopIteration)
self.assertEqual(ex.args[0], "issue29692:Unchained")
self.assertIsNone(ex.__cause__)
f(torch.randn(2))
@unittest.expectedFailure
@make_dynamo_test
def _create_contextmanager_attribs(self):
def attribs(**kw):
def decorate(func):
for k, v in kw.items():
setattr(func, k, v)
return func
return decorate
def test_contextmanager_wrap_runtimeerror(self):
@contextmanager
@attribs(foo="bar")
def baz(spam):
"""Whee!"""
def woohoo():
try:
yield
except Exception as exc:
raise RuntimeError(f"caught {exc}") from exc
return baz
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
with self.assertRaises(RuntimeError):
with woohoo():
1 / 0
fn(torch.randn(2, 3))
# If the context manager wrapped StopIteration in a RuntimeError,
# we also unwrap it, because we can't tell whether the wrapping was
# done by the generator machinery or by the generator itself.
with self.assertRaises(StopIteration):
with woohoo():
raise StopIteration
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_attribs(self):
baz = self._create_contextmanager_attribs()
self.assertEqual(baz.__name__, "baz")
self.assertEqual(baz.foo, "bar")
@make_dynamo_test
def test_keywords(self):
# Ensure no keyword arguments are inhibited
@contextmanager
def woohoo(self, func, args, kwds):
yield (self, func, args, kwds)
with woohoo(self=11, func=22, args=33, kwds=44) as target:
self.assertEqual(target, (11, 22, 33, 44))
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
with woohoo(self=11, func=22, args=33, kwds=44) as target:
self.assertEqual(target, (11, 22, 33, 44))
fn(torch.randn(2, 3))
@unittest.expectedFailure
@make_dynamo_test
def test_param_errors(self):
@contextmanager
def woohoo(a, *, b):
yield
with self.assertRaises(TypeError):
woohoo()
with self.assertRaises(TypeError):
woohoo(3, 5)
with self.assertRaises(TypeError):
woohoo(b=3)
@unittest.expectedFailure
@make_dynamo_test
def test_recursive(self):
depth = 0
ncols = 0
@contextmanager
def woohoo():
nonlocal ncols
ncols += 1
nonlocal depth
before = depth
depth += 1
@ -2951,70 +3011,15 @@ def woohoo():
if depth < 10:
recursive()
recursive()
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
recursive()
fn(torch.randn(2, 3))
self.assertEqual(ncols, 10)
self.assertEqual(depth, 0)
@skipIfNotPy311
@make_dynamo_test
def test_contextmanager_trap_no_yield(self):
@contextmanager
def whoo():
if False:
yield
ctx = whoo()
with self.assertRaises(RuntimeError):
ctx.__enter__()
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_trap_second_yield(self):
@contextmanager
def whoo():
yield
yield
ctx = whoo()
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(None, None, None)
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_wrap_runtimeerror(self):
@contextmanager
def woohoo():
try:
yield
except Exception as exc:
raise RuntimeError(f"caught {exc}") from exc
with self.assertRaises(RuntimeError):
with woohoo():
1 / 0
# If the context manager wrapped StopIteration in a RuntimeError,
# we also unwrap it, because we can't tell whether the wrapping was
# done by the generator machinery or by the generator itself.
with self.assertRaises(StopIteration):
with woohoo():
raise StopIteration
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_non_normalised(self):
@contextmanager
def whoo():
try:
yield
except RuntimeError:
raise SyntaxError # noqa: B904
ctx = whoo()
ctx.__enter__()
with self.assertRaises(SyntaxError):
ctx.__exit__(RuntimeError, None, None)
instantiate_parametrized_tests(CtxManagerTests)
instantiate_parametrized_tests(ContextlibContextManagerTests)

View file

@ -1,7 +1,5 @@
# Owner(s): ["module: dynamo"]
import contextlib
import sys
import unittest
import torch
@ -456,6 +454,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
@unittest.expectedFailure
@make_dynamo_test
def test_raise_set___context__(self):
try:
@ -472,73 +471,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIsNone(exc2.__context__)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_raise_match(self):
a = AttributeError
b = BytesWarning
c = ConnectionError
d = DeprecationWarning
e = Exception
def fn(a, b):
try:
raise a
finally:
raise b
def fix_exc_context(frame_exc, new_exc, old_exc):
# slightly change from ExitStack.fix_exc_context function
while 1:
exc_context = new_exc.__context__
if exc_context is None or exc_context is old_exc:
return
if exc_context is frame_exc:
break
new_exc = exc_context
new_exc.__context__ = old_exc
@contextlib.contextmanager
def ctx():
try:
yield
finally:
frame_exc = prev_exc = sys.exc_info()
args = [(d, c), (b, a)]
for x, y in args:
try:
fn(x, y)
except BaseException:
new_exc = sys.exc_info()
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
prev_exc = new_exc
try:
fixed_ctx = prev_exc[1].__context__
raise prev_exc[1]
except BaseException:
prev_exc[1].__context__ = fixed_ctx
raise
try:
with ctx():
raise e
except Exception as exc:
self.assertIsInstance(exc, a)
self.assertIsInstance(exc.__context__, b)
self.assertIsInstance(exc.__context__.__context__, c)
self.assertIsInstance(exc.__context__.__context__.__context__, d)
self.assertIsInstance(
exc.__context__.__context__.__context__.__context__, e
)
@make_dynamo_test
def test_raise_ZeroDivisionError(self):
try:
1 / 0
except Exception:
pass
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
@ -562,6 +494,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIsNone(e.__context__)
self.assertIsNone(e.__cause__)
@unittest.expectedFailure
@make_dynamo_test
def testChainingDescriptors(self):
try:
@ -581,6 +514,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
e.__suppress_context__ = False
self.assertFalse(e.__suppress_context__)
@unittest.expectedFailure
@make_dynamo_test
def test_context_of_exception_in_try_and_finally(self):
try:
@ -596,6 +530,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc, ve)
self.assertIs(exc.__context__, te)
@unittest.expectedFailure
@make_dynamo_test
def test_context_of_exception_in_except_and_finally(self):
try:
@ -615,6 +550,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc.__context__, ve)
self.assertIs(exc.__context__.__context__, te)
@unittest.expectedFailure
@make_dynamo_test
def test_context_of_exception_in_else_and_finally(self):
try:
@ -634,7 +570,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc, oe)
self.assertIs(exc.__context__, ve)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@unittest.expectedFailure
@make_dynamo_test
def test_raise_does_not_create_context_chain_cycle(self):
A = AssertionError
@ -673,7 +609,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(c.__context__, b)
self.assertIsNone(b.__context__)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@unittest.expectedFailure
@make_dynamo_test
def test_no_hang_on_context_chain_cycle1(self):
# See issue 25782. Cycle in context chain.
@ -729,7 +665,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(b.__context__, a)
self.assertIs(a.__context__, c)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@unittest.expectedFailure
@make_dynamo_test
def test_no_hang_on_context_chain_cycle3(self):
# See issue 25782. Longer context chain with cycle.

View file

@ -2619,6 +2619,23 @@ class GraphModule(torch.nn.Module):
else:
return x.cos()
@unittest.expectedFailure
def test_getattr_metaclass(self):
class Meta(type):
def __getattr__(cls, name):
return len(name)
class C(metaclass=Meta):
attr = 123
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
return t + C.attr + C.dynamic_attr
t = torch.randn(2)
y = fn(t)
self.assertEqual(y, t + 123 + 12)
def test_two_point_iter(self):
def fn(x, y):
it = map(lambda n: n + 1, range(6))

View file

@ -37,6 +37,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
def test_sys_exception_no_exception(self):
self.assertEqual(sys.exception(), None)
@unittest.expectedFailure
@make_dynamo_test
def test_exc_info_with_exception_instance(self):
def f():
@ -53,6 +54,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc_info[1], e)
self.assertIs(exc_info[2], e.__traceback__)
@unittest.expectedFailure
@make_dynamo_test
def test_exc_info_with_exception_type(self):
def f():
@ -69,6 +71,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIs(exc_info[1], e)
self.assertIs(exc_info[2], e.__traceback__)
@unittest.expectedFailure
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_with_exception_instance(self):
@ -84,6 +87,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
self.assertIsInstance(e, ValueError)
self.assertIs(exc, e)
@unittest.expectedFailure
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_with_exception_type(self):

View file

@ -1,608 +0,0 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
import warnings
from itertools import product
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import make_dynamo_test
class Test_Assertions(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py
# https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py
def setUp(self):
if sys.version_info < (3, 11):
self.skipTest(
"Tracing the unittest module needs exception table (Python 3.11+) to work"
)
super().setUp()
@make_dynamo_test
def test_AlmostEqual(self):
self.assertAlmostEqual(1.00000001, 1.0)
self.assertNotAlmostEqual(1.0000001, 1.0)
self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0
)
self.assertAlmostEqual(1.1, 1.0, places=0)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1
)
self.assertAlmostEqual(0, 0.1 + 0.1j, places=0)
self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1
)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0
)
self.assertAlmostEqual(float("inf"), float("inf"))
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf")
)
@make_dynamo_test
def test_AmostEqualWithDelta(self):
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5
)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05
)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5
)
self.assertRaises(
TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2
)
self.assertRaises(
TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2
)
@make_dynamo_test
def test_assertRaises(self):
def _raise(e):
raise e
self.assertRaises(KeyError, _raise, KeyError)
self.assertRaises(KeyError, _raise, KeyError("key"))
try:
self.assertRaises(KeyError, lambda: None)
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
self.assertRaises(KeyError, _raise, ValueError)
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
with self.assertRaises(KeyError) as cm:
try:
raise KeyError
except Exception as e:
exc = e
raise
self.assertIs(cm.exception, exc)
with self.assertRaises(KeyError):
raise KeyError("key")
try:
with self.assertRaises(KeyError):
pass
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
with self.assertRaises(KeyError):
raise ValueError
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
@unittest.expectedFailure
@make_dynamo_test
def testAssertNotRegex(self):
self.assertNotRegex("Ala ma kota", r"r+")
try:
self.assertNotRegex("Ala ma kota", r"k.t", "Message")
except self.failureException as e:
self.assertIn("Message", e.args[0])
else:
self.fail("assertNotRegex should have failed.")
class TestLongMessage(torch._dynamo.test_case.TestCase):
"""Test that the individual asserts honour longMessage.
This actually tests all the message behaviour for
asserts that use longMessage."""
def setUp(self):
if sys.version_info < (3, 11):
return self.skipTest(
"Tracing the unittest module needs exception table (Python 3.11+) to work"
)
super().setUp()
class TestableTestFalse(unittest.TestCase):
longMessage = False
failureException = self.failureException
def testTest(self):
pass
class TestableTestTrue(unittest.TestCase):
longMessage = True
failureException = self.failureException
def testTest(self):
pass
self.testableTrue = TestableTestTrue("testTest")
self.testableFalse = TestableTestFalse("testTest")
def testDefault(self):
self.assertTrue(unittest.TestCase.longMessage)
def test_formatMsg(self):
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
# This blows up if _formatMessage uses string concatenation
self.testableTrue._formatMessage(object(), "foo")
def test_formatMessage_unicode_error(self):
one = "".join(chr(i) for i in range(255))
# this used to cause a UnicodeDecodeError constructing msg
self.testableTrue._formatMessage(one, "\uFFFD")
def assertMessages(self, methodName, args, errors):
"""
Check that methodName(*args) raises the correct error messages.
errors should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
def getMethod(i):
useTestableFalse = i < 2
if useTestableFalse:
test = self.testableFalse
else:
test = self.testableTrue
return getattr(test, methodName)
for i, expected_regex in enumerate(errors):
testMethod = getMethod(i)
kwargs = {}
withMsg = i % 2
if withMsg:
kwargs = {"msg": "oops"}
with self.assertRaisesRegex(
self.failureException, expected_regex=expected_regex
):
testMethod(*args, **kwargs)
# with self.assertRaises(self.failureException) as cm:
# testMethod(*args, **kwargs)
# self.assertIn(expected_regex, str(cm.exception))
@make_dynamo_test
def testAssertTrue(self):
self.assertMessages(
"assertTrue",
(False,),
[
"False is not true",
"oops",
"False is not true",
"False is not true : oops",
],
)
@make_dynamo_test
def testAssertFalse(self):
self.assertMessages(
"assertFalse",
(True,),
[
"True is not false",
"oops",
"True is not false",
"True is not false : oops",
],
)
@make_dynamo_test
def testNotEqual(self):
self.assertMessages(
"assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"]
)
@make_dynamo_test
def testAlmostEqual(self):
self.assertMessages(
"assertAlmostEqual",
(1, 2),
[
r"^1 != 2 within 7 places \(1 difference\)$",
"^oops$",
r"^1 != 2 within 7 places \(1 difference\)$",
r"^1 != 2 within 7 places \(1 difference\) : oops$",
],
)
@make_dynamo_test
def testNotAlmostEqual(self):
self.assertMessages(
"assertNotAlmostEqual",
(1, 1),
[
"^1 == 1 within 7 places$",
"^oops$",
"^1 == 1 within 7 places$",
"^1 == 1 within 7 places : oops$",
],
)
@make_dynamo_test
def test_baseAssertEqual(self):
self.assertMessages(
"_baseAssertEqual",
(1, 2),
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertSequenceEqual(self):
# Error messages are multiline so not testing on full message
# assertTupleEqual and assertListEqual delegate to this method
self.assertMessages(
"assertSequenceEqual",
([], [None]),
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"],
)
@make_dynamo_test
def testAssertSetEqual(self):
self.assertMessages(
"assertSetEqual",
(set(), set([None])), # noqa: C405
["None$", "^oops$", "None$", "None : oops$"],
)
@make_dynamo_test
def testAssertIn(self):
self.assertMessages(
"assertIn",
(None, []),
[
r"^None not found in \[\]$",
"^oops$",
r"^None not found in \[\]$",
r"^None not found in \[\] : oops$",
],
)
@make_dynamo_test
def testAssertNotIn(self):
self.assertMessages(
"assertNotIn",
(None, [None]),
[
r"^None unexpectedly found in \[None\]$",
"^oops$",
r"^None unexpectedly found in \[None\]$",
r"^None unexpectedly found in \[None\] : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertDictEqual(self):
self.assertMessages(
"assertDictEqual",
({}, {"key": "value"}),
[
r"\+ \{'key': 'value'\}$",
"^oops$",
r"\+ \{'key': 'value'\}$",
r"\+ \{'key': 'value'\} : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertMultiLineEqual(self):
self.assertMessages(
"assertMultiLineEqual",
("", "foo"),
[r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"],
)
@make_dynamo_test
def testAssertLess(self):
self.assertMessages(
"assertLess",
(2, 1),
[
"^2 not less than 1$",
"^oops$",
"^2 not less than 1$",
"^2 not less than 1 : oops$",
],
)
@make_dynamo_test
def testAssertLessEqual(self):
self.assertMessages(
"assertLessEqual",
(2, 1),
[
"^2 not less than or equal to 1$",
"^oops$",
"^2 not less than or equal to 1$",
"^2 not less than or equal to 1 : oops$",
],
)
@make_dynamo_test
def testAssertGreater(self):
self.assertMessages(
"assertGreater",
(1, 2),
[
"^1 not greater than 2$",
"^oops$",
"^1 not greater than 2$",
"^1 not greater than 2 : oops$",
],
)
@make_dynamo_test
def testAssertGreaterEqual(self):
self.assertMessages(
"assertGreaterEqual",
(1, 2),
[
"^1 not greater than or equal to 2$",
"^oops$",
"^1 not greater than or equal to 2$",
"^1 not greater than or equal to 2 : oops$",
],
)
@make_dynamo_test
def testAssertIsNone(self):
self.assertMessages(
"assertIsNone",
("not None",),
[
"^'not None' is not None$",
"^oops$",
"^'not None' is not None$",
"^'not None' is not None : oops$",
],
)
@make_dynamo_test
def testAssertIsNotNone(self):
self.assertMessages(
"assertIsNotNone",
(None,),
[
"^unexpectedly None$",
"^oops$",
"^unexpectedly None$",
"^unexpectedly None : oops$",
],
)
@make_dynamo_test
def testAssertIs(self):
self.assertMessages(
"assertIs",
(None, "foo"),
[
"^None is not 'foo'$",
"^oops$",
"^None is not 'foo'$",
"^None is not 'foo' : oops$",
],
)
@make_dynamo_test
def testAssertIsNot(self):
self.assertMessages(
"assertIsNot",
(None, None),
[
"^unexpectedly identical: None$",
"^oops$",
"^unexpectedly identical: None$",
"^unexpectedly identical: None : oops$",
],
)
@make_dynamo_test
def testAssertRegex(self):
self.assertMessages(
"assertRegex",
("foo", "bar"),
[
"^Regex didn't match:",
"^oops$",
"^Regex didn't match:",
"^Regex didn't match: (.*) : oops$",
],
)
@make_dynamo_test
def testAssertNotRegex(self):
self.assertMessages(
"assertNotRegex",
("foo", "foo"),
[
"^Regex matched:",
"^oops$",
"^Regex matched:",
"^Regex matched: (.*) : oops$",
],
)
def assertMessagesCM(self, methodName, args, func, errors):
"""
Check that the correct error messages are raised while executing:
with method(*args):
func()
*errors* should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"}))
for (cls, kwargs), err in zip(p, errors):
method = getattr(cls, methodName)
with self.assertRaisesRegex(cls.failureException, err):
with method(*args, **kwargs) as cm: # noqa: F841
func()
@make_dynamo_test
def testAssertRaises(self):
self.assertMessagesCM(
"assertRaises",
(TypeError,),
lambda: None,
[
"^TypeError not raised$",
"^oops$",
"^TypeError not raised$",
"^TypeError not raised : oops$",
],
)
@make_dynamo_test
def testAssertRaisesRegex(self):
self.assertMessagesCM(
"assertRaisesRegex",
(TypeError, "unused regex"),
lambda: None,
[
"^TypeError not raised$",
"^oops$",
"^TypeError not raised$",
"^TypeError not raised : oops$",
],
)
# test error raised but with wrong message
def raise_wrong_message():
raise TypeError("foo")
self.assertMessagesCM(
"assertRaisesRegex",
(TypeError, "regex"),
raise_wrong_message,
[
'^"regex" does not match "foo"$',
"^oops$",
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$',
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertWarns(self):
self.assertMessagesCM(
"assertWarns",
(UserWarning,),
lambda: None,
[
"^UserWarning not triggered$",
"^oops$",
"^UserWarning not triggered$",
"^UserWarning not triggered : oops$",
],
)
@unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13")
@make_dynamo_test
def test_assertNotWarns(self):
def warn_future():
warnings.warn("xyz", FutureWarning, stacklevel=2)
self.assertMessagesCM(
"_assertNotWarns",
(FutureWarning,),
warn_future,
[
"^FutureWarning triggered$",
"^oops$",
"^FutureWarning triggered$",
"^FutureWarning triggered : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertWarnsRegex(self):
# test error not raised
self.assertMessagesCM(
"assertWarnsRegex",
(UserWarning, "unused regex"),
lambda: None,
[
"^UserWarning not triggered$",
"^oops$",
"^UserWarning not triggered$",
"^UserWarning not triggered : oops$",
],
)
# test warning raised but with wrong message
def raise_wrong_message():
warnings.warn("foo")
self.assertMessagesCM(
"assertWarnsRegex",
(UserWarning, "regex"),
raise_wrong_message,
[
'^"regex" does not match "foo"$',
"^oops$",
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$',
],
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View file

@ -291,13 +291,6 @@ observed_exception_map = {
}
def create_dynamo_observed_exception(e: type[Exception]) -> None:
if e not in observed_exception_map:
name = getattr(e, "__name__", str(e))
internal_exc = type(f"Observed{name}Error", (ObservedException,), {})
observed_exception_map[e] = internal_exc
def raise_observed_exception(
exc_type: type[Exception],
tx: InstructionTranslatorBase,
@ -311,8 +304,6 @@ def raise_observed_exception(
# stack and raise the exception.
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
tx.exn_vt_stack.append(exception_vt)
if exc_type not in observed_exception_map:
create_dynamo_observed_exception(exc_type)
raise observed_exception_map[exc_type]

View file

@ -3242,34 +3242,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
else:
return result
def update_parent_exn_vt_stack(self):
# TODO(anijain2305) - This works but we should probably have a
# global/central data structure for the exception stack.
parent = self.parent
p_exn = parent.exn_vt_stack
c_exn = self.exn_vt_stack
# Prior to this, the parent stack would be extended with all exceptions from
# the child frame. We don't need to append everything as the parent interpreter
# can only match the topmost exception from the child. i.e.
#
# def foo():
# try:
# raise ValueError
# except ValueError as e:
# raise TypeError from e
#
# def bar():
# try:
# foo()
# except TypeError:
# # it is impossible for `bar` to match on the first `ValueError`
# # raised by `foo`
# pass
#
# For a test case, check test_exceptions::test_raise_match
if len(c_exn) > len(p_exn):
parent.exn_vt_stack.append(c_exn[-1])
@staticmethod
def build_inline_tracer(
parent,
@ -3376,7 +3348,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
self.run()
except exc.ObservedException as e:
msg = f"Observed exception DURING INLING {code} : {e}"
self.update_parent_exn_vt_stack()
# TODO(anijain2305) - This works but we should probably have a
# global/central data structure for the exception stack.
parent.exn_vt_stack.extend(self.exn_vt_stack)
log.debug(msg)
# bubble up the exception to the parent frame.
raise
@ -3449,8 +3423,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
)
self.funcvar = funcvar
self.parent = parent
# Propagate any exception on the parent stack
self.exn_vt_stack = list(parent.exn_vt_stack)
self.num_calls = parent.num_calls
self.symbolic_result = None
self.nn_module_stack = parent.nn_module_stack.copy()

View file

@ -16,6 +16,7 @@ import sys
import traceback
import types
import typing
import unittest
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, cast, Optional, Union
@ -3147,6 +3148,7 @@ BUILTIN_SKIPLIST = (
random,
traceback,
linecache,
unittest,
)
# third party libraries skiplist is defined by str, because users may not use these libraries.

View file

@ -1012,12 +1012,10 @@ class VariableBuilder:
elif is_lru_cache_wrapped_function(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
elif value in (traceback.clear_frames,):
elif value is traceback.clear_frames:
return TracebackVariable(source=self.source)
elif (
value in (sys.exc_info,)
or sys.version_info >= (3, 11)
and value is sys.exception
elif value is sys.exc_info or (
sys.version_info >= (3, 11) and value is sys.exception
):
return SysFunctionVariable(value, source=self.source)
elif is_function_or_wrapper(value) and inspect.getattr_static(

View file

@ -21,7 +21,6 @@ from .. import config, polyfills, variables
from ..exc import (
AttributeMutationError,
ObservedAttributeError,
raise_observed_exception,
unimplemented,
Unsupported,
UserError,
@ -874,7 +873,7 @@ class BuiltinVariable(VariableTracker):
*[x.as_python_constant() for x in args],
)
except Exception as exc:
raise_observed_exception(type(exc), tx)
unimplemented(f"constant fold exception: {repr(exc)}")
return VariableTracker.build(tx, res)
else:
@ -1219,12 +1218,6 @@ class BuiltinVariable(VariableTracker):
# Inline the user function
return tx.inline_user_function_return(user_func_variable, [arg], {})
elif isinstance(arg, (variables.ExceptionVariable,)):
if len(arg.args) == 0:
value = f"{arg.exc_type}"
else:
value = ", ".join(a.as_python_constant() for a in arg.args)
return variables.ConstantVariable.create(value=value)
def _call_min_max(self, tx: "InstructionTranslator", *args):
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):

View file

@ -633,11 +633,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
def has_closure(self):
return self.closure is not None
def const_getattr(self, tx, name):
if name == "__name__":
return self.fn_name.as_python_constant()
return super().const_getattr(tx, name)
def has_self(self):
return False

View file

@ -185,7 +185,12 @@ class UserDefinedClassVariable(UserDefinedVariable):
try:
obj = inspect.getattr_static(self.value, name)
except AttributeError:
raise_observed_exception(AttributeError, tx)
if type(self.value) is type:
raise_observed_exception(AttributeError, tx)
else:
# Cannot reason about classes with a custom metaclass
# See: test_functions::test_getattr_metaclass
obj = None
if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType):
return variables.GetAttrVariable(self, name, source=source)