From 70f2adaec338bd33d93c373ba4f791b33adc4aaa Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 19 Sep 2023 16:23:52 +0000 Subject: [PATCH] Setup_context does not contain default values of forward() (#108561) Fixes #108529 As the title shown. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108561 Approved by: https://github.com/soulitzer --- test/functorch/test_eager_transforms.py | 2 +- test/test_autograd.py | 50 +++++++++++++++++++ torch/autograd/function.py | 14 +++++- .../testing/_internal/autograd_function_db.py | 43 ++++++++++++++++ 4 files changed, 107 insertions(+), 2 deletions(-) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index fd5693ed594..616aadcee94 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -1323,7 +1323,7 @@ class TestAutogradFunctionVmapAPI(TestCase): def test_in_dims_multiple_inputs(self, device): class Id(torch.autograd.Function): @staticmethod - def forward(input): + def forward(x, y): pass @staticmethod diff --git a/test/test_autograd.py b/test/test_autograd.py index d8bd7f061fb..ad98d7031c8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -8460,6 +8460,56 @@ get_out().sum().backward() _test_op(torch.view_as_complex, torch.rand(2, 2), ()) _test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ()) + def test_setup_context_when_forward_has_default_args(self): + class PowFunction(Function): + @staticmethod + def forward(x, y=3): + return torch.pow(x, y) + + @staticmethod + def setup_context(ctx, inputs, output): + x, y = inputs + ctx.save_for_backward(x) + ctx.y = y + + @staticmethod + def backward(ctx, gO): + x, = ctx.saved_tensors + y = ctx.y + return gO * y * torch.pow(x, y - 1), None + + class PowFunctionWithClassmethod(Function): + @classmethod + def forward(cls, x, y=3): + return torch.pow(x, y) + + @classmethod + def setup_context(cls, ctx, inputs, output): + x, y = inputs + ctx.save_for_backward(x) + ctx.y = y + + @classmethod + def backward(cls, ctx, gO): + x, = ctx.saved_tensors + y = ctx.y + return gO * y * torch.pow(x, y - 1), None + + x = torch.tensor(2.0, requires_grad=True) + + y = torch.tensor(8.0) + y_expected = torch.tensor(12.0) + + y1 = PowFunction.apply(x) + y1_expected, = torch.autograd.grad(y1, x) + + y2 = PowFunctionWithClassmethod.apply(x) + y2_expected, = torch.autograd.grad(y2, x) + + self.assertEqual(y, y1) + self.assertEqual(y_expected, y1_expected) + self.assertEqual(y, y2) + self.assertEqual(y_expected, y2_expected) def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 080f1a4e944..92060f1e3f0 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,4 +1,5 @@ import functools +import inspect import warnings from collections import OrderedDict from typing import Any, List, Optional, Tuple @@ -533,12 +534,23 @@ class Function(_SingleLevelFunction): @classmethod def apply(cls, *args, **kwargs): + def bind_default_args(func, *args, **kwargs): + signature = inspect.signature(func) + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + return bound_args.args + + is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context + if is_setup_ctx_defined: + args = bind_default_args(cls.forward, *args, **kwargs) + if not torch._C._are_functorch_transforms_active(): # See NOTE: [functorch vjp and autograd interaction] args = _functorch.utils.unwrap_dead_wrappers(args) return super().apply(*args, **kwargs) # type: ignore[misc] - if cls.setup_context == _SingleLevelFunction.setup_context: + if not is_setup_ctx_defined: raise RuntimeError( "In order to use an autograd.Function with functorch transforms " "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " diff --git a/torch/testing/_internal/autograd_function_db.py b/torch/testing/_internal/autograd_function_db.py index 5042d712482..0a16cae3aca 100644 --- a/torch/testing/_internal/autograd_function_db.py +++ b/torch/testing/_internal/autograd_function_db.py @@ -463,6 +463,40 @@ class ZeroGradientsGenVmap(torch.autograd.Function): torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), ) + +def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5)) + + +class ForwardHasDefaultArgs(torch.autograd.Function): + @staticmethod + def forward(x, idx=(2,)): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return ForwardHasDefaultArgs.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx) + + autograd_function_db = [ OpInfo( 'NumpyCubeAutogradFunction', @@ -584,4 +618,13 @@ autograd_function_db = [ dtypes=all_types_and(torch.bool, torch.half), supports_out=False, ), + OpInfo( + 'ForwardHasDefaultArgsAutogradFunction', + op=ForwardHasDefaultArgs.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_forward_default_args, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), ]