diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py index ad9757a5887..ce575811646 100644 --- a/functorch/functorch/_src/eager_transforms.py +++ b/functorch/functorch/_src/eager_transforms.py @@ -644,11 +644,19 @@ def jvp(f, primals, tangents, *, strict=False): _grad_decrement_nesting() JVP_NESTING -= 1 +def safe_unflatten(tensor, dim, shape): + if len(shape) == 0: + assert tensor.numel() == 1 + return tensor.squeeze() + return tensor.unflatten(dim, shape) + def jacfwd(f, argnums=0): def wrapper_fn(*args): f_wrapper, primals = _argnums_partial(f, args, argnums) - primals_numels = tuple(p.numel() for p in primals) - basis = _construct_standard_basis_for(primals, primals_numels) + flat_primals, primals_spec = tree_flatten(primals) + flat_primals_numels = tuple(p.numel() for p in flat_primals) + flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) + basis = tree_unflatten(flat_basis, primals_spec) def push_jvp(basis): _, jvp_out = jvp(f_wrapper, primals, basis) @@ -660,11 +668,13 @@ def jacfwd(f, argnums=0): jac_outs, spec = tree_flatten(results) jac_outs_ins = tuple( tuple( - jac_out_in.unflatten(-1, primal.shape) - for primal, jac_out_in in zip(primals, jac_out.movedim(0, -1).split(primals_numels, dim=-1)) + safe_unflatten(jac_out_in, -1, primal.shape) + for primal, jac_out_in in + zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1)) ) for jac_out in jac_outs ) + jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins) if isinstance(argnums, int): jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py index 8d69fdc5a06..022475ce5bc 100644 --- a/functorch/test/test_eager_transforms.py +++ b/functorch/test/test_eager_transforms.py @@ -1071,7 +1071,7 @@ class TestJac(TestCase): self.assertTrue(isinstance(z['right'], tuple)) self.assertEqual(z, expected) - @FIXME_jacrev_only + @jacrev_and_jacfwd def test_multiple_inputs_pytree(self, device, jacapi): def f(a, b, c): a0, a1 = a @@ -1096,6 +1096,21 @@ class TestJac(TestCase): expected = (torch.tensor(1., device=device), torch.tensor(2., device=device)) self.assertEqual(result, expected) + @jacrev_and_jacfwd + def test_dimensionality(self, device, jacapi): + def f(x): + return x + + x = torch.randn([], device=device) + result = jacapi(f)(x) + self.assertEqual(result.dim(), 0) + self.assertEqual(result, torch.ones_like(x)) + + x = torch.randn([1], device=device) + result = jacapi(f)(x) + self.assertEqual(result.dim(), 2) + self.assertEqual(result, x.new_ones(1, 1)) + @FIXME_jacrev_only def test_multiple_inputs_outputs_pytree(self, device, jacapi): def f(a, b, c): @@ -1213,7 +1228,7 @@ class TestJac(TestCase): @jacrev_and_jacfwd def test_argnums_defaults_to_zero(self, device, jacapi): def f(x, y): - return x * 2 + y * 3 + return x * 2 + y * 3 x = torch.randn(3, device=device) y = torch.randn(3, device=device)