diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 316855adabc..cf6da178d27 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1072,10 +1072,9 @@ class GraphModule(torch.nn.Module): inner_func = torch.set_grad_enabled(mode)(inner_func) - # decorator will mutate global state even if it wraps a function - # This behaviour is not desirable and may change in the future - # https://github.com/pytorch/pytorch/issues/113298 - assert torch.is_grad_enabled() == mode + # Consuming set_grad_enabled by calling it on a function + # should not mutate global state + assert torch.is_grad_enabled() == mode_inverse with torch.set_grad_enabled(mode_inverse): return inner_func(x) diff --git a/test/test_autograd.py b/test/test_autograd.py index 57aab2ca386..671ebc449de 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3735,6 +3735,36 @@ class TestAutograd(TestCase): y = x * 2 self.assertTrue(y.requires_grad) + def test_set_grad_enabled_wraps(self): + for decorator in [True, False]: + with torch.enable_grad(): + self.assertTrue(torch.is_grad_enabled()) + + if decorator: + # This should not mutate the global grad mode! + @torch.set_grad_enabled(False) + def inner_func(x): + return x.sin() + else: + def inner_func(x): + return x.sin() + + # This is non-idiomatic usage! + # More idiomatic usage: torch.set_grad_enabled(False)(inner_func) + obj = torch.set_grad_enabled(False) + self.assertTrue(not torch.is_grad_enabled()) + + # this will consume the set_grad_enabled global mutation! + inner_func = obj(inner_func) + self.assertTrue(torch.is_grad_enabled()) + + self.assertTrue(torch.is_grad_enabled()) + + x = torch.zeros(1, requires_grad=True) + self.assertTrue( + not inner_func(x).requires_grad + ) + def test_simple_reentrant(self): y_data = torch.randn(2, 2) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 7567a6e01ee..d8dd581d938 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -181,8 +181,7 @@ class GradModeVariable(ContextWrappingVariable): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ): - # TODO(jon-chuang): uncomment once https://github.com/pytorch/pytorch/issues/113298 is fixed - # self._call_func(tx, self.initial_values) # undo eager initialization + self._call_func(tx, self.initial_values) # undo eager initialization return super().call_function(tx, args, kwargs) def _call_func(self, tx, values): diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index fcd5d9ff414..9c23dded1ad 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -5,6 +5,7 @@ import torch from torch.utils._contextlib import ( _DecoratorContextManager, _NoParamDecoratorContextManager, + F, ) __all__ = [ @@ -181,11 +182,15 @@ class set_grad_enabled(_DecoratorContextManager): def __init__(self, mode: bool) -> None: self.prev = torch.is_grad_enabled() - torch._C._set_grad_enabled(mode) self.mode = mode + torch._C._set_grad_enabled(mode) + + def __call__(self, orig_func: F) -> F: + torch._C._set_grad_enabled(self.prev) + return super().__call__(orig_func) def __enter__(self) -> None: - pass + torch._C._set_grad_enabled(self.mode) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch._C._set_grad_enabled(self.prev)