mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Setup_context does not contain default values of forward() (#108561)
Fixes #108529 As the title shown. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108561 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
1427b8149c
commit
70f2adaec3
4 changed files with 107 additions and 2 deletions
|
|
@ -1323,7 +1323,7 @@ class TestAutogradFunctionVmapAPI(TestCase):
|
|||
def test_in_dims_multiple_inputs(self, device):
|
||||
class Id(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(input):
|
||||
def forward(x, y):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -8460,6 +8460,56 @@ get_out().sum().backward()
|
|||
_test_op(torch.view_as_complex, torch.rand(2, 2), ())
|
||||
_test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ())
|
||||
|
||||
def test_setup_context_when_forward_has_default_args(self):
|
||||
class PowFunction(Function):
|
||||
@staticmethod
|
||||
def forward(x, y=3):
|
||||
return torch.pow(x, y)
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, output):
|
||||
x, y = inputs
|
||||
ctx.save_for_backward(x)
|
||||
ctx.y = y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
x, = ctx.saved_tensors
|
||||
y = ctx.y
|
||||
return gO * y * torch.pow(x, y - 1), None
|
||||
|
||||
class PowFunctionWithClassmethod(Function):
|
||||
@classmethod
|
||||
def forward(cls, x, y=3):
|
||||
return torch.pow(x, y)
|
||||
|
||||
@classmethod
|
||||
def setup_context(cls, ctx, inputs, output):
|
||||
x, y = inputs
|
||||
ctx.save_for_backward(x)
|
||||
ctx.y = y
|
||||
|
||||
@classmethod
|
||||
def backward(cls, ctx, gO):
|
||||
x, = ctx.saved_tensors
|
||||
y = ctx.y
|
||||
return gO * y * torch.pow(x, y - 1), None
|
||||
|
||||
x = torch.tensor(2.0, requires_grad=True)
|
||||
|
||||
y = torch.tensor(8.0)
|
||||
y_expected = torch.tensor(12.0)
|
||||
|
||||
y1 = PowFunction.apply(x)
|
||||
y1_expected, = torch.autograd.grad(y1, x)
|
||||
|
||||
y2 = PowFunctionWithClassmethod.apply(x)
|
||||
y2_expected, = torch.autograd.grad(y2, x)
|
||||
|
||||
self.assertEqual(y, y1)
|
||||
self.assertEqual(y_expected, y1_expected)
|
||||
self.assertEqual(y, y2)
|
||||
self.assertEqual(y_expected, y2_expected)
|
||||
|
||||
def index_perm_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
|
@ -533,12 +534,23 @@ class Function(_SingleLevelFunction):
|
|||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
def bind_default_args(func, *args, **kwargs):
|
||||
signature = inspect.signature(func)
|
||||
bound_args = signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
return bound_args.args
|
||||
|
||||
is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
|
||||
if is_setup_ctx_defined:
|
||||
args = bind_default_args(cls.forward, *args, **kwargs)
|
||||
|
||||
if not torch._C._are_functorch_transforms_active():
|
||||
# See NOTE: [functorch vjp and autograd interaction]
|
||||
args = _functorch.utils.unwrap_dead_wrappers(args)
|
||||
return super().apply(*args, **kwargs) # type: ignore[misc]
|
||||
|
||||
if cls.setup_context == _SingleLevelFunction.setup_context:
|
||||
if not is_setup_ctx_defined:
|
||||
raise RuntimeError(
|
||||
"In order to use an autograd.Function with functorch transforms "
|
||||
"(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
|
||||
|
|
|
|||
|
|
@ -463,6 +463,40 @@ class ZeroGradientsGenVmap(torch.autograd.Function):
|
|||
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
|
||||
)
|
||||
|
||||
|
||||
def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
yield SampleInput(make_arg(3, 5))
|
||||
|
||||
|
||||
class ForwardHasDefaultArgs(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(x, idx=(2,)):
|
||||
return x[idx]
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, output):
|
||||
x, idx = inputs
|
||||
ctx.x_shape = x.shape
|
||||
ctx.idx = idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
result = grad_output.new_zeros(ctx.x_shape)
|
||||
result[ctx.idx] = grad_output
|
||||
return result, None
|
||||
|
||||
@staticmethod
|
||||
def vmap(info, in_dims, x, idx):
|
||||
x_bdim, _ = in_dims
|
||||
x = x.movedim(x_bdim, 1)
|
||||
return ForwardHasDefaultArgs.apply(x, idx), 0
|
||||
|
||||
@staticmethod
|
||||
def jvp(ctx, x_tangent, _):
|
||||
return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)
|
||||
|
||||
|
||||
autograd_function_db = [
|
||||
OpInfo(
|
||||
'NumpyCubeAutogradFunction',
|
||||
|
|
@ -584,4 +618,13 @@ autograd_function_db = [
|
|||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
'ForwardHasDefaultArgsAutogradFunction',
|
||||
op=ForwardHasDefaultArgs.apply,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_forward_default_args,
|
||||
dtypes=all_types_and(torch.bool, torch.half),
|
||||
supports_out=False,
|
||||
),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue