[functorch] jacfwd accepts pytree inputs (pytorch/functorch#300)

Test Plan:
- new tests
This commit is contained in:
Richard Zou 2021-11-29 22:28:56 -05:00 committed by Jon Janzen
parent 057aaa3b51
commit 30350bfcaf
2 changed files with 31 additions and 6 deletions

View file

@ -644,11 +644,19 @@ def jvp(f, primals, tangents, *, strict=False):
_grad_decrement_nesting()
JVP_NESTING -= 1
def safe_unflatten(tensor, dim, shape):
if len(shape) == 0:
assert tensor.numel() == 1
return tensor.squeeze()
return tensor.unflatten(dim, shape)
def jacfwd(f, argnums=0):
def wrapper_fn(*args):
f_wrapper, primals = _argnums_partial(f, args, argnums)
primals_numels = tuple(p.numel() for p in primals)
basis = _construct_standard_basis_for(primals, primals_numels)
flat_primals, primals_spec = tree_flatten(primals)
flat_primals_numels = tuple(p.numel() for p in flat_primals)
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
basis = tree_unflatten(flat_basis, primals_spec)
def push_jvp(basis):
_, jvp_out = jvp(f_wrapper, primals, basis)
@ -660,11 +668,13 @@ def jacfwd(f, argnums=0):
jac_outs, spec = tree_flatten(results)
jac_outs_ins = tuple(
tuple(
jac_out_in.unflatten(-1, primal.shape)
for primal, jac_out_in in zip(primals, jac_out.movedim(0, -1).split(primals_numels, dim=-1))
safe_unflatten(jac_out_in, -1, primal.shape)
for primal, jac_out_in in
zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
)
for jac_out in jac_outs
)
jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
if isinstance(argnums, int):
jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)

View file

@ -1071,7 +1071,7 @@ class TestJac(TestCase):
self.assertTrue(isinstance(z['right'], tuple))
self.assertEqual(z, expected)
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_multiple_inputs_pytree(self, device, jacapi):
def f(a, b, c):
a0, a1 = a
@ -1096,6 +1096,21 @@ class TestJac(TestCase):
expected = (torch.tensor(1., device=device), torch.tensor(2., device=device))
self.assertEqual(result, expected)
@jacrev_and_jacfwd
def test_dimensionality(self, device, jacapi):
def f(x):
return x
x = torch.randn([], device=device)
result = jacapi(f)(x)
self.assertEqual(result.dim(), 0)
self.assertEqual(result, torch.ones_like(x))
x = torch.randn([1], device=device)
result = jacapi(f)(x)
self.assertEqual(result.dim(), 2)
self.assertEqual(result, x.new_ones(1, 1))
@FIXME_jacrev_only
def test_multiple_inputs_outputs_pytree(self, device, jacapi):
def f(a, b, c):
@ -1213,7 +1228,7 @@ class TestJac(TestCase):
@jacrev_and_jacfwd
def test_argnums_defaults_to_zero(self, device, jacapi):
def f(x, y):
return x * 2 + y * 3
return x * 2 + y * 3
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)