mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[functorch] jacfwd accepts pytree inputs (pytorch/functorch#300)
Test Plan: - new tests
This commit is contained in:
parent
057aaa3b51
commit
30350bfcaf
2 changed files with 31 additions and 6 deletions
|
|
@ -644,11 +644,19 @@ def jvp(f, primals, tangents, *, strict=False):
|
|||
_grad_decrement_nesting()
|
||||
JVP_NESTING -= 1
|
||||
|
||||
def safe_unflatten(tensor, dim, shape):
|
||||
if len(shape) == 0:
|
||||
assert tensor.numel() == 1
|
||||
return tensor.squeeze()
|
||||
return tensor.unflatten(dim, shape)
|
||||
|
||||
def jacfwd(f, argnums=0):
|
||||
def wrapper_fn(*args):
|
||||
f_wrapper, primals = _argnums_partial(f, args, argnums)
|
||||
primals_numels = tuple(p.numel() for p in primals)
|
||||
basis = _construct_standard_basis_for(primals, primals_numels)
|
||||
flat_primals, primals_spec = tree_flatten(primals)
|
||||
flat_primals_numels = tuple(p.numel() for p in flat_primals)
|
||||
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
|
||||
basis = tree_unflatten(flat_basis, primals_spec)
|
||||
|
||||
def push_jvp(basis):
|
||||
_, jvp_out = jvp(f_wrapper, primals, basis)
|
||||
|
|
@ -660,11 +668,13 @@ def jacfwd(f, argnums=0):
|
|||
jac_outs, spec = tree_flatten(results)
|
||||
jac_outs_ins = tuple(
|
||||
tuple(
|
||||
jac_out_in.unflatten(-1, primal.shape)
|
||||
for primal, jac_out_in in zip(primals, jac_out.movedim(0, -1).split(primals_numels, dim=-1))
|
||||
safe_unflatten(jac_out_in, -1, primal.shape)
|
||||
for primal, jac_out_in in
|
||||
zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
|
||||
)
|
||||
for jac_out in jac_outs
|
||||
)
|
||||
jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
|
||||
|
||||
if isinstance(argnums, int):
|
||||
jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
|
||||
|
|
|
|||
|
|
@ -1071,7 +1071,7 @@ class TestJac(TestCase):
|
|||
self.assertTrue(isinstance(z['right'], tuple))
|
||||
self.assertEqual(z, expected)
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_multiple_inputs_pytree(self, device, jacapi):
|
||||
def f(a, b, c):
|
||||
a0, a1 = a
|
||||
|
|
@ -1096,6 +1096,21 @@ class TestJac(TestCase):
|
|||
expected = (torch.tensor(1., device=device), torch.tensor(2., device=device))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@jacrev_and_jacfwd
|
||||
def test_dimensionality(self, device, jacapi):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
x = torch.randn([], device=device)
|
||||
result = jacapi(f)(x)
|
||||
self.assertEqual(result.dim(), 0)
|
||||
self.assertEqual(result, torch.ones_like(x))
|
||||
|
||||
x = torch.randn([1], device=device)
|
||||
result = jacapi(f)(x)
|
||||
self.assertEqual(result.dim(), 2)
|
||||
self.assertEqual(result, x.new_ones(1, 1))
|
||||
|
||||
@FIXME_jacrev_only
|
||||
def test_multiple_inputs_outputs_pytree(self, device, jacapi):
|
||||
def f(a, b, c):
|
||||
|
|
@ -1213,7 +1228,7 @@ class TestJac(TestCase):
|
|||
@jacrev_and_jacfwd
|
||||
def test_argnums_defaults_to_zero(self, device, jacapi):
|
||||
def f(x, y):
|
||||
return x * 2 + y * 3
|
||||
return x * 2 + y * 3
|
||||
|
||||
x = torch.randn(3, device=device)
|
||||
y = torch.randn(3, device=device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue