mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[contextlib] Wrapping a function with set_grad_enabled will consume its global mutation (#113359)
Fixes https://github.com/pytorch/pytorch/issues/113298 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113359 Approved by: https://github.com/soulitzer, https://github.com/jansel
This commit is contained in:
parent
0381d8ce68
commit
5ccd22502f
4 changed files with 41 additions and 8 deletions
|
|
@ -1072,10 +1072,9 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
inner_func = torch.set_grad_enabled(mode)(inner_func)
|
||||
|
||||
# decorator will mutate global state even if it wraps a function
|
||||
# This behaviour is not desirable and may change in the future
|
||||
# https://github.com/pytorch/pytorch/issues/113298
|
||||
assert torch.is_grad_enabled() == mode
|
||||
# Consuming set_grad_enabled by calling it on a function
|
||||
# should not mutate global state
|
||||
assert torch.is_grad_enabled() == mode_inverse
|
||||
|
||||
with torch.set_grad_enabled(mode_inverse):
|
||||
return inner_func(x)
|
||||
|
|
|
|||
|
|
@ -3735,6 +3735,36 @@ class TestAutograd(TestCase):
|
|||
y = x * 2
|
||||
self.assertTrue(y.requires_grad)
|
||||
|
||||
def test_set_grad_enabled_wraps(self):
|
||||
for decorator in [True, False]:
|
||||
with torch.enable_grad():
|
||||
self.assertTrue(torch.is_grad_enabled())
|
||||
|
||||
if decorator:
|
||||
# This should not mutate the global grad mode!
|
||||
@torch.set_grad_enabled(False)
|
||||
def inner_func(x):
|
||||
return x.sin()
|
||||
else:
|
||||
def inner_func(x):
|
||||
return x.sin()
|
||||
|
||||
# This is non-idiomatic usage!
|
||||
# More idiomatic usage: torch.set_grad_enabled(False)(inner_func)
|
||||
obj = torch.set_grad_enabled(False)
|
||||
self.assertTrue(not torch.is_grad_enabled())
|
||||
|
||||
# this will consume the set_grad_enabled global mutation!
|
||||
inner_func = obj(inner_func)
|
||||
self.assertTrue(torch.is_grad_enabled())
|
||||
|
||||
self.assertTrue(torch.is_grad_enabled())
|
||||
|
||||
x = torch.zeros(1, requires_grad=True)
|
||||
self.assertTrue(
|
||||
not inner_func(x).requires_grad
|
||||
)
|
||||
|
||||
def test_simple_reentrant(self):
|
||||
y_data = torch.randn(2, 2)
|
||||
|
||||
|
|
|
|||
|
|
@ -181,8 +181,7 @@ class GradModeVariable(ContextWrappingVariable):
|
|||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
):
|
||||
# TODO(jon-chuang): uncomment once https://github.com/pytorch/pytorch/issues/113298 is fixed
|
||||
# self._call_func(tx, self.initial_values) # undo eager initialization
|
||||
self._call_func(tx, self.initial_values) # undo eager initialization
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
def _call_func(self, tx, values):
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from torch.utils._contextlib import (
|
||||
_DecoratorContextManager,
|
||||
_NoParamDecoratorContextManager,
|
||||
F,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -181,11 +182,15 @@ class set_grad_enabled(_DecoratorContextManager):
|
|||
|
||||
def __init__(self, mode: bool) -> None:
|
||||
self.prev = torch.is_grad_enabled()
|
||||
torch._C._set_grad_enabled(mode)
|
||||
self.mode = mode
|
||||
torch._C._set_grad_enabled(mode)
|
||||
|
||||
def __call__(self, orig_func: F) -> F:
|
||||
torch._C._set_grad_enabled(self.prev)
|
||||
return super().__call__(orig_func)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
torch._C._set_grad_enabled(self.mode)
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
torch._C._set_grad_enabled(self.prev)
|
||||
|
|
|
|||
Loading…
Reference in a new issue