diff --git a/test/jit/mydecorator.py b/test/jit/mydecorator.py new file mode 100644 index 00000000000..18c84b92103 --- /dev/null +++ b/test/jit/mydecorator.py @@ -0,0 +1,20 @@ +r""" +Decorator used in test_decorator.py. We define it in a +separate file on purpose to test that the names in different modules +are resolved correctly. +""" + +import functools + + +def my_decorator(func): + """Dummy decorator that removes itself when torchscripting""" + + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + return func(*args, **kwargs) + + # torch.jit.script() uses __prepare_scriptable__ to remove the decorator + wrapped_func.__prepare_scriptable__ = lambda: func + + return wrapped_func diff --git a/test/jit/myfunction_a.py b/test/jit/myfunction_a.py new file mode 100644 index 00000000000..34efea69dcb --- /dev/null +++ b/test/jit/myfunction_a.py @@ -0,0 +1,13 @@ +""" +Helper function used in test_decorator.py. We define it in a +separate file on purpose to test that the names in different modules +are resolved correctly. +""" + +from jit.mydecorator import my_decorator +from jit.myfunction_b import my_function_b + + +@my_decorator +def my_function_a(x: float) -> float: + return my_function_b(x) + 1 diff --git a/test/jit/myfunction_b.py b/test/jit/myfunction_b.py new file mode 100644 index 00000000000..6407672defe --- /dev/null +++ b/test/jit/myfunction_b.py @@ -0,0 +1,16 @@ +r""" +Helper function used in test_decorator.py. We define it in a +separate file on purpose to test that the names in different modules +are resolved correctly. +""" + +from jit.mydecorator import my_decorator + + +@my_decorator +def my_function_b(x: float) -> float: + return my_function_c(x) + 2 + + +def my_function_c(x: float) -> float: + return x + 3 diff --git a/test/jit/test_decorator.py b/test/jit/test_decorator.py new file mode 100644 index 00000000000..132935c37a7 --- /dev/null +++ b/test/jit/test_decorator.py @@ -0,0 +1,27 @@ +# Owner(s): ["oncall: jit"] +# flake8: noqa + +import sys +import unittest +from enum import Enum +from typing import List, Optional + +import torch +from torch.testing._internal.jit_utils import JitTestCase + +from jit.myfunction_a import my_function_a + + +class TestDecorator(JitTestCase): + def test_decorator(self): + # Note: JitTestCase.checkScript() does not work with decorators + # self.checkScript(my_function_a, (1.0,)) + # Error: + # RuntimeError: expected def but found '@' here: + # @my_decorator + # ~ <--- HERE + # def my_function_a(x: float) -> float: + # Do a simple torch.jit.script() test instead + fn = my_function_a + fx = torch.jit.script(fn) + self.assertEqual(fn(1.0), fx(1.0)) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 716b2ef002d..7795dd7e518 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -996,6 +996,10 @@ def try_compile_fn(fn, loc): f"Consider manually annotating `{fn}` with @torch.jit.script." ) + # The object returned by __prepare_scriptable__ might have a different closure. + # Resolve it here to get the right resolution callback. + fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator] + # We don't have the actual scope where the function was defined, but we can # extract the necessary info from the closed over variables on the function # object