mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
parent
1e6a7cfd54
commit
0ddc653d67
12 changed files with 262 additions and 952 deletions
|
|
@ -9,18 +9,12 @@ import torch
|
|||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.exc import InternalTorchDynamoError
|
||||
from torch._dynamo.testing import (
|
||||
EagerAndRecordGraphs,
|
||||
normalize_gm,
|
||||
same,
|
||||
skipIfNotPy311,
|
||||
)
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
make_dynamo_test,
|
||||
parametrize,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
|
|
@ -1750,11 +1744,10 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
self._old = torch._dynamo.config.enable_trace_contextlib
|
||||
torch._dynamo.config.enable_trace_contextlib = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_contextlib = self._old
|
||||
torch._dynamo.config.enable_trace_contextlib = False
|
||||
|
||||
def test_ctx_basic0(self):
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -2700,9 +2693,9 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py
|
||||
# https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70
|
||||
|
||||
@make_dynamo_test
|
||||
@unittest.expectedFailure
|
||||
def test_contextmanager_plain(self):
|
||||
state = []
|
||||
|
||||
|
|
@ -2712,14 +2705,24 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
|||
yield 42
|
||||
state.append(999)
|
||||
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
y = t.sum()
|
||||
with woohoo() as x:
|
||||
assert state == [1]
|
||||
assert x == 42
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
y += x
|
||||
return y
|
||||
|
||||
t = torch.randn(2, 3)
|
||||
y = fn(t)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
self.assertEqual(y, t.sum() + 42)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_finally(self):
|
||||
state = []
|
||||
|
||||
|
|
@ -2731,56 +2734,121 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
|||
finally:
|
||||
state.append(999)
|
||||
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
_y = t.sum()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_traceback(self):
|
||||
@contextmanager
|
||||
def f():
|
||||
yield
|
||||
|
||||
try:
|
||||
with f():
|
||||
1 / 0
|
||||
except ZeroDivisionError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
frames = []
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
nonlocal frames
|
||||
_y = t.sum()
|
||||
try:
|
||||
with f():
|
||||
1 / 0
|
||||
except ZeroDivisionError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "1/0")
|
||||
|
||||
# Repeat with RuntimeError (which goes through a different code path)
|
||||
try:
|
||||
with f():
|
||||
raise NotImplementedError(42)
|
||||
except NotImplementedError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "raise NotImplementedError(42)")
|
||||
self.assertEqual(frames[0].line, "1 / 0")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_contextmanager_traceback2(self):
|
||||
@contextmanager
|
||||
def f():
|
||||
yield
|
||||
|
||||
# Repeat with RuntimeError (which goes through a different code path)
|
||||
class RuntimeErrorSubclass(RuntimeError):
|
||||
pass
|
||||
|
||||
frames = []
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
nonlocal frames
|
||||
_y = t.sum()
|
||||
try:
|
||||
with f():
|
||||
raise RuntimeErrorSubclass(42)
|
||||
except RuntimeErrorSubclass as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "raise RuntimeErrorSubclass(42)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_contextmanager_traceback3(self):
|
||||
@contextmanager
|
||||
def f():
|
||||
yield
|
||||
|
||||
frames = []
|
||||
|
||||
class StopIterationSubclass(StopIteration):
|
||||
pass
|
||||
|
||||
for stop_exc in (
|
||||
StopIteration("spam"),
|
||||
StopIterationSubclass("spam"),
|
||||
):
|
||||
with self.subTest(type=type(stop_exc)):
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
nonlocal frames
|
||||
_y = t.sum()
|
||||
try:
|
||||
with f():
|
||||
raise stop_exc
|
||||
except type(stop_exc) as e:
|
||||
self.assertIs(e, stop_exc)
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
else:
|
||||
self.fail(f"{stop_exc} was suppressed")
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "raise stop_exc")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_no_reraise(self):
|
||||
@contextmanager
|
||||
def whee():
|
||||
yield
|
||||
|
||||
ctx = whee()
|
||||
ctx.__enter__()
|
||||
# Calling __exit__ should not result in an exception
|
||||
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
ctx = whee()
|
||||
ctx.__enter__()
|
||||
# Calling __exit__ should not result in an exception
|
||||
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
|
||||
return t.sum()
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_yield_after_throw(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
|
|
@ -2789,12 +2857,49 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
|||
except Exception:
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
self.assertRaises(RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None)
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(TypeError, TypeError("foo"), None)
|
||||
return t.sum()
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_contextmanager_trap_no_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
if False:
|
||||
yield
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
ctx = whoo()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__enter__()
|
||||
return t.sum()
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_contextmanager_trap_second_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
yield
|
||||
yield
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(t):
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(None, None, None)
|
||||
|
||||
f(torch.randn(2))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except(self):
|
||||
state = []
|
||||
|
||||
|
|
@ -2807,58 +2912,18 @@ class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase):
|
|||
state.append(e.args[0])
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError(999)
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError(999)
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except_stopiter(self):
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
yield
|
||||
|
||||
class StopIterationSubclass(StopIteration):
|
||||
pass
|
||||
|
||||
for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")):
|
||||
with self.subTest(type=type(stop_exc)):
|
||||
try:
|
||||
with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail(f"{stop_exc} was suppressed")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except_pep479(self):
|
||||
code = """\
|
||||
from __future__ import generator_stop
|
||||
from contextlib import contextmanager
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
yield
|
||||
"""
|
||||
locals = {}
|
||||
exec(code, locals, locals)
|
||||
woohoo = locals["woohoo"]
|
||||
|
||||
stop_exc = StopIteration("spam")
|
||||
try:
|
||||
with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail("StopIteration was suppressed")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
|
||||
@contextmanager
|
||||
def test_issue29692():
|
||||
|
|
@ -2867,78 +2932,73 @@ def woohoo():
|
|||
except Exception as exc:
|
||||
raise RuntimeError("issue29692:Chained") from exc
|
||||
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise ZeroDivisionError
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), RuntimeError)
|
||||
self.assertEqual(ex.args[0], "issue29692:Chained")
|
||||
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(t):
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise ZeroDivisionError
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), RuntimeError)
|
||||
self.assertEqual(ex.args[0], "issue29692:Chained")
|
||||
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
|
||||
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise StopIteration("issue29692:Unchained")
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), StopIteration)
|
||||
self.assertEqual(ex.args[0], "issue29692:Unchained")
|
||||
self.assertIsNone(ex.__cause__)
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise StopIteration("issue29692:Unchained")
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), StopIteration)
|
||||
self.assertEqual(ex.args[0], "issue29692:Unchained")
|
||||
self.assertIsNone(ex.__cause__)
|
||||
|
||||
f(torch.randn(2))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def _create_contextmanager_attribs(self):
|
||||
def attribs(**kw):
|
||||
def decorate(func):
|
||||
for k, v in kw.items():
|
||||
setattr(func, k, v)
|
||||
return func
|
||||
|
||||
return decorate
|
||||
|
||||
def test_contextmanager_wrap_runtimeerror(self):
|
||||
@contextmanager
|
||||
@attribs(foo="bar")
|
||||
def baz(spam):
|
||||
"""Whee!"""
|
||||
def woohoo():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"caught {exc}") from exc
|
||||
|
||||
return baz
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
with self.assertRaises(RuntimeError):
|
||||
with woohoo():
|
||||
1 / 0
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
# If the context manager wrapped StopIteration in a RuntimeError,
|
||||
# we also unwrap it, because we can't tell whether the wrapping was
|
||||
# done by the generator machinery or by the generator itself.
|
||||
with self.assertRaises(StopIteration):
|
||||
with woohoo():
|
||||
raise StopIteration
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_attribs(self):
|
||||
baz = self._create_contextmanager_attribs()
|
||||
self.assertEqual(baz.__name__, "baz")
|
||||
self.assertEqual(baz.foo, "bar")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_keywords(self):
|
||||
# Ensure no keyword arguments are inhibited
|
||||
@contextmanager
|
||||
def woohoo(self, func, args, kwds):
|
||||
yield (self, func, args, kwds)
|
||||
|
||||
with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
||||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
||||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_param_errors(self):
|
||||
@contextmanager
|
||||
def woohoo(a, *, b):
|
||||
yield
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo()
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo(3, 5)
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo(b=3)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_recursive(self):
|
||||
depth = 0
|
||||
ncols = 0
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
nonlocal ncols
|
||||
ncols += 1
|
||||
nonlocal depth
|
||||
before = depth
|
||||
depth += 1
|
||||
|
|
@ -2951,70 +3011,15 @@ def woohoo():
|
|||
if depth < 10:
|
||||
recursive()
|
||||
|
||||
recursive()
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
recursive()
|
||||
|
||||
fn(torch.randn(2, 3))
|
||||
|
||||
self.assertEqual(ncols, 10)
|
||||
self.assertEqual(depth, 0)
|
||||
|
||||
@skipIfNotPy311
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_no_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
if False:
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__enter__()
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_second_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
yield
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(None, None, None)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_wrap_runtimeerror(self):
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"caught {exc}") from exc
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
with woohoo():
|
||||
1 / 0
|
||||
|
||||
# If the context manager wrapped StopIteration in a RuntimeError,
|
||||
# we also unwrap it, because we can't tell whether the wrapping was
|
||||
# done by the generator machinery or by the generator itself.
|
||||
with self.assertRaises(StopIteration):
|
||||
with woohoo():
|
||||
raise StopIteration
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_non_normalised(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
try:
|
||||
yield
|
||||
except RuntimeError:
|
||||
raise SyntaxError # noqa: B904
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(SyntaxError):
|
||||
ctx.__exit__(RuntimeError, None, None)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(CtxManagerTests)
|
||||
instantiate_parametrized_tests(ContextlibContextManagerTests)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
|
@ -456,6 +454,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_raise_set___context__(self):
|
||||
try:
|
||||
|
|
@ -472,73 +471,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertIsNone(exc2.__context__)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_raise_match(self):
|
||||
a = AttributeError
|
||||
b = BytesWarning
|
||||
c = ConnectionError
|
||||
d = DeprecationWarning
|
||||
e = Exception
|
||||
|
||||
def fn(a, b):
|
||||
try:
|
||||
raise a
|
||||
finally:
|
||||
raise b
|
||||
|
||||
def fix_exc_context(frame_exc, new_exc, old_exc):
|
||||
# slightly change from ExitStack.fix_exc_context function
|
||||
while 1:
|
||||
exc_context = new_exc.__context__
|
||||
if exc_context is None or exc_context is old_exc:
|
||||
return
|
||||
if exc_context is frame_exc:
|
||||
break
|
||||
new_exc = exc_context
|
||||
new_exc.__context__ = old_exc
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
frame_exc = prev_exc = sys.exc_info()
|
||||
args = [(d, c), (b, a)]
|
||||
for x, y in args:
|
||||
try:
|
||||
fn(x, y)
|
||||
except BaseException:
|
||||
new_exc = sys.exc_info()
|
||||
fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
|
||||
prev_exc = new_exc
|
||||
|
||||
try:
|
||||
fixed_ctx = prev_exc[1].__context__
|
||||
raise prev_exc[1]
|
||||
except BaseException:
|
||||
prev_exc[1].__context__ = fixed_ctx
|
||||
raise
|
||||
|
||||
try:
|
||||
with ctx():
|
||||
raise e
|
||||
except Exception as exc:
|
||||
self.assertIsInstance(exc, a)
|
||||
self.assertIsInstance(exc.__context__, b)
|
||||
self.assertIsInstance(exc.__context__.__context__, c)
|
||||
self.assertIsInstance(exc.__context__.__context__.__context__, d)
|
||||
self.assertIsInstance(
|
||||
exc.__context__.__context__.__context__.__context__, e
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_ZeroDivisionError(self):
|
||||
try:
|
||||
1 / 0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
|
||||
|
|
@ -562,6 +494,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIsNone(e.__context__)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testChainingDescriptors(self):
|
||||
try:
|
||||
|
|
@ -581,6 +514,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
e.__suppress_context__ = False
|
||||
self.assertFalse(e.__suppress_context__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_try_and_finally(self):
|
||||
try:
|
||||
|
|
@ -596,6 +530,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(exc, ve)
|
||||
self.assertIs(exc.__context__, te)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_except_and_finally(self):
|
||||
try:
|
||||
|
|
@ -615,6 +550,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(exc.__context__, ve)
|
||||
self.assertIs(exc.__context__.__context__, te)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_else_and_finally(self):
|
||||
try:
|
||||
|
|
@ -634,7 +570,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(exc, oe)
|
||||
self.assertIs(exc.__context__, ve)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
A = AssertionError
|
||||
|
|
@ -673,7 +609,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(c.__context__, b)
|
||||
self.assertIsNone(b.__context__)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle1(self):
|
||||
# See issue 25782. Cycle in context chain.
|
||||
|
|
@ -729,7 +665,7 @@ class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
|
|
|
|||
|
|
@ -2619,6 +2619,23 @@ class GraphModule(torch.nn.Module):
|
|||
else:
|
||||
return x.cos()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_getattr_metaclass(self):
|
||||
class Meta(type):
|
||||
def __getattr__(cls, name):
|
||||
return len(name)
|
||||
|
||||
class C(metaclass=Meta):
|
||||
attr = 123
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
return t + C.attr + C.dynamic_attr
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t + 123 + 12)
|
||||
|
||||
def test_two_point_iter(self):
|
||||
def fn(x, y):
|
||||
it = map(lambda n: n + 1, range(6))
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
def test_sys_exception_no_exception(self):
|
||||
self.assertEqual(sys.exception(), None)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_exc_info_with_exception_instance(self):
|
||||
def f():
|
||||
|
|
@ -53,6 +54,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(exc_info[1], e)
|
||||
self.assertIs(exc_info[2], e.__traceback__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_exc_info_with_exception_type(self):
|
||||
def f():
|
||||
|
|
@ -69,6 +71,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIs(exc_info[1], e)
|
||||
self.assertIs(exc_info[2], e.__traceback__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_with_exception_instance(self):
|
||||
|
|
@ -84,6 +87,7 @@ class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc, e)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_with_exception_type(self):
|
||||
|
|
|
|||
|
|
@ -1,608 +0,0 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
class Test_Assertions(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py
|
||||
# https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py
|
||||
def setUp(self):
|
||||
if sys.version_info < (3, 11):
|
||||
self.skipTest(
|
||||
"Tracing the unittest module needs exception table (Python 3.11+) to work"
|
||||
)
|
||||
super().setUp()
|
||||
|
||||
@make_dynamo_test
|
||||
def test_AlmostEqual(self):
|
||||
self.assertAlmostEqual(1.00000001, 1.0)
|
||||
self.assertNotAlmostEqual(1.0000001, 1.0)
|
||||
self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(1.1, 1.0, places=0)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(0, 0.1 + 0.1j, places=0)
|
||||
self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1
|
||||
)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(float("inf"), float("inf"))
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf")
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_AmostEqualWithDelta(self):
|
||||
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
|
||||
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
|
||||
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
|
||||
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
|
||||
|
||||
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5
|
||||
)
|
||||
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05
|
||||
)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5
|
||||
)
|
||||
|
||||
self.assertRaises(
|
||||
TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2
|
||||
)
|
||||
self.assertRaises(
|
||||
TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_assertRaises(self):
|
||||
def _raise(e):
|
||||
raise e
|
||||
|
||||
self.assertRaises(KeyError, _raise, KeyError)
|
||||
self.assertRaises(KeyError, _raise, KeyError("key"))
|
||||
try:
|
||||
self.assertRaises(KeyError, lambda: None)
|
||||
except self.failureException as e:
|
||||
self.assertIn("KeyError not raised", str(e))
|
||||
else:
|
||||
self.fail("assertRaises() didn't fail")
|
||||
try:
|
||||
self.assertRaises(KeyError, _raise, ValueError)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.fail("assertRaises() didn't let exception pass through")
|
||||
with self.assertRaises(KeyError) as cm:
|
||||
try:
|
||||
raise KeyError
|
||||
except Exception as e:
|
||||
exc = e
|
||||
raise
|
||||
self.assertIs(cm.exception, exc)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
raise KeyError("key")
|
||||
try:
|
||||
with self.assertRaises(KeyError):
|
||||
pass
|
||||
except self.failureException as e:
|
||||
self.assertIn("KeyError not raised", str(e))
|
||||
else:
|
||||
self.fail("assertRaises() didn't fail")
|
||||
try:
|
||||
with self.assertRaises(KeyError):
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.fail("assertRaises() didn't let exception pass through")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertNotRegex(self):
|
||||
self.assertNotRegex("Ala ma kota", r"r+")
|
||||
try:
|
||||
self.assertNotRegex("Ala ma kota", r"k.t", "Message")
|
||||
except self.failureException as e:
|
||||
self.assertIn("Message", e.args[0])
|
||||
else:
|
||||
self.fail("assertNotRegex should have failed.")
|
||||
|
||||
|
||||
class TestLongMessage(torch._dynamo.test_case.TestCase):
|
||||
|
||||
"""Test that the individual asserts honour longMessage.
|
||||
This actually tests all the message behaviour for
|
||||
asserts that use longMessage."""
|
||||
|
||||
def setUp(self):
|
||||
if sys.version_info < (3, 11):
|
||||
return self.skipTest(
|
||||
"Tracing the unittest module needs exception table (Python 3.11+) to work"
|
||||
)
|
||||
super().setUp()
|
||||
|
||||
class TestableTestFalse(unittest.TestCase):
|
||||
longMessage = False
|
||||
failureException = self.failureException
|
||||
|
||||
def testTest(self):
|
||||
pass
|
||||
|
||||
class TestableTestTrue(unittest.TestCase):
|
||||
longMessage = True
|
||||
failureException = self.failureException
|
||||
|
||||
def testTest(self):
|
||||
pass
|
||||
|
||||
self.testableTrue = TestableTestTrue("testTest")
|
||||
self.testableFalse = TestableTestFalse("testTest")
|
||||
|
||||
def testDefault(self):
|
||||
self.assertTrue(unittest.TestCase.longMessage)
|
||||
|
||||
def test_formatMsg(self):
|
||||
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
|
||||
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
|
||||
|
||||
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
|
||||
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
|
||||
|
||||
# This blows up if _formatMessage uses string concatenation
|
||||
self.testableTrue._formatMessage(object(), "foo")
|
||||
|
||||
def test_formatMessage_unicode_error(self):
|
||||
one = "".join(chr(i) for i in range(255))
|
||||
# this used to cause a UnicodeDecodeError constructing msg
|
||||
self.testableTrue._formatMessage(one, "\uFFFD")
|
||||
|
||||
def assertMessages(self, methodName, args, errors):
|
||||
"""
|
||||
Check that methodName(*args) raises the correct error messages.
|
||||
errors should be a list of 4 regex that match the error when:
|
||||
1) longMessage = False and no msg passed;
|
||||
2) longMessage = False and msg passed;
|
||||
3) longMessage = True and no msg passed;
|
||||
4) longMessage = True and msg passed;
|
||||
"""
|
||||
|
||||
def getMethod(i):
|
||||
useTestableFalse = i < 2
|
||||
if useTestableFalse:
|
||||
test = self.testableFalse
|
||||
else:
|
||||
test = self.testableTrue
|
||||
return getattr(test, methodName)
|
||||
|
||||
for i, expected_regex in enumerate(errors):
|
||||
testMethod = getMethod(i)
|
||||
kwargs = {}
|
||||
withMsg = i % 2
|
||||
if withMsg:
|
||||
kwargs = {"msg": "oops"}
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
self.failureException, expected_regex=expected_regex
|
||||
):
|
||||
testMethod(*args, **kwargs)
|
||||
# with self.assertRaises(self.failureException) as cm:
|
||||
# testMethod(*args, **kwargs)
|
||||
# self.assertIn(expected_regex, str(cm.exception))
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertTrue(self):
|
||||
self.assertMessages(
|
||||
"assertTrue",
|
||||
(False,),
|
||||
[
|
||||
"False is not true",
|
||||
"oops",
|
||||
"False is not true",
|
||||
"False is not true : oops",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertFalse(self):
|
||||
self.assertMessages(
|
||||
"assertFalse",
|
||||
(True,),
|
||||
[
|
||||
"True is not false",
|
||||
"oops",
|
||||
"True is not false",
|
||||
"True is not false : oops",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testNotEqual(self):
|
||||
self.assertMessages(
|
||||
"assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"]
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAlmostEqual(self):
|
||||
self.assertMessages(
|
||||
"assertAlmostEqual",
|
||||
(1, 2),
|
||||
[
|
||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
||||
"^oops$",
|
||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
||||
r"^1 != 2 within 7 places \(1 difference\) : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testNotAlmostEqual(self):
|
||||
self.assertMessages(
|
||||
"assertNotAlmostEqual",
|
||||
(1, 1),
|
||||
[
|
||||
"^1 == 1 within 7 places$",
|
||||
"^oops$",
|
||||
"^1 == 1 within 7 places$",
|
||||
"^1 == 1 within 7 places : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_baseAssertEqual(self):
|
||||
self.assertMessages(
|
||||
"_baseAssertEqual",
|
||||
(1, 2),
|
||||
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertSequenceEqual(self):
|
||||
# Error messages are multiline so not testing on full message
|
||||
# assertTupleEqual and assertListEqual delegate to this method
|
||||
self.assertMessages(
|
||||
"assertSequenceEqual",
|
||||
([], [None]),
|
||||
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertSetEqual(self):
|
||||
self.assertMessages(
|
||||
"assertSetEqual",
|
||||
(set(), set([None])), # noqa: C405
|
||||
["None$", "^oops$", "None$", "None : oops$"],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIn(self):
|
||||
self.assertMessages(
|
||||
"assertIn",
|
||||
(None, []),
|
||||
[
|
||||
r"^None not found in \[\]$",
|
||||
"^oops$",
|
||||
r"^None not found in \[\]$",
|
||||
r"^None not found in \[\] : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertNotIn(self):
|
||||
self.assertMessages(
|
||||
"assertNotIn",
|
||||
(None, [None]),
|
||||
[
|
||||
r"^None unexpectedly found in \[None\]$",
|
||||
"^oops$",
|
||||
r"^None unexpectedly found in \[None\]$",
|
||||
r"^None unexpectedly found in \[None\] : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertDictEqual(self):
|
||||
self.assertMessages(
|
||||
"assertDictEqual",
|
||||
({}, {"key": "value"}),
|
||||
[
|
||||
r"\+ \{'key': 'value'\}$",
|
||||
"^oops$",
|
||||
r"\+ \{'key': 'value'\}$",
|
||||
r"\+ \{'key': 'value'\} : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertMultiLineEqual(self):
|
||||
self.assertMessages(
|
||||
"assertMultiLineEqual",
|
||||
("", "foo"),
|
||||
[r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertLess(self):
|
||||
self.assertMessages(
|
||||
"assertLess",
|
||||
(2, 1),
|
||||
[
|
||||
"^2 not less than 1$",
|
||||
"^oops$",
|
||||
"^2 not less than 1$",
|
||||
"^2 not less than 1 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertLessEqual(self):
|
||||
self.assertMessages(
|
||||
"assertLessEqual",
|
||||
(2, 1),
|
||||
[
|
||||
"^2 not less than or equal to 1$",
|
||||
"^oops$",
|
||||
"^2 not less than or equal to 1$",
|
||||
"^2 not less than or equal to 1 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertGreater(self):
|
||||
self.assertMessages(
|
||||
"assertGreater",
|
||||
(1, 2),
|
||||
[
|
||||
"^1 not greater than 2$",
|
||||
"^oops$",
|
||||
"^1 not greater than 2$",
|
||||
"^1 not greater than 2 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertGreaterEqual(self):
|
||||
self.assertMessages(
|
||||
"assertGreaterEqual",
|
||||
(1, 2),
|
||||
[
|
||||
"^1 not greater than or equal to 2$",
|
||||
"^oops$",
|
||||
"^1 not greater than or equal to 2$",
|
||||
"^1 not greater than or equal to 2 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIsNone(self):
|
||||
self.assertMessages(
|
||||
"assertIsNone",
|
||||
("not None",),
|
||||
[
|
||||
"^'not None' is not None$",
|
||||
"^oops$",
|
||||
"^'not None' is not None$",
|
||||
"^'not None' is not None : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIsNotNone(self):
|
||||
self.assertMessages(
|
||||
"assertIsNotNone",
|
||||
(None,),
|
||||
[
|
||||
"^unexpectedly None$",
|
||||
"^oops$",
|
||||
"^unexpectedly None$",
|
||||
"^unexpectedly None : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIs(self):
|
||||
self.assertMessages(
|
||||
"assertIs",
|
||||
(None, "foo"),
|
||||
[
|
||||
"^None is not 'foo'$",
|
||||
"^oops$",
|
||||
"^None is not 'foo'$",
|
||||
"^None is not 'foo' : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIsNot(self):
|
||||
self.assertMessages(
|
||||
"assertIsNot",
|
||||
(None, None),
|
||||
[
|
||||
"^unexpectedly identical: None$",
|
||||
"^oops$",
|
||||
"^unexpectedly identical: None$",
|
||||
"^unexpectedly identical: None : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertRegex(self):
|
||||
self.assertMessages(
|
||||
"assertRegex",
|
||||
("foo", "bar"),
|
||||
[
|
||||
"^Regex didn't match:",
|
||||
"^oops$",
|
||||
"^Regex didn't match:",
|
||||
"^Regex didn't match: (.*) : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertNotRegex(self):
|
||||
self.assertMessages(
|
||||
"assertNotRegex",
|
||||
("foo", "foo"),
|
||||
[
|
||||
"^Regex matched:",
|
||||
"^oops$",
|
||||
"^Regex matched:",
|
||||
"^Regex matched: (.*) : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
def assertMessagesCM(self, methodName, args, func, errors):
|
||||
"""
|
||||
Check that the correct error messages are raised while executing:
|
||||
with method(*args):
|
||||
func()
|
||||
*errors* should be a list of 4 regex that match the error when:
|
||||
1) longMessage = False and no msg passed;
|
||||
2) longMessage = False and msg passed;
|
||||
3) longMessage = True and no msg passed;
|
||||
4) longMessage = True and msg passed;
|
||||
"""
|
||||
p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"}))
|
||||
for (cls, kwargs), err in zip(p, errors):
|
||||
method = getattr(cls, methodName)
|
||||
with self.assertRaisesRegex(cls.failureException, err):
|
||||
with method(*args, **kwargs) as cm: # noqa: F841
|
||||
func()
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertRaises(self):
|
||||
self.assertMessagesCM(
|
||||
"assertRaises",
|
||||
(TypeError,),
|
||||
lambda: None,
|
||||
[
|
||||
"^TypeError not raised$",
|
||||
"^oops$",
|
||||
"^TypeError not raised$",
|
||||
"^TypeError not raised : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertRaisesRegex(self):
|
||||
self.assertMessagesCM(
|
||||
"assertRaisesRegex",
|
||||
(TypeError, "unused regex"),
|
||||
lambda: None,
|
||||
[
|
||||
"^TypeError not raised$",
|
||||
"^oops$",
|
||||
"^TypeError not raised$",
|
||||
"^TypeError not raised : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
# test error raised but with wrong message
|
||||
def raise_wrong_message():
|
||||
raise TypeError("foo")
|
||||
|
||||
self.assertMessagesCM(
|
||||
"assertRaisesRegex",
|
||||
(TypeError, "regex"),
|
||||
raise_wrong_message,
|
||||
[
|
||||
'^"regex" does not match "foo"$',
|
||||
"^oops$",
|
||||
'^"regex" does not match "foo"$',
|
||||
'^"regex" does not match "foo" : oops$',
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertWarns(self):
|
||||
self.assertMessagesCM(
|
||||
"assertWarns",
|
||||
(UserWarning,),
|
||||
lambda: None,
|
||||
[
|
||||
"^UserWarning not triggered$",
|
||||
"^oops$",
|
||||
"^UserWarning not triggered$",
|
||||
"^UserWarning not triggered : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13")
|
||||
@make_dynamo_test
|
||||
def test_assertNotWarns(self):
|
||||
def warn_future():
|
||||
warnings.warn("xyz", FutureWarning, stacklevel=2)
|
||||
|
||||
self.assertMessagesCM(
|
||||
"_assertNotWarns",
|
||||
(FutureWarning,),
|
||||
warn_future,
|
||||
[
|
||||
"^FutureWarning triggered$",
|
||||
"^oops$",
|
||||
"^FutureWarning triggered$",
|
||||
"^FutureWarning triggered : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertWarnsRegex(self):
|
||||
# test error not raised
|
||||
self.assertMessagesCM(
|
||||
"assertWarnsRegex",
|
||||
(UserWarning, "unused regex"),
|
||||
lambda: None,
|
||||
[
|
||||
"^UserWarning not triggered$",
|
||||
"^oops$",
|
||||
"^UserWarning not triggered$",
|
||||
"^UserWarning not triggered : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
# test warning raised but with wrong message
|
||||
def raise_wrong_message():
|
||||
warnings.warn("foo")
|
||||
|
||||
self.assertMessagesCM(
|
||||
"assertWarnsRegex",
|
||||
(UserWarning, "regex"),
|
||||
raise_wrong_message,
|
||||
[
|
||||
'^"regex" does not match "foo"$',
|
||||
"^oops$",
|
||||
'^"regex" does not match "foo"$',
|
||||
'^"regex" does not match "foo" : oops$',
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
@ -291,13 +291,6 @@ observed_exception_map = {
|
|||
}
|
||||
|
||||
|
||||
def create_dynamo_observed_exception(e: type[Exception]) -> None:
|
||||
if e not in observed_exception_map:
|
||||
name = getattr(e, "__name__", str(e))
|
||||
internal_exc = type(f"Observed{name}Error", (ObservedException,), {})
|
||||
observed_exception_map[e] = internal_exc
|
||||
|
||||
|
||||
def raise_observed_exception(
|
||||
exc_type: type[Exception],
|
||||
tx: InstructionTranslatorBase,
|
||||
|
|
@ -311,8 +304,6 @@ def raise_observed_exception(
|
|||
# stack and raise the exception.
|
||||
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
|
||||
tx.exn_vt_stack.append(exception_vt)
|
||||
if exc_type not in observed_exception_map:
|
||||
create_dynamo_observed_exception(exc_type)
|
||||
raise observed_exception_map[exc_type]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3242,34 +3242,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
else:
|
||||
return result
|
||||
|
||||
def update_parent_exn_vt_stack(self):
|
||||
# TODO(anijain2305) - This works but we should probably have a
|
||||
# global/central data structure for the exception stack.
|
||||
parent = self.parent
|
||||
p_exn = parent.exn_vt_stack
|
||||
c_exn = self.exn_vt_stack
|
||||
# Prior to this, the parent stack would be extended with all exceptions from
|
||||
# the child frame. We don't need to append everything as the parent interpreter
|
||||
# can only match the topmost exception from the child. i.e.
|
||||
#
|
||||
# def foo():
|
||||
# try:
|
||||
# raise ValueError
|
||||
# except ValueError as e:
|
||||
# raise TypeError from e
|
||||
#
|
||||
# def bar():
|
||||
# try:
|
||||
# foo()
|
||||
# except TypeError:
|
||||
# # it is impossible for `bar` to match on the first `ValueError`
|
||||
# # raised by `foo`
|
||||
# pass
|
||||
#
|
||||
# For a test case, check test_exceptions::test_raise_match
|
||||
if len(c_exn) > len(p_exn):
|
||||
parent.exn_vt_stack.append(c_exn[-1])
|
||||
|
||||
@staticmethod
|
||||
def build_inline_tracer(
|
||||
parent,
|
||||
|
|
@ -3376,7 +3348,9 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
self.run()
|
||||
except exc.ObservedException as e:
|
||||
msg = f"Observed exception DURING INLING {code} : {e}"
|
||||
self.update_parent_exn_vt_stack()
|
||||
# TODO(anijain2305) - This works but we should probably have a
|
||||
# global/central data structure for the exception stack.
|
||||
parent.exn_vt_stack.extend(self.exn_vt_stack)
|
||||
log.debug(msg)
|
||||
# bubble up the exception to the parent frame.
|
||||
raise
|
||||
|
|
@ -3449,8 +3423,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
)
|
||||
self.funcvar = funcvar
|
||||
self.parent = parent
|
||||
# Propagate any exception on the parent stack
|
||||
self.exn_vt_stack = list(parent.exn_vt_stack)
|
||||
self.num_calls = parent.num_calls
|
||||
self.symbolic_result = None
|
||||
self.nn_module_stack = parent.nn_module_stack.copy()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import sys
|
|||
import traceback
|
||||
import types
|
||||
import typing
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
|
@ -3147,6 +3148,7 @@ BUILTIN_SKIPLIST = (
|
|||
random,
|
||||
traceback,
|
||||
linecache,
|
||||
unittest,
|
||||
)
|
||||
|
||||
# third party libraries skiplist is defined by str, because users may not use these libraries.
|
||||
|
|
|
|||
|
|
@ -1012,12 +1012,10 @@ class VariableBuilder:
|
|||
elif is_lru_cache_wrapped_function(value):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
|
||||
elif value in (traceback.clear_frames,):
|
||||
elif value is traceback.clear_frames:
|
||||
return TracebackVariable(source=self.source)
|
||||
elif (
|
||||
value in (sys.exc_info,)
|
||||
or sys.version_info >= (3, 11)
|
||||
and value is sys.exception
|
||||
elif value is sys.exc_info or (
|
||||
sys.version_info >= (3, 11) and value is sys.exception
|
||||
):
|
||||
return SysFunctionVariable(value, source=self.source)
|
||||
elif is_function_or_wrapper(value) and inspect.getattr_static(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from .. import config, polyfills, variables
|
|||
from ..exc import (
|
||||
AttributeMutationError,
|
||||
ObservedAttributeError,
|
||||
raise_observed_exception,
|
||||
unimplemented,
|
||||
Unsupported,
|
||||
UserError,
|
||||
|
|
@ -874,7 +873,7 @@ class BuiltinVariable(VariableTracker):
|
|||
*[x.as_python_constant() for x in args],
|
||||
)
|
||||
except Exception as exc:
|
||||
raise_observed_exception(type(exc), tx)
|
||||
unimplemented(f"constant fold exception: {repr(exc)}")
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
else:
|
||||
|
|
@ -1219,12 +1218,6 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
# Inline the user function
|
||||
return tx.inline_user_function_return(user_func_variable, [arg], {})
|
||||
elif isinstance(arg, (variables.ExceptionVariable,)):
|
||||
if len(arg.args) == 0:
|
||||
value = f"{arg.exc_type}"
|
||||
else:
|
||||
value = ", ".join(a.as_python_constant() for a in arg.args)
|
||||
return variables.ConstantVariable.create(value=value)
|
||||
|
||||
def _call_min_max(self, tx: "InstructionTranslator", *args):
|
||||
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
|
|
|
|||
|
|
@ -633,11 +633,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
def has_closure(self):
|
||||
return self.closure is not None
|
||||
|
||||
def const_getattr(self, tx, name):
|
||||
if name == "__name__":
|
||||
return self.fn_name.as_python_constant()
|
||||
return super().const_getattr(tx, name)
|
||||
|
||||
def has_self(self):
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -185,7 +185,12 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
try:
|
||||
obj = inspect.getattr_static(self.value, name)
|
||||
except AttributeError:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
if type(self.value) is type:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
else:
|
||||
# Cannot reason about classes with a custom metaclass
|
||||
# See: test_functions::test_getattr_metaclass
|
||||
obj = None
|
||||
|
||||
if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType):
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
|
|
|
|||
Loading…
Reference in a new issue