Disallow annotations on instance attributes outside __init__ (#67051)

Summary:
**Summary**: This commit solves the first part of https://github.com/pytorch/pytorch/issues/52306, which disallows type annotations on instance attributes inside any method other than the constructor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67051

Test Plan:
Added test to test_types.py.

**Reviewers**: Zhengxu Chen

**Subscribers**: Zhengxu Chen, Yanan Cao, Peng Wu, Yining Lu

**Tasks**: T103941984

**Tags**: pytorch

**Fixes** https://github.com/pytorch/pytorch/issues/52306

Reviewed By: zhxchen17

Differential Revision: D31843527

Pulled By: andrewor14

fbshipit-source-id: 624879ae801621e367c59228be8b0581ecd30ef4
This commit is contained in:
andrewor 2021-10-25 16:19:00 -07:00 committed by Facebook GitHub Bot
parent 1f55dd83ac
commit 0d7d446154
3 changed files with 50 additions and 2 deletions

View file

@ -272,3 +272,37 @@ class TestTypesAndAnnotation(JitTestCase):
x = 5
if 1 == 1:
x : Optional[int] = 7
def test_annotate_outside_init(self):
msg = "annotations on instance attributes must be declared in __init__"
highlight = "self.x: int"
# Simple case
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
@torch.jit.script
class BadModule(object):
def __init__(self, x: int):
self.x = x
def set(self, val: int):
self.x: int = val
# Type annotation in a loop
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
@torch.jit.script
class BadModuleLoop(object):
def __init__(self, x: int):
self.x = x
def set(self, val: int):
for i in range(3):
self.x: int = val
# Type annotation in __init__, should not fail
@torch.jit.script
class GoodModule(object):
def __init__(self, x: int):
self.x: int = x
def set(self, val: int):
self.x = val

View file

@ -69,10 +69,11 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]:
# Thin wrapper around SourceRangeFactory to store extra metadata
# about the function-to-be-compiled.
class SourceContext(SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None):
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
self.funcname = funcname
@functools.lru_cache(maxsize=None)
@ -100,5 +101,5 @@ def parse_def(fn):
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__)
return ParsedDef(py_ast, ctx, source, filename, file_lineno)

View file

@ -533,6 +533,19 @@ class StmtBuilder(Builder):
def build_AnnAssign(ctx, stmt):
if stmt.value is None:
raise UnsupportedNodeError(ctx, stmt, reason='without assigned value')
# Disallow type annotations on instance attributes outside of __init__
if type(stmt.target) == ast.Attribute and\
stmt.target.value.id == "self" and\
ctx.funcname != "__init__":
start = stmt.col_offset
end = start + len(f"self.{stmt.target.attr}")
if hasattr(stmt.annotation, 'id'):
end += len(f": {stmt.annotation.id}")
sr = ctx.make_range(stmt.lineno, start, end)
raise ValueError("Type annotations on instance attributes must be declared in "
f"__init__, not '{ctx.funcname}': {sr}")
rhs = build_expr(ctx, stmt.value)
lhs = build_expr(ctx, stmt.target)
the_type = build_expr(ctx, stmt.annotation)