mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Disallow annotations on instance attributes outside __init__ (#67051)
Summary: **Summary**: This commit solves the first part of https://github.com/pytorch/pytorch/issues/52306, which disallows type annotations on instance attributes inside any method other than the constructor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67051 Test Plan: Added test to test_types.py. **Reviewers**: Zhengxu Chen **Subscribers**: Zhengxu Chen, Yanan Cao, Peng Wu, Yining Lu **Tasks**: T103941984 **Tags**: pytorch **Fixes** https://github.com/pytorch/pytorch/issues/52306 Reviewed By: zhxchen17 Differential Revision: D31843527 Pulled By: andrewor14 fbshipit-source-id: 624879ae801621e367c59228be8b0581ecd30ef4
This commit is contained in:
parent
1f55dd83ac
commit
0d7d446154
3 changed files with 50 additions and 2 deletions
|
|
@ -272,3 +272,37 @@ class TestTypesAndAnnotation(JitTestCase):
|
|||
x = 5
|
||||
if 1 == 1:
|
||||
x : Optional[int] = 7
|
||||
|
||||
def test_annotate_outside_init(self):
|
||||
msg = "annotations on instance attributes must be declared in __init__"
|
||||
highlight = "self.x: int"
|
||||
|
||||
# Simple case
|
||||
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
|
||||
@torch.jit.script
|
||||
class BadModule(object):
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
def set(self, val: int):
|
||||
self.x: int = val
|
||||
|
||||
# Type annotation in a loop
|
||||
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
|
||||
@torch.jit.script
|
||||
class BadModuleLoop(object):
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
def set(self, val: int):
|
||||
for i in range(3):
|
||||
self.x: int = val
|
||||
|
||||
# Type annotation in __init__, should not fail
|
||||
@torch.jit.script
|
||||
class GoodModule(object):
|
||||
def __init__(self, x: int):
|
||||
self.x: int = x
|
||||
|
||||
def set(self, val: int):
|
||||
self.x = val
|
||||
|
|
|
|||
|
|
@ -69,10 +69,11 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
|||
# Thin wrapper around SourceRangeFactory to store extra metadata
|
||||
# about the function-to-be-compiled.
|
||||
class SourceContext(SourceRangeFactory):
|
||||
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
|
||||
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None):
|
||||
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
|
||||
self.uses_true_division = uses_true_division
|
||||
self.filename = filename
|
||||
self.funcname = funcname
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
|
@ -100,5 +101,5 @@ def parse_def(fn):
|
|||
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
||||
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
|
||||
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
|
||||
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
|
||||
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__)
|
||||
return ParsedDef(py_ast, ctx, source, filename, file_lineno)
|
||||
|
|
|
|||
|
|
@ -533,6 +533,19 @@ class StmtBuilder(Builder):
|
|||
def build_AnnAssign(ctx, stmt):
|
||||
if stmt.value is None:
|
||||
raise UnsupportedNodeError(ctx, stmt, reason='without assigned value')
|
||||
|
||||
# Disallow type annotations on instance attributes outside of __init__
|
||||
if type(stmt.target) == ast.Attribute and\
|
||||
stmt.target.value.id == "self" and\
|
||||
ctx.funcname != "__init__":
|
||||
start = stmt.col_offset
|
||||
end = start + len(f"self.{stmt.target.attr}")
|
||||
if hasattr(stmt.annotation, 'id'):
|
||||
end += len(f": {stmt.annotation.id}")
|
||||
sr = ctx.make_range(stmt.lineno, start, end)
|
||||
raise ValueError("Type annotations on instance attributes must be declared in "
|
||||
f"__init__, not '{ctx.funcname}': {sr}")
|
||||
|
||||
rhs = build_expr(ctx, stmt.value)
|
||||
lhs = build_expr(ctx, stmt.target)
|
||||
the_type = build_expr(ctx, stmt.annotation)
|
||||
|
|
|
|||
Loading…
Reference in a new issue