[jit] Better checking for overload function declarations. (#59956)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59956

Issue #50175. Basically two things need to be checked and are lacking currently:
1. Overload declarations should always have a single `pass` statement as the body.
2. There should be always an implementation provided for decls which doesn't
   have the torch.jit._overload decorator. So in this case we need to check
   whether we are actually compiling a function body with decorator ahead.

Test Plan:
python test/test_jit.py TestScript.test_function_overloads

Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D29106555

fbshipit-source-id: 2d9d7df2fb51ab6db0e1b726f9644e4cfbf733d6
This commit is contained in:
Zhengxu Chen 2021-08-05 14:19:56 -07:00 committed by Facebook GitHub Bot
parent 63fa53d37a
commit e62189ad69
10 changed files with 213 additions and 102 deletions

View file

@ -14485,6 +14485,47 @@ dedent """
with self.assertRaisesRegex(Exception, "Parameters not specified"):
torch.jit.script(test)
def test_function_overload_misuse(self):
with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
@torch.jit._overload
def wrong_decl_body(x: str) -> str:
return x + "0"
with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
class MyClass:
@torch.jit._overload_method
def method(self):
return 0
@torch.jit._overload
def null_overload(x: int) -> int: ... # noqa: E704
@torch.jit._overload
def null_overload(x: str) -> str: # noqa: F811
pass
def null_overload_driver():
return null_overload(0)
with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
torch.jit.script(null_overload_driver)
class OverloadMisuse(torch.nn.Module):
def __init__(self):
super().__init__()
@torch.jit._overload_method
def forward(self, x: int):
pass
@torch.jit._overload_method
def forward(self, x: Tensor): # noqa: F811
pass
with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
m = torch.jit.script(OverloadMisuse())
def test_script_method_torch_function_overload(self):
class MyCustomTensor(torch.Tensor):
pass

View file

@ -18,13 +18,12 @@ import builtins
import typing
import io
import pickle
import functools
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
# Explicitly ask to import `torch.distributed.__init__` first.
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
import torch.distributed.rpc
from torch._utils_internal import get_source_lines_and_file
from torch._C import Future as CFuture
from torch._sources import get_source_lines_and_file, parse_def, fake_range
from torch.futures import Future
import torch.package._mangling as package_mangling
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union # noqa: F401
@ -716,7 +715,50 @@ def copy_torchscript_modifier(orig, new) -> None:
# qualified_name => list[overload_functions]
_overloaded_fns : Dict[str, List[Callable]] = {} # noqa: T484
_OVERLOAD_EXAMPLE = '''
Example usage of overload function:
@torch.jit._overload
def my_function(x: type0) -> type0: # decl 1
pass
@torch.jit._overload
def my_function(x: type1) -> type1: # decl 2
pass
def my_function(x): # implementation
if isinstance(x, type0):
return x
elif isinstance(x, type1):
return x
'''
def get_overload_no_implementation_error_message(kind, obj):
sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
return (
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
f'sure a definition is provided and defined after all overload declarations.\n'
f'File "{filename}", line {file_lineno}:\n' + ''.join(sourcelines) + "\n" + _OVERLOAD_EXAMPLE
)
def _check_overload_body(func):
parsed_def = parse_def(func)
body = parsed_def.ast.body[0].body
def is_pass(x):
return isinstance(x, ast.Pass)
def is_ellipsis(x):
return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
msg = "Only `pass` statement or `...` can be the body of overload declaration:\n"
msg += '\n'.join(parsed_def.source.split("\n")[:3])
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
raise RuntimeError(msg)
def _overload(func):
_check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_fns
fn_overload_list = _overloaded_fns.get(qual_name)
@ -762,6 +804,7 @@ _overloaded_methods : Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
_overloaded_method_class_fileno = {}
def _overload_method(func):
_check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_methods
class_name_map = _overloaded_methods.get(qual_name, None)
@ -994,22 +1037,6 @@ def _qualified_name(obj) -> str:
return module_name + "." + name
# Thin wrapper around SourceRangeFactory to store extra metadata
# about the function-to-be-compiled.
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
@functools.lru_cache(maxsize=None)
def make_source_context(*args):
return SourceContext(*args)
def fake_range():
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None

104
torch/_sources.py Normal file
View file

@ -0,0 +1,104 @@
import ast
import functools
import inspect
from textwrap import dedent
from typing import Any, Optional, Tuple, List, NamedTuple
from torch._C import ErrorReport
from torch._C._jit_tree_views import SourceRangeFactory
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> Tuple[List[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
Returns: (sourcelines, file_lino, filename)
"""
filename = None # in case getsourcefile throws
try:
filename = inspect.getsourcefile(obj)
sourcelines, file_lineno = inspect.getsourcelines(obj)
except OSError as e:
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
"order to carry out compilation, make sure original .py files are "
"available.")
if error_msg:
msg += '\n' + error_msg
raise OSError(msg) from e
return sourcelines, file_lineno, filename
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
"""
This helper function accepts a list of source lines. It finds the
indentation level of the function definition (`def`), then it indents
all lines in the function body to a point at or greater than that
level. This allows for comments and continued string literals that
are at a lower indentation than the rest of the code.
Args:
sourcelines: function source code, separated into lines by
the '\n' character
Returns:
A list of source lines that have been correctly aligned
"""
def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix):]
# Find the line and line number containing the function definition
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
fn_def = sourcelines[idx]
# Get a string representing the amount of leading whitespace
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
# Put it together again
aligned_prefix.append(fn_def)
return aligned_prefix + aligned_suffix
# Thin wrapper around SourceRangeFactory to store extra metadata
# about the function-to-be-compiled.
class SourceContext(SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
@functools.lru_cache(maxsize=None)
def make_source_context(*args):
return SourceContext(*args)
def fake_range():
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
class ParsedDef(NamedTuple):
ast: ast.Module
ctx: SourceContext
source: str
filename: Optional[str]
file_lineno: int
def parse_def(fn):
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
return ParsedDef(py_ast, ctx, source, filename, file_lineno)

View file

@ -1,9 +1,7 @@
import os
import inspect
import sys
import tempfile
from typing import Any, List, Optional, Tuple
# this arbitrary-looking assortment of functionality is provided here
# to have a central place for overrideable behavior. The motivating
@ -44,30 +42,6 @@ def resolve_library_path(path: str) -> str:
return os.path.realpath(path)
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> Tuple[List[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
Returns: (sourcelines, file_lino, filename)
"""
filename = None # in case getsourcefile throws
try:
filename = inspect.getsourcefile(obj)
sourcelines, file_lineno = inspect.getsourcelines(obj)
except OSError as e:
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
"order to carry out compilation, make sure original .py files are "
"available.")
if error_msg:
msg += '\n' + error_msg
raise OSError(msg) from e
return sourcelines, file_lineno, filename
TEST_MASTER_ADDR = '127.0.0.1'
TEST_MASTER_PORT = 29500
# USE_GLOBAL_DEPS controls whether __init__.py tries to load

View file

@ -6,7 +6,7 @@ from types import FunctionType
from typing import cast, Union, Callable, Dict, Optional, Any
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
from torch.jit.frontend import normalize_source_lines
from torch._sources import normalize_source_lines
import torch
class AST_Rewriter(ast.NodeTransformer):

View file

@ -8,6 +8,7 @@ import warnings
from typing import Dict, List, Set, Type
import torch._jit_internal as _jit_internal
from torch._sources import fake_range
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties
from torch.jit._builtins import _find_builtin
from torch.jit._check import AttributeTypeIsSupportedChecker
@ -148,10 +149,10 @@ def infer_concrete_type_builder(nn_module, share_types=True):
inferred = False
try:
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
attr_type = torch._C.InferredType(ann_to_type)
elif isinstance(item, torch.jit.Attribute):
ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
attr_type = torch._C.InferredType(ann_to_type)
else:
attr_type = torch._C._jit_try_infer_type(item)
@ -620,6 +621,10 @@ def get_overload_annotations(mod, jit_ignored_properties):
if method_overloads is None:
continue
if item.__func__ in method_overloads:
raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
'method', item.__func__))
names = [name + "__" + str(i) for i in range(len(method_overloads))]
overloads[item] = list(zip(names, method_overloads))
@ -639,7 +644,7 @@ def get_overload_name_mapping(overload_info):
return overload_name_mappings
def _check_no_signature(func):
signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func))
signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func))
if signature is None:
qual_name = _jit_internal._qualified_name(func)
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))

View file

@ -1337,6 +1337,10 @@ def _get_overloads(obj):
if uncompiled_overloads is None:
return existing_compiled_fns
if obj in uncompiled_overloads:
raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
'function', obj))
compiled_fns = []
for overload_fn in uncompiled_overloads:
compiled_fns.append(

View file

@ -16,7 +16,7 @@ from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
from textwrap import dedent
from torch._utils_internal import get_source_lines_and_file
from torch._sources import get_source_lines_and_file
from typing import Type
if torch.distributed.rpc.is_available():

View file

@ -17,9 +17,9 @@ from torch._C._jit_tree_views import (
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
from torch._utils_internal import get_source_lines_and_file
from torch._sources import get_source_lines_and_file, parse_def, make_source_context
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
from torch._jit_internal import make_source_context, should_drop, is_static_fn, FunctionModifiers # noqa: F401
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
import torch.jit.annotations
_IS_ASTUNPARSE_INSTALLED = False
@ -215,42 +215,6 @@ def get_jit_class_def(cls, self_name):
return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
"""
This helper function accepts a list of source lines. It finds the
indentation level of the function definition (`def`), then it indents
all lines in the function body to a point at or greater than that
level. This allows for comments and continued string literals that
are at a lower indentation than the rest of the code.
Args:
sourcelines: function source code, separated into lines by
the '\n' character
Returns:
A list of source lines that have been correctly aligned
"""
def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix):]
# Find the line and line number containing the function definition
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
fn_def = sourcelines[idx]
# Get a string representing the amount of leading whitespace
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
# Put it together again
aligned_prefix.append(fn_def)
return aligned_prefix + aligned_suffix
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
"""
Build a JIT AST (TreeView) from the given function.
@ -266,17 +230,9 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
but we want the result AST to have the name "forward".
self_name: If this function is a method, what the type name of `self` is.
"""
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
type_line = torch.jit.annotations.get_type_line(source)
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
fn_def = py_ast.body[0]
parsed_def = parse_def(fn)
type_line = torch.jit.annotations.get_type_line(parsed_def.source)
fn_def = parsed_def.ast.body[0]
if is_classmethod:
arg_name = fn_def.args.args[0].arg
@ -288,7 +244,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
if should_drop(fn):
unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")
if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
raise RuntimeError(f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}")
unused_def = unused_fn_def.body[0]
fn_def.body = unused_def.body
# kwarg/vararg not supported by `build_def`
@ -305,7 +261,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
qualname = get_qualified_name(fn)
pdt_arg_types = type_trace_db.get_args_types(qualname)
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
# TODO: more robust handling of recognizing ignore context manager
def is_torch_jit_ignore_context_manager(stmt):

View file

@ -11,7 +11,7 @@ import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._utils_internal import get_source_lines_and_file
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg