From 0d7d44615447d336feebeda22f03be40418b9f84 Mon Sep 17 00:00:00 2001 From: andrewor Date: Mon, 25 Oct 2021 16:19:00 -0700 Subject: [PATCH] Disallow annotations on instance attributes outside __init__ (#67051) Summary: **Summary**: This commit solves the first part of https://github.com/pytorch/pytorch/issues/52306, which disallows type annotations on instance attributes inside any method other than the constructor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67051 Test Plan: Added test to test_types.py. **Reviewers**: Zhengxu Chen **Subscribers**: Zhengxu Chen, Yanan Cao, Peng Wu, Yining Lu **Tasks**: T103941984 **Tags**: pytorch **Fixes** https://github.com/pytorch/pytorch/issues/52306 Reviewed By: zhxchen17 Differential Revision: D31843527 Pulled By: andrewor14 fbshipit-source-id: 624879ae801621e367c59228be8b0581ecd30ef4 --- test/jit/test_types.py | 34 ++++++++++++++++++++++++++++++++++ torch/_sources.py | 5 +++-- torch/jit/frontend.py | 13 +++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/test/jit/test_types.py b/test/jit/test_types.py index be952a6d006..9fadbedb272 100644 --- a/test/jit/test_types.py +++ b/test/jit/test_types.py @@ -272,3 +272,37 @@ class TestTypesAndAnnotation(JitTestCase): x = 5 if 1 == 1: x : Optional[int] = 7 + + def test_annotate_outside_init(self): + msg = "annotations on instance attributes must be declared in __init__" + highlight = "self.x: int" + + # Simple case + with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight): + @torch.jit.script + class BadModule(object): + def __init__(self, x: int): + self.x = x + + def set(self, val: int): + self.x: int = val + + # Type annotation in a loop + with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight): + @torch.jit.script + class BadModuleLoop(object): + def __init__(self, x: int): + self.x = x + + def set(self, val: int): + for i in range(3): + self.x: int = val + + # Type annotation in __init__, should not fail + @torch.jit.script + class GoodModule(object): + def __init__(self, x: int): + self.x: int = x + + def set(self, val: int): + self.x = val diff --git a/torch/_sources.py b/torch/_sources.py index 24649490e40..e3d9064b38b 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -69,10 +69,11 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]: # Thin wrapper around SourceRangeFactory to store extra metadata # about the function-to-be-compiled. class SourceContext(SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True): + def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None): super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len) self.uses_true_division = uses_true_division self.filename = filename + self.funcname = funcname @functools.lru_cache(maxsize=None) @@ -100,5 +101,5 @@ def parse_def(fn): if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) - ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True) + ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__) return ParsedDef(py_ast, ctx, source, filename, file_lineno) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 09be22d6618..fbbe962d40b 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -533,6 +533,19 @@ class StmtBuilder(Builder): def build_AnnAssign(ctx, stmt): if stmt.value is None: raise UnsupportedNodeError(ctx, stmt, reason='without assigned value') + + # Disallow type annotations on instance attributes outside of __init__ + if type(stmt.target) == ast.Attribute and\ + stmt.target.value.id == "self" and\ + ctx.funcname != "__init__": + start = stmt.col_offset + end = start + len(f"self.{stmt.target.attr}") + if hasattr(stmt.annotation, 'id'): + end += len(f": {stmt.annotation.id}") + sr = ctx.make_range(stmt.lineno, start, end) + raise ValueError("Type annotations on instance attributes must be declared in " + f"__init__, not '{ctx.funcname}': {sr}") + rhs = build_expr(ctx, stmt.value) lhs = build_expr(ctx, stmt.target) the_type = build_expr(ctx, stmt.annotation)