mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[jit] Better checking for overload function declarations. (#59956)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59956 Issue #50175. Basically two things need to be checked and are lacking currently: 1. Overload declarations should always have a single `pass` statement as the body. 2. There should be always an implementation provided for decls which doesn't have the torch.jit._overload decorator. So in this case we need to check whether we are actually compiling a function body with decorator ahead. Test Plan: python test/test_jit.py TestScript.test_function_overloads Imported from OSS Reviewed By: gmagogsfm Differential Revision: D29106555 fbshipit-source-id: 2d9d7df2fb51ab6db0e1b726f9644e4cfbf733d6
This commit is contained in:
parent
63fa53d37a
commit
e62189ad69
10 changed files with 213 additions and 102 deletions
|
|
@ -14485,6 +14485,47 @@ dedent """
|
|||
with self.assertRaisesRegex(Exception, "Parameters not specified"):
|
||||
torch.jit.script(test)
|
||||
|
||||
def test_function_overload_misuse(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
|
||||
@torch.jit._overload
|
||||
def wrong_decl_body(x: str) -> str:
|
||||
return x + "0"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
|
||||
class MyClass:
|
||||
@torch.jit._overload_method
|
||||
def method(self):
|
||||
return 0
|
||||
|
||||
@torch.jit._overload
|
||||
def null_overload(x: int) -> int: ... # noqa: E704
|
||||
|
||||
@torch.jit._overload
|
||||
def null_overload(x: str) -> str: # noqa: F811
|
||||
pass
|
||||
|
||||
def null_overload_driver():
|
||||
return null_overload(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
|
||||
torch.jit.script(null_overload_driver)
|
||||
|
||||
class OverloadMisuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.jit._overload_method
|
||||
def forward(self, x: int):
|
||||
pass
|
||||
|
||||
@torch.jit._overload_method
|
||||
def forward(self, x: Tensor): # noqa: F811
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
|
||||
m = torch.jit.script(OverloadMisuse())
|
||||
|
||||
|
||||
def test_script_method_torch_function_overload(self):
|
||||
class MyCustomTensor(torch.Tensor):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -18,13 +18,12 @@ import builtins
|
|||
import typing
|
||||
import io
|
||||
import pickle
|
||||
import functools
|
||||
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
|
||||
# Explicitly ask to import `torch.distributed.__init__` first.
|
||||
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
|
||||
import torch.distributed.rpc
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch._C import Future as CFuture
|
||||
from torch._sources import get_source_lines_and_file, parse_def, fake_range
|
||||
from torch.futures import Future
|
||||
import torch.package._mangling as package_mangling
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union # noqa: F401
|
||||
|
|
@ -716,7 +715,50 @@ def copy_torchscript_modifier(orig, new) -> None:
|
|||
# qualified_name => list[overload_functions]
|
||||
_overloaded_fns : Dict[str, List[Callable]] = {} # noqa: T484
|
||||
|
||||
|
||||
_OVERLOAD_EXAMPLE = '''
|
||||
Example usage of overload function:
|
||||
@torch.jit._overload
|
||||
def my_function(x: type0) -> type0: # decl 1
|
||||
pass
|
||||
|
||||
@torch.jit._overload
|
||||
def my_function(x: type1) -> type1: # decl 2
|
||||
pass
|
||||
|
||||
def my_function(x): # implementation
|
||||
if isinstance(x, type0):
|
||||
return x
|
||||
elif isinstance(x, type1):
|
||||
return x
|
||||
'''
|
||||
|
||||
def get_overload_no_implementation_error_message(kind, obj):
|
||||
sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
|
||||
return (
|
||||
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
|
||||
f'sure a definition is provided and defined after all overload declarations.\n'
|
||||
f'File "{filename}", line {file_lineno}:\n' + ''.join(sourcelines) + "\n" + _OVERLOAD_EXAMPLE
|
||||
)
|
||||
|
||||
def _check_overload_body(func):
|
||||
parsed_def = parse_def(func)
|
||||
body = parsed_def.ast.body[0].body
|
||||
|
||||
def is_pass(x):
|
||||
return isinstance(x, ast.Pass)
|
||||
|
||||
def is_ellipsis(x):
|
||||
return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
|
||||
|
||||
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
|
||||
msg = "Only `pass` statement or `...` can be the body of overload declaration:\n"
|
||||
msg += '\n'.join(parsed_def.source.split("\n")[:3])
|
||||
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def _overload(func):
|
||||
_check_overload_body(func)
|
||||
qual_name = _qualified_name(func)
|
||||
global _overloaded_fns
|
||||
fn_overload_list = _overloaded_fns.get(qual_name)
|
||||
|
|
@ -762,6 +804,7 @@ _overloaded_methods : Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
|
|||
_overloaded_method_class_fileno = {}
|
||||
|
||||
def _overload_method(func):
|
||||
_check_overload_body(func)
|
||||
qual_name = _qualified_name(func)
|
||||
global _overloaded_methods
|
||||
class_name_map = _overloaded_methods.get(qual_name, None)
|
||||
|
|
@ -994,22 +1037,6 @@ def _qualified_name(obj) -> str:
|
|||
return module_name + "." + name
|
||||
|
||||
|
||||
# Thin wrapper around SourceRangeFactory to store extra metadata
|
||||
# about the function-to-be-compiled.
|
||||
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
|
||||
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
|
||||
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
|
||||
self.uses_true_division = uses_true_division
|
||||
self.filename = filename
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def make_source_context(*args):
|
||||
return SourceContext(*args)
|
||||
|
||||
def fake_range():
|
||||
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
|
||||
|
||||
|
||||
def _try_get_dispatched_fn(fn):
|
||||
if not callable(fn):
|
||||
return None
|
||||
|
|
|
|||
104
torch/_sources.py
Normal file
104
torch/_sources.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import ast
|
||||
import functools
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any, Optional, Tuple, List, NamedTuple
|
||||
from torch._C import ErrorReport
|
||||
from torch._C._jit_tree_views import SourceRangeFactory
|
||||
|
||||
def get_source_lines_and_file(
|
||||
obj: Any,
|
||||
error_msg: Optional[str] = None,
|
||||
) -> Tuple[List[str], int, Optional[str]]:
|
||||
"""
|
||||
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
|
||||
|
||||
Returns: (sourcelines, file_lino, filename)
|
||||
"""
|
||||
filename = None # in case getsourcefile throws
|
||||
try:
|
||||
filename = inspect.getsourcefile(obj)
|
||||
sourcelines, file_lineno = inspect.getsourcelines(obj)
|
||||
except OSError as e:
|
||||
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
|
||||
"order to carry out compilation, make sure original .py files are "
|
||||
"available.")
|
||||
if error_msg:
|
||||
msg += '\n' + error_msg
|
||||
raise OSError(msg) from e
|
||||
|
||||
return sourcelines, file_lineno, filename
|
||||
|
||||
|
||||
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
||||
"""
|
||||
This helper function accepts a list of source lines. It finds the
|
||||
indentation level of the function definition (`def`), then it indents
|
||||
all lines in the function body to a point at or greater than that
|
||||
level. This allows for comments and continued string literals that
|
||||
are at a lower indentation than the rest of the code.
|
||||
Args:
|
||||
sourcelines: function source code, separated into lines by
|
||||
the '\n' character
|
||||
Returns:
|
||||
A list of source lines that have been correctly aligned
|
||||
"""
|
||||
|
||||
def remove_prefix(text, prefix):
|
||||
return text[text.startswith(prefix) and len(prefix):]
|
||||
|
||||
# Find the line and line number containing the function definition
|
||||
for i, l in enumerate(sourcelines):
|
||||
if l.lstrip().startswith("def"):
|
||||
idx = i
|
||||
break
|
||||
fn_def = sourcelines[idx]
|
||||
|
||||
# Get a string representing the amount of leading whitespace
|
||||
whitespace = fn_def.split("def")[0]
|
||||
|
||||
# Add this leading whitespace to all lines before and after the `def`
|
||||
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
|
||||
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
|
||||
|
||||
# Put it together again
|
||||
aligned_prefix.append(fn_def)
|
||||
return aligned_prefix + aligned_suffix
|
||||
|
||||
|
||||
# Thin wrapper around SourceRangeFactory to store extra metadata
|
||||
# about the function-to-be-compiled.
|
||||
class SourceContext(SourceRangeFactory):
|
||||
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
|
||||
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
|
||||
self.uses_true_division = uses_true_division
|
||||
self.filename = filename
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def make_source_context(*args):
|
||||
return SourceContext(*args)
|
||||
|
||||
|
||||
def fake_range():
|
||||
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
|
||||
|
||||
|
||||
class ParsedDef(NamedTuple):
|
||||
ast: ast.Module
|
||||
ctx: SourceContext
|
||||
source: str
|
||||
filename: Optional[str]
|
||||
file_lineno: int
|
||||
|
||||
def parse_def(fn):
|
||||
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
|
||||
sourcelines = normalize_source_lines(sourcelines)
|
||||
source = ''.join(sourcelines)
|
||||
dedent_src = dedent(source)
|
||||
py_ast = ast.parse(dedent_src)
|
||||
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
||||
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
|
||||
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
|
||||
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
|
||||
return ParsedDef(py_ast, ctx, source, filename, file_lineno)
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
# this arbitrary-looking assortment of functionality is provided here
|
||||
# to have a central place for overrideable behavior. The motivating
|
||||
|
|
@ -44,30 +42,6 @@ def resolve_library_path(path: str) -> str:
|
|||
return os.path.realpath(path)
|
||||
|
||||
|
||||
def get_source_lines_and_file(
|
||||
obj: Any,
|
||||
error_msg: Optional[str] = None,
|
||||
) -> Tuple[List[str], int, Optional[str]]:
|
||||
"""
|
||||
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
|
||||
|
||||
Returns: (sourcelines, file_lino, filename)
|
||||
"""
|
||||
filename = None # in case getsourcefile throws
|
||||
try:
|
||||
filename = inspect.getsourcefile(obj)
|
||||
sourcelines, file_lineno = inspect.getsourcelines(obj)
|
||||
except OSError as e:
|
||||
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
|
||||
"order to carry out compilation, make sure original .py files are "
|
||||
"available.")
|
||||
if error_msg:
|
||||
msg += '\n' + error_msg
|
||||
raise OSError(msg) from e
|
||||
|
||||
return sourcelines, file_lineno, filename
|
||||
|
||||
|
||||
TEST_MASTER_ADDR = '127.0.0.1'
|
||||
TEST_MASTER_PORT = 29500
|
||||
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from types import FunctionType
|
|||
from typing import cast, Union, Callable, Dict, Optional, Any
|
||||
from torch.fx._symbolic_trace import Tracer
|
||||
from torch.fx.graph import Graph
|
||||
from torch.jit.frontend import normalize_source_lines
|
||||
from torch._sources import normalize_source_lines
|
||||
import torch
|
||||
|
||||
class AST_Rewriter(ast.NodeTransformer):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import warnings
|
|||
from typing import Dict, List, Set, Type
|
||||
|
||||
import torch._jit_internal as _jit_internal
|
||||
from torch._sources import fake_range
|
||||
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties
|
||||
from torch.jit._builtins import _find_builtin
|
||||
from torch.jit._check import AttributeTypeIsSupportedChecker
|
||||
|
|
@ -148,10 +149,10 @@ def infer_concrete_type_builder(nn_module, share_types=True):
|
|||
inferred = False
|
||||
try:
|
||||
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
|
||||
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
|
||||
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
|
||||
attr_type = torch._C.InferredType(ann_to_type)
|
||||
elif isinstance(item, torch.jit.Attribute):
|
||||
ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
|
||||
ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
|
||||
attr_type = torch._C.InferredType(ann_to_type)
|
||||
else:
|
||||
attr_type = torch._C._jit_try_infer_type(item)
|
||||
|
|
@ -620,6 +621,10 @@ def get_overload_annotations(mod, jit_ignored_properties):
|
|||
if method_overloads is None:
|
||||
continue
|
||||
|
||||
if item.__func__ in method_overloads:
|
||||
raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
|
||||
'method', item.__func__))
|
||||
|
||||
names = [name + "__" + str(i) for i in range(len(method_overloads))]
|
||||
overloads[item] = list(zip(names, method_overloads))
|
||||
|
||||
|
|
@ -639,7 +644,7 @@ def get_overload_name_mapping(overload_info):
|
|||
return overload_name_mappings
|
||||
|
||||
def _check_no_signature(func):
|
||||
signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func))
|
||||
signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func))
|
||||
if signature is None:
|
||||
qual_name = _jit_internal._qualified_name(func)
|
||||
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
|
||||
|
|
|
|||
|
|
@ -1337,6 +1337,10 @@ def _get_overloads(obj):
|
|||
if uncompiled_overloads is None:
|
||||
return existing_compiled_fns
|
||||
|
||||
if obj in uncompiled_overloads:
|
||||
raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
|
||||
'function', obj))
|
||||
|
||||
compiled_fns = []
|
||||
for overload_fn in uncompiled_overloads:
|
||||
compiled_fns.append(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \
|
|||
|
||||
|
||||
from textwrap import dedent
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch._sources import get_source_lines_and_file
|
||||
from typing import Type
|
||||
|
||||
if torch.distributed.rpc.is_available():
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ from torch._C._jit_tree_views import (
|
|||
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
|
||||
DictComp,
|
||||
)
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch._sources import get_source_lines_and_file, parse_def, make_source_context
|
||||
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
|
||||
from torch._jit_internal import make_source_context, should_drop, is_static_fn, FunctionModifiers # noqa: F401
|
||||
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
|
||||
import torch.jit.annotations
|
||||
|
||||
_IS_ASTUNPARSE_INSTALLED = False
|
||||
|
|
@ -215,42 +215,6 @@ def get_jit_class_def(cls, self_name):
|
|||
return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
|
||||
|
||||
|
||||
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
||||
"""
|
||||
This helper function accepts a list of source lines. It finds the
|
||||
indentation level of the function definition (`def`), then it indents
|
||||
all lines in the function body to a point at or greater than that
|
||||
level. This allows for comments and continued string literals that
|
||||
are at a lower indentation than the rest of the code.
|
||||
Args:
|
||||
sourcelines: function source code, separated into lines by
|
||||
the '\n' character
|
||||
Returns:
|
||||
A list of source lines that have been correctly aligned
|
||||
"""
|
||||
|
||||
def remove_prefix(text, prefix):
|
||||
return text[text.startswith(prefix) and len(prefix):]
|
||||
|
||||
# Find the line and line number containing the function definition
|
||||
for i, l in enumerate(sourcelines):
|
||||
if l.lstrip().startswith("def"):
|
||||
idx = i
|
||||
break
|
||||
fn_def = sourcelines[idx]
|
||||
|
||||
# Get a string representing the amount of leading whitespace
|
||||
whitespace = fn_def.split("def")[0]
|
||||
|
||||
# Add this leading whitespace to all lines before and after the `def`
|
||||
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
|
||||
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
|
||||
|
||||
# Put it together again
|
||||
aligned_prefix.append(fn_def)
|
||||
return aligned_prefix + aligned_suffix
|
||||
|
||||
|
||||
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
"""
|
||||
Build a JIT AST (TreeView) from the given function.
|
||||
|
|
@ -266,17 +230,9 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
|||
but we want the result AST to have the name "forward".
|
||||
self_name: If this function is a method, what the type name of `self` is.
|
||||
"""
|
||||
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
|
||||
sourcelines = normalize_source_lines(sourcelines)
|
||||
source = ''.join(sourcelines)
|
||||
dedent_src = dedent(source)
|
||||
py_ast = ast.parse(dedent_src)
|
||||
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
||||
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
|
||||
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
|
||||
type_line = torch.jit.annotations.get_type_line(source)
|
||||
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True)
|
||||
fn_def = py_ast.body[0]
|
||||
parsed_def = parse_def(fn)
|
||||
type_line = torch.jit.annotations.get_type_line(parsed_def.source)
|
||||
fn_def = parsed_def.ast.body[0]
|
||||
|
||||
if is_classmethod:
|
||||
arg_name = fn_def.args.args[0].arg
|
||||
|
|
@ -288,7 +244,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
|||
if should_drop(fn):
|
||||
unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")
|
||||
if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef):
|
||||
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
|
||||
raise RuntimeError(f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}")
|
||||
unused_def = unused_fn_def.body[0]
|
||||
fn_def.body = unused_def.body
|
||||
# kwarg/vararg not supported by `build_def`
|
||||
|
|
@ -305,7 +261,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
|||
qualname = get_qualified_name(fn)
|
||||
pdt_arg_types = type_trace_db.get_args_types(qualname)
|
||||
|
||||
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
|
||||
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
|
||||
|
||||
# TODO: more robust handling of recognizing ignore context manager
|
||||
def is_torch_jit_ignore_context_manager(stmt):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import warnings
|
|||
from contextlib import closing, contextmanager
|
||||
from ._utils import _import_dotted_name
|
||||
from ._six import string_classes as _string_classes
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch._sources import get_source_lines_and_file
|
||||
from torch.types import Storage
|
||||
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
|
||||
import copyreg
|
||||
|
|
|
|||
Loading…
Reference in a new issue