diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 4d34f5b045c..69ba42eba68 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -233,6 +233,7 @@ test_dynamo_shard() { --exclude-distributed-tests \ --exclude \ test_autograd \ + test_jit \ test_proxy_tensor \ test_quantization \ test_public_bindings \ diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 811bc4869ed..a20962e0a0c 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -5157,6 +5157,26 @@ def fn(): self.assertTrue(isinstance(compile_out, torch.Size)) self.assertEqual(eager_out, compile_out) + def test_nested_function_resuming_with_correct_globals(self): + # https://github.com/pytorch/pytorch/issues/99665 + try: + from .utils import outer_func + except ImportError: + from utils import outer_func + + def gn(x, y): + return x + y + + def fn(x, y): + return outer_func(gn)(x, y) + + x = torch.rand([3]) + y = torch.rand([3]) + opt_fn = torch.compile(backend="eager")(fn) + ref = fn(x, y) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + class CustomFunc1(torch.autograd.Function): @staticmethod diff --git a/test/dynamo/utils.py b/test/dynamo/utils.py new file mode 100644 index 00000000000..54cacd080fd --- /dev/null +++ b/test/dynamo/utils.py @@ -0,0 +1,17 @@ +# Owner(s): ["module: dynamo"] + +import torch +import torch._dynamo + + +def inner_func(): + return torch.is_grad_enabled() + + +def outer_func(func): + def wrapped(*args): + a = func(*args) + torch._dynamo.graph_break() + return torch.sin(a + 1), inner_func() + + return wrapped diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index a88b2663913..5df19187180 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -35,6 +35,7 @@ if __name__ == '__main__': "\tpython test/test_jit.py TESTNAME\n\n" "instead.") +@skipIfTorchDynamo("Not a suitable test for TorchDynamo") class TestTracer(JitTestCase): @unittest.skipIf(not RUN_CUDA, "requires CUDA") def test_large_nbr_kernel_args(self): @@ -1990,6 +1991,7 @@ class TestTracer(JitTestCase): self.assertEqual(model(**input_dict), traced_model(**input_dict)) +@skipIfTorchDynamo("Not a suitable test for TorchDynamo") class TestMixTracingScripting(JitTestCase): def test_trace_script(self): @torch.jit.script diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 28ca5aa5a29..ec21547dd5d 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -27,6 +27,7 @@ from torch.testing._internal.common_utils import ( numpy_to_torch_dtype_dict, TEST_SCIPY, set_default_dtype, + skipIfTorchDynamo, ) from torch.testing._internal.common_device_type import ( expectedFailureMeta, @@ -1852,6 +1853,7 @@ class TestBinaryUfuncs(TestCase): _scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide) @onlyNativeDeviceTypes + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_div_and_floordiv_script_vs_python(self, device): # Creates jitted functions of two tensors def _wrapped_div(a, b): @@ -1924,6 +1926,7 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t)) @onlyNativeDeviceTypes + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_idiv_and_ifloordiv_vs_python(self, device): def _wrapped_idiv_tensor(a, b): a /= b diff --git a/test/test_indexing.py b/test/test_indexing.py index 38bddda4469..551327cd93c 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -12,7 +12,7 @@ import numpy as np from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, TEST_WITH_TORCHDYNAMO) + TestCase, run_tests, skipIfTorchDynamo) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyNativeDeviceTypes, skipXLA) @@ -738,10 +738,7 @@ class TestIndexing(TestCase): self.assertEqual(y, torch.ones(size=(10, 10), device=device)) self.assertEqual(len(w), 2) - @unittest.skipIf( - TEST_WITH_TORCHDYNAMO, - "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" - ) + @skipIfTorchDynamo("This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472") def test_index_put_accumulate_large_tensor(self, device): # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). N = (1 << 31) + 5 @@ -839,6 +836,7 @@ class TestIndexing(TestCase): self.assertEqual(out_cuda.cpu(), out_cpu) @onlyCUDA + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_index_put_accumulate_with_optional_tensors(self, device): # TODO: replace with a better solution. # Currently, here using torchscript to put None into indices. @@ -935,6 +933,7 @@ class TestIndexing(TestCase): r = v[c > 0] self.assertEqual(r.shape, (num_ones, 3)) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_jit_indexing(self, device): def fn1(x): x[x < 50] = 1.0 diff --git a/test/test_native_functions.py b/test/test_native_functions.py index ba7889e10f4..c95b4a221ea 100644 --- a/test/test_native_functions.py +++ b/test/test_native_functions.py @@ -2,7 +2,7 @@ from typing import Optional, List import torch -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo # End-to-end tests of features in native_functions.yaml @@ -81,6 +81,7 @@ class TestNativeFunctions(TestCase): return torch._C._nn._test_optional_floatlist(values, const) return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float)) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_optional_floatlist(self): self.do_test_optional_floatlist_with_module(FloatListWrapperModule()) self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule())) @@ -134,6 +135,7 @@ class TestNativeFunctions(TestCase): return torch._C._nn._test_optional_intlist(values, const) return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_optional_intlist(self): self.do_test_optional_intlist_with_module(IntListWrapperModule()) self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule())) @@ -187,6 +189,7 @@ class TestNativeFunctions(TestCase): return torch._C._nn._test_optional_filled_intlist(values, const) return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_optional_filled_intlist(self): def f(n: int): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 9af810fa6f2..280f4674c15 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -3,7 +3,6 @@ import enum import functools import inspect import itertools -import sys import types from typing import Dict, List @@ -11,11 +10,7 @@ import torch from .. import variables from ..allowed_functions import is_allowed, is_builtin_callable -from ..bytecode_transformation import ( - create_call_function, - create_instruction, - create_rot_n, -) +from ..bytecode_transformation import create_call_function, create_rot_n from ..exc import unimplemented from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource from ..utils import istensor, istype, make_cell @@ -89,6 +84,26 @@ def init_cellvars(parent, result, code): return closure_cells +def _create_nested_fn( + code, f_globals, name, defaults, closure, kwdefaults, annotations +): + from types import FunctionType + + func = FunctionType(code, f_globals, name, defaults, closure) + func.__kwdefaults__ = kwdefaults + + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert annotations is None or isinstance(annotations, dict) + func.__annotations__ = annotations + + return func + + class BaseUserFunctionVariable(VariableTracker): def get_filename(self): return self.get_code().co_filename @@ -460,17 +475,27 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): parent.symbolic_locals[var] = child.symbolic_locals[var] def reconstruct(self, codegen): - flags = 0x00 + codegen.load_import_from(__name__, "_create_nested_fn") + codegen(self.code) + codegen.extend_output([codegen._create_load_const(self.f_globals)]) + codegen(self.fn_name) + if self.defaults: - flags |= 0x01 codegen(self.defaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.closure: + codegen(self.closure) + else: + codegen.extend_output([codegen.create_load_const(None)]) + if self.kwdefaults: - flags |= 0x02 codegen(self.kwdefaults) - if isinstance( - self.annotations, (variables.ConstDictVariable, variables.TupleVariable) - ): - flags |= 0x04 + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.annotations: try: if isinstance(self.annotations, variables.ConstDictVariable): annotations = { @@ -484,13 +509,10 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): codegen.extend_output([codegen._create_load_const(annotations)]) except NotImplementedError: codegen(self.annotations) - if self.closure: - flags |= 0x08 - codegen(self.closure) - codegen(self.code) - if sys.version_info < (3, 11): - codegen(self.fn_name) - codegen.extend_output([create_instruction("MAKE_FUNCTION", arg=flags)]) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + codegen.extend_output(create_call_function(7, push_null=True)) if self.wraps_source: codegen.load_import_from("functools", "wraps")