Dynamo x autograd.Function supports setup_context (#124802)

Fixes part of #118397

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124802
Approved by: https://github.com/zou3519
This commit is contained in:
Yanbo Liang 2024-04-27 04:57:13 +00:00 committed by PyTorch MergeBot
parent a866bfff45
commit ce503c1b40
7 changed files with 79 additions and 20 deletions

View file

@ -253,11 +253,11 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
def test_linear_setup_context(self):
model = ModuleLinear()
opt_model = torch._dynamo.optimize("eager")(model)
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
optim_result = opt_model(input, weight)
eager_result = model(input, weight)
optim_result = opt_model(input, weight)
self.assertEqual(optim_result, eager_result)
def test_materialize_grad(self):

View file

@ -3205,6 +3205,7 @@ MOD_INLINELIST = {
"torch._dynamo.comptime",
"torch._dynamo.polyfill",
"torch._functorch.vmap",
"torch._functorch.autograd_function",
"torch._library.custom_ops",
"torch._functorch.eager_transforms",
"torch._inductor.test_operators",

View file

@ -1628,13 +1628,12 @@ class AutogradFunctionApplyVariable(VariableTracker):
fwd_src = AttrSource(self.parent_source, member="forward")
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
if isinstance(self.fwd_graph, types.FunctionType):
fwd_fn = UserFunctionVariable(self.fwd_graph, source=fwd_src)
fwd_fn = UserFunctionVariable(self.fwd_graph)
fwd_args = [ctx, *args]
elif isinstance(self.fwd_graph, types.MethodType):
fwd_fn = UserMethodVariable(
self.fwd_graph.__func__,
UserDefinedClassVariable(self.fwd_graph.__class__),
source=fwd_src,
)
fwd_args = [fwd_fn.obj, ctx, *args]
else:

View file

@ -357,14 +357,19 @@ class AutogradFunctionVariable(VariableTracker):
and torch.is_grad_enabled()
and config.capture_autograd_function
):
# Note - this is the same check used in autograd/function.py, except inverted.
# If we want to support functorch transforms here, we will need to enable this.
if (
self.fn_cls.setup_context
!= torch.autograd.function._SingleLevelFunction.setup_context
):
unimplemented(
"NYI - autograd.Function with custom setup_context method"
from torch._functorch.autograd_function import (
autograd_function_forward_rewritten,
)
from torch.autograd.function import _is_setup_context_defined
forward_fn = self.fn_cls.forward
is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
if is_setup_ctx_defined:
# If setup_context is defined, we generate a new forward function which includes
# the original forward and setup_context function, and trace the new forward function.
forward_fn = autograd_function_forward_rewritten(
self.fn_cls.forward, self.fn_cls.setup_context
)
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
@ -383,12 +388,25 @@ class AutogradFunctionVariable(VariableTracker):
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
)
return AutogradFunctionApplyVariable(
self.fn_cls.forward,
val = AutogradFunctionApplyVariable(
forward_fn,
self.fn_cls.backward,
source,
source=AttrSource(source, member="apply"),
).call_function(tx, args, kwargs)
# Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
# the forward function, as we don't want to generate guards for new_forward.__closure__
# if forward is rewritten by autograd_function_forward_rewritten.
# But we still need to generate correct guards for the original forward and setup_context
# functions, so we have to add guards manually.
if self.source:
fwd_src = AttrSource(self.source, "forward")
install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
if is_setup_ctx_defined:
setup_ctx_src = AttrSource(self.source, "setup_context")
install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
return val
if self.source:
source = AttrSource(self.source, "forward")
@ -443,7 +461,32 @@ class AutogradFunctionVariable(VariableTracker):
return self.call_apply(tx, args, kwargs)
else:
unimplemented(f"Unsupported method: {name}")
from .. import trace_rules
source = AttrSource(self.source, name) if self.source is not None else None
try:
obj = inspect.getattr_static(self.fn_cls, name)
except AttributeError:
obj = None
if isinstance(obj, staticmethod):
func = obj.__get__(self.fn_cls)
if source is not None:
return (
trace_rules.lookup(func)
.create_with_source(func, source=source)
.call_function(tx, args, kwargs)
)
else:
return trace_rules.lookup(func)(func).call_function(
tx, args, kwargs
)
elif isinstance(obj, classmethod):
return variables.UserMethodVariable(
obj.__func__, self, source=source
).call_function(tx, args, kwargs)
else:
unimplemented(f"Unsupported method: {name}")
@dataclasses.dataclass

View file

@ -2,6 +2,7 @@
import collections
import contextlib
import enum
import functools
import importlib
import inspect
@ -107,7 +108,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
def var_getattr(self, tx, name: str) -> "VariableTracker":
from .. import trace_rules
from . import ConstantVariable
from . import ConstantVariable, EnumVariable
from .builder import VariableBuilder
if name == "__name__":
@ -144,14 +145,16 @@ class UserDefinedClassVariable(UserDefinedVariable):
if self.value is collections.OrderedDict and name == "fromkeys":
return super().var_getattr(tx, name)
if name in getattr(self.value, "__dict__", {}) or (
if ConstantVariable.is_literal(obj):
return ConstantVariable.create(obj)
elif isinstance(obj, enum.Enum):
return EnumVariable(obj)
elif name in getattr(self.value, "__dict__", {}) or (
self.value.__module__.startswith("torch.")
or self.value.__module__ == "torch"
):
if source:
return VariableBuilder(tx, source)(obj)
elif ConstantVariable.is_literal(obj):
return ConstantVariable.create(obj)
return super().var_getattr(tx, name)

View file

@ -682,6 +682,15 @@ def reductify_leaf(
return grad_input
def autograd_function_forward_rewritten(original_forward, original_setup_context):
def new_forward(ctx, *args, **kwargs):
output = original_forward(*args, **kwargs)
original_setup_context(ctx, args, output)
return output
return new_forward
class AutogradFunctionApply(HigherOrderOperator):
def __init__(self):
super().__init__("autograd_function_apply")

View file

@ -561,7 +561,7 @@ class Function(_SingleLevelFunction):
return bound_args.args
is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
if is_setup_ctx_defined:
args = bind_default_args(cls.forward, *args, **kwargs)
@ -585,6 +585,10 @@ class Function(_SingleLevelFunction):
return (ctx._autograd_function_id,)
def _is_setup_context_defined(fn):
return fn != _SingleLevelFunction.setup_context
def once_differentiable(fn):
@functools.wraps(fn)
def wrapper(ctx, *args):