From ffe301846b0b2c7a3192bfb7d5673d47293113ae Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 7 Apr 2021 23:47:32 -0700 Subject: [PATCH] [Hackathon] Add error source range highlighting check in test_hash and test_list_dict (#55490) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55490 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D27628697 Pulled By: tugsbayasgalan fbshipit-source-id: 694226f0b083606f665569e6a84d547026c7f19f --- test/jit/test_hash.py | 2 +- test/jit/test_list_dict.py | 38 ++++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/test/jit/test_hash.py b/test/jit/test_hash.py index 13c4761fc2b..3ef56d879d4 100644 --- a/test/jit/test_hash.py +++ b/test/jit/test_hash.py @@ -31,7 +31,7 @@ class TestHash(JitTestCase): def fn_unhashable(t1: Tuple[int, List[int]]): return hash(t1) - with self.assertRaisesRegex(RuntimeError, "unhashable"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"): fn_unhashable((1, [1])) def test_hash_tensor(self): diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index aec41c451cf..61b293da050 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -92,7 +92,7 @@ class TestList(JitTestCase): if 1 == 1: x = [1, 2, 3] return - with self.assertRaisesRegex(RuntimeError, r"previously has type List\[Tensor\]"): + with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously has type List\[Tensor\]", "x"): self.checkScript(reassign_from_empty_literal, (), optimize=False) def reassign_from_empty_builtin(): @@ -113,7 +113,7 @@ class TestList(JitTestCase): if 1 == 1: x = [1.0] return - with self.assertRaisesRegex(RuntimeError, "previously has type"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"): self.checkScript(reassign_bad_type, (), optimize=False) def reassign_nested(): @@ -123,7 +123,7 @@ class TestList(JitTestCase): if 1 == 1: x = [1.0] return - with self.assertRaisesRegex(RuntimeError, "previously has type"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"): self.checkScript(reassign_nested, (), optimize=False) def test_del(self): @@ -147,10 +147,10 @@ class TestList(JitTestCase): del x[100] return x - with self.assertRaisesRegex(RuntimeError, "out of range"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "out of range", "x[100]"): fn2([]) - with self.assertRaisesRegex(RuntimeError, "deletion at a single index"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "deletion at a single index", "x[1:3]"): @torch.jit.script def fn(x: List[int]) -> List[int]: del x[1:3] @@ -247,8 +247,9 @@ class TestList(JitTestCase): # TODO: This fails during function schema matching, so the error # message is not very informative to the user. Change logic so # that the error is thrown at a different time? - with self.assertRaisesRegex(RuntimeError, "Arguments for call " - "are not valid"): + err_msg = "Arguments for call are not valid" + highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3" + with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg): @torch.jit.script def fn(): x: Dict[int, str] = dict([("foo", 1), ("bar", 2), ("baz", 3)]) # noqa: C406 @@ -577,7 +578,8 @@ class TestList(JitTestCase): def test_index_slice_out_of_bounds_index(x): x = x[[4], :, :] return x - with self.assertRaisesRegex(RuntimeError, "index 4 is out of bounds for dimension 0 with size 3"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "index 4 is out of bounds for dimension 0 with size 3", + "x[[4], :, :]"): self.checkScript(test_index_slice_out_of_bounds_index, (a,)) def test_mutable_list_append(self): @@ -753,7 +755,7 @@ class TestList(JitTestCase): a = torch.jit.annotate(List[int], []) return a.pop() - with self.assertRaisesRegex(RuntimeError, "pop from empty list"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "pop from empty list", "a.pop"): test_pop_empty() def test_mutable_list_pop(self): @@ -878,7 +880,7 @@ class TestList(JitTestCase): return a - with self.assertRaisesRegex(RuntimeError, "x not in list"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "x not in list", "a.remove"): test_list_remove_not_existing() def test_mutable_list_remove(self): @@ -904,7 +906,7 @@ class TestList(JitTestCase): return i - with self.assertRaisesRegex(RuntimeError, "'5' is not in list"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "'5' is not in list", "a.index"): list_index_not_existing() def test_list_index(self): @@ -938,7 +940,7 @@ class TestList(JitTestCase): return i - with self.assertRaisesRegex(RuntimeError, "is not in list"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "is not in list", "a.index"): tensor_list_index_not_existing() def test_list_count(self): @@ -1386,7 +1388,7 @@ class TestDict(JitTestCase): cu.define(dedent(inspect.getsource(fn))) self.assertEqual(cu.fn(inputs()), python_out) self.assertEqual(torch.jit.script(fn)(inputs()), python_out) - with self.assertRaisesRegex(RuntimeError, "KeyError"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x['hi']"): self.checkScript(fn, [{}]) def test_keys(self): @@ -1452,7 +1454,7 @@ class TestDict(JitTestCase): tester(pop, 'a') - with self.assertRaisesRegex(RuntimeError, "KeyError"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x.pop"): torch.jit.script(pop)(self.dict(), 'x') @@ -1579,7 +1581,7 @@ class TestDict(JitTestCase): def missing_index(x: Dict[str, int]) -> int: return x['dne'] - with self.assertRaisesRegex(RuntimeError, "KeyError"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x['d"): missing_index({'item': 20, 'other_item': 120}) code = dedent(''' @@ -1613,7 +1615,7 @@ class TestDict(JitTestCase): self.assertEqual(fn(), {'ok': 10}) def test_key_type(self): - with self.assertRaisesRegex(RuntimeError, "but instead found type"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "but instead found type", "a[None]"): @torch.jit.script def fn(a: Dict[str, int]) -> int: return a[None] @@ -1653,7 +1655,7 @@ class TestDict(JitTestCase): self.checkScript(fn, (d, 3)) self.checkScript(fn, (d, 2)) - with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "is actually of type Optional", "return x.get(y"): @torch.jit.script def bad_types(x: Dict[int, int], y: int) -> int: return x.get(y) # noqa: T484 @@ -1706,7 +1708,7 @@ class TestDict(JitTestCase): a[1] = 2 return a - with self.assertRaisesRegex(Exception, "Arguments for call are not"): + with self.assertRaisesRegexWithHighlight(Exception, "Arguments for call are not", "a[1] = 2"): torch.jit.script(test_dict_error) def test_type_annotation_missing_contained_type(self):