From ce503c1b40207dab770c28cbd4568cd9e105277b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 27 Apr 2024 04:57:13 +0000 Subject: [PATCH] Dynamo x autograd.Function supports setup_context (#124802) Fixes part of #118397 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124802 Approved by: https://github.com/zou3519 --- test/dynamo/test_autograd_function.py | 4 +- torch/_dynamo/trace_rules.py | 1 + torch/_dynamo/variables/higher_order_ops.py | 3 +- torch/_dynamo/variables/misc.py | 65 +++++++++++++++++---- torch/_dynamo/variables/user_defined.py | 11 ++-- torch/_functorch/autograd_function.py | 9 +++ torch/autograd/function.py | 6 +- 7 files changed, 79 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 492936d0a99..e30bb4bf172 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -253,11 +253,11 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): def test_linear_setup_context(self): model = ModuleLinear() - opt_model = torch._dynamo.optimize("eager")(model) + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) input = torch.randn(2, 2, dtype=torch.double, requires_grad=True) weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True) - optim_result = opt_model(input, weight) eager_result = model(input, weight) + optim_result = opt_model(input, weight) self.assertEqual(optim_result, eager_result) def test_materialize_grad(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 2fb8dc7241b..2a940eb600b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3205,6 +3205,7 @@ MOD_INLINELIST = { "torch._dynamo.comptime", "torch._dynamo.polyfill", "torch._functorch.vmap", + "torch._functorch.autograd_function", "torch._library.custom_ops", "torch._functorch.eager_transforms", "torch._inductor.test_operators", diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 26f1eeb91c6..e3a6ece18d7 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1628,13 +1628,12 @@ class AutogradFunctionApplyVariable(VariableTracker): fwd_src = AttrSource(self.parent_source, member="forward") ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) if isinstance(self.fwd_graph, types.FunctionType): - fwd_fn = UserFunctionVariable(self.fwd_graph, source=fwd_src) + fwd_fn = UserFunctionVariable(self.fwd_graph) fwd_args = [ctx, *args] elif isinstance(self.fwd_graph, types.MethodType): fwd_fn = UserMethodVariable( self.fwd_graph.__func__, UserDefinedClassVariable(self.fwd_graph.__class__), - source=fwd_src, ) fwd_args = [fwd_fn.obj, ctx, *args] else: diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 83ddc372bd3..39de1c4a50e 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -357,14 +357,19 @@ class AutogradFunctionVariable(VariableTracker): and torch.is_grad_enabled() and config.capture_autograd_function ): - # Note - this is the same check used in autograd/function.py, except inverted. - # If we want to support functorch transforms here, we will need to enable this. - if ( - self.fn_cls.setup_context - != torch.autograd.function._SingleLevelFunction.setup_context - ): - unimplemented( - "NYI - autograd.Function with custom setup_context method" + from torch._functorch.autograd_function import ( + autograd_function_forward_rewritten, + ) + from torch.autograd.function import _is_setup_context_defined + + forward_fn = self.fn_cls.forward + + is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) + if is_setup_ctx_defined: + # If setup_context is defined, we generate a new forward function which includes + # the original forward and setup_context function, and trace the new forward function. + forward_fn = autograd_function_forward_rewritten( + self.fn_cls.forward, self.fn_cls.setup_context ) vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] @@ -383,12 +388,25 @@ class AutogradFunctionVariable(VariableTracker): tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ ) - return AutogradFunctionApplyVariable( - self.fn_cls.forward, + val = AutogradFunctionApplyVariable( + forward_fn, self.fn_cls.backward, source, source=AttrSource(source, member="apply"), ).call_function(tx, args, kwargs) + # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping + # the forward function, as we don't want to generate guards for new_forward.__closure__ + # if forward is rewritten by autograd_function_forward_rewritten. + # But we still need to generate correct guards for the original forward and setup_context + # functions, so we have to add guards manually. + if self.source: + fwd_src = AttrSource(self.source, "forward") + install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH)) + if is_setup_ctx_defined: + setup_ctx_src = AttrSource(self.source, "setup_context") + install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH)) + + return val if self.source: source = AttrSource(self.source, "forward") @@ -443,7 +461,32 @@ class AutogradFunctionVariable(VariableTracker): return self.call_apply(tx, args, kwargs) else: - unimplemented(f"Unsupported method: {name}") + from .. import trace_rules + + source = AttrSource(self.source, name) if self.source is not None else None + try: + obj = inspect.getattr_static(self.fn_cls, name) + except AttributeError: + obj = None + + if isinstance(obj, staticmethod): + func = obj.__get__(self.fn_cls) + if source is not None: + return ( + trace_rules.lookup(func) + .create_with_source(func, source=source) + .call_function(tx, args, kwargs) + ) + else: + return trace_rules.lookup(func)(func).call_function( + tx, args, kwargs + ) + elif isinstance(obj, classmethod): + return variables.UserMethodVariable( + obj.__func__, self, source=source + ).call_function(tx, args, kwargs) + else: + unimplemented(f"Unsupported method: {name}") @dataclasses.dataclass diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 9db504cd16f..544773f08a1 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -2,6 +2,7 @@ import collections import contextlib +import enum import functools import importlib import inspect @@ -107,7 +108,7 @@ class UserDefinedClassVariable(UserDefinedVariable): def var_getattr(self, tx, name: str) -> "VariableTracker": from .. import trace_rules - from . import ConstantVariable + from . import ConstantVariable, EnumVariable from .builder import VariableBuilder if name == "__name__": @@ -144,14 +145,16 @@ class UserDefinedClassVariable(UserDefinedVariable): if self.value is collections.OrderedDict and name == "fromkeys": return super().var_getattr(tx, name) - if name in getattr(self.value, "__dict__", {}) or ( + if ConstantVariable.is_literal(obj): + return ConstantVariable.create(obj) + elif isinstance(obj, enum.Enum): + return EnumVariable(obj) + elif name in getattr(self.value, "__dict__", {}) or ( self.value.__module__.startswith("torch.") or self.value.__module__ == "torch" ): if source: return VariableBuilder(tx, source)(obj) - elif ConstantVariable.is_literal(obj): - return ConstantVariable.create(obj) return super().var_getattr(tx, name) diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 98ffe6dd165..03bfd710ae3 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -682,6 +682,15 @@ def reductify_leaf( return grad_input +def autograd_function_forward_rewritten(original_forward, original_setup_context): + def new_forward(ctx, *args, **kwargs): + output = original_forward(*args, **kwargs) + original_setup_context(ctx, args, output) + return output + + return new_forward + + class AutogradFunctionApply(HigherOrderOperator): def __init__(self): super().__init__("autograd_function_apply") diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 3ff96953b29..9c624ce5d14 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -561,7 +561,7 @@ class Function(_SingleLevelFunction): return bound_args.args - is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context + is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context) if is_setup_ctx_defined: args = bind_default_args(cls.forward, *args, **kwargs) @@ -585,6 +585,10 @@ class Function(_SingleLevelFunction): return (ctx._autograd_function_id,) +def _is_setup_context_defined(fn): + return fn != _SingleLevelFunction.setup_context + + def once_differentiable(fn): @functools.wraps(fn) def wrapper(ctx, *args):