[contextlib] Wrapping a function with set_grad_enabled will consume its global mutation (#113359)

Fixes https://github.com/pytorch/pytorch/issues/113298

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113359
Approved by: https://github.com/soulitzer, https://github.com/jansel
This commit is contained in:
Jon Chuang 2023-11-09 19:16:15 +00:00 committed by PyTorch MergeBot
parent 0381d8ce68
commit 5ccd22502f
4 changed files with 41 additions and 8 deletions

View file

@ -1072,10 +1072,9 @@ class GraphModule(torch.nn.Module):
inner_func = torch.set_grad_enabled(mode)(inner_func)
# decorator will mutate global state even if it wraps a function
# This behaviour is not desirable and may change in the future
# https://github.com/pytorch/pytorch/issues/113298
assert torch.is_grad_enabled() == mode
# Consuming set_grad_enabled by calling it on a function
# should not mutate global state
assert torch.is_grad_enabled() == mode_inverse
with torch.set_grad_enabled(mode_inverse):
return inner_func(x)

View file

@ -3735,6 +3735,36 @@ class TestAutograd(TestCase):
y = x * 2
self.assertTrue(y.requires_grad)
def test_set_grad_enabled_wraps(self):
for decorator in [True, False]:
with torch.enable_grad():
self.assertTrue(torch.is_grad_enabled())
if decorator:
# This should not mutate the global grad mode!
@torch.set_grad_enabled(False)
def inner_func(x):
return x.sin()
else:
def inner_func(x):
return x.sin()
# This is non-idiomatic usage!
# More idiomatic usage: torch.set_grad_enabled(False)(inner_func)
obj = torch.set_grad_enabled(False)
self.assertTrue(not torch.is_grad_enabled())
# this will consume the set_grad_enabled global mutation!
inner_func = obj(inner_func)
self.assertTrue(torch.is_grad_enabled())
self.assertTrue(torch.is_grad_enabled())
x = torch.zeros(1, requires_grad=True)
self.assertTrue(
not inner_func(x).requires_grad
)
def test_simple_reentrant(self):
y_data = torch.randn(2, 2)

View file

@ -181,8 +181,7 @@ class GradModeVariable(ContextWrappingVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
):
# TODO(jon-chuang): uncomment once https://github.com/pytorch/pytorch/issues/113298 is fixed
# self._call_func(tx, self.initial_values) # undo eager initialization
self._call_func(tx, self.initial_values) # undo eager initialization
return super().call_function(tx, args, kwargs)
def _call_func(self, tx, values):

View file

@ -5,6 +5,7 @@ import torch
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
F,
)
__all__ = [
@ -181,11 +182,15 @@ class set_grad_enabled(_DecoratorContextManager):
def __init__(self, mode: bool) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(mode)
self.mode = mode
torch._C._set_grad_enabled(mode)
def __call__(self, orig_func: F) -> F:
torch._C._set_grad_enabled(self.prev)
return super().__call__(orig_func)
def __enter__(self) -> None:
pass
torch._C._set_grad_enabled(self.mode)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)