Fix torch.library.register_vmap (#137306)

We didn't support multiple levels of vmap. The main problem is, during
the batching rule, we need to exclude the vmap dispatch key
(FuncTorchBatched) like how our C++ batching rules do it.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137306
Approved by: https://github.com/Chillee
This commit is contained in:
rzou 2024-10-03 16:58:53 -07:00 committed by PyTorch MergeBot
parent cfc51c858a
commit f500cb43bb
2 changed files with 25 additions and 3 deletions

View file

@ -3558,6 +3558,11 @@ Please use `add.register_fake` to add an fake impl.""",
self.assertTrue(called)
self.assertEqual(result, x * y)
x = torch.randn(3)
y = torch.randn(3)
result = torch.vmap(torch.vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
self.assertEqual(result, y.unsqueeze(-1) * x)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap_op_decorator(self):
@torch.library.custom_op("mylib::f", mutates_args=())
@ -3584,6 +3589,11 @@ Please use `add.register_fake` to add an fake impl.""",
self.assertTrue(called)
self.assertEqual(result, x * y)
x = torch.randn(3)
y = torch.randn(2)
result = torch.vmap(torch.vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
self.assertEqual(result, y.unsqueeze(-1) * x)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap_register_multiple_times(self):
@torch.library.custom_op("mylib::f", mutates_args=())

View file

@ -325,18 +325,30 @@ def custom_function_call_vmap_helper(
batch_size=interpreter.batch_size(),
randomness=interpreter.randomness(),
)
# We're either in the autograd.Function case (vmap staticmethod)
# or the torch.library.register_vmap case.
autograd_function_case = isinstance(op, torch.autograd.function.FunctionMeta)
def lower_to_next():
if autograd_function_case:
return interpreter.lower()
else:
return torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.FuncTorchBatched)
)
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
# If none of the tensors are batched at the current level, then we skip the
# current level. This saves the user from needing to handle this case in
# their vmap staticmethod (and is consistent with our C++ batching rule API)
if pytree.tree_all(lambda dim: dim is None, in_dims):
with interpreter.lower():
if isinstance(op, torch.autograd.function.FunctionMeta):
with lower_to_next():
if autograd_function_case:
return custom_function_call(op, *operands)
else:
return op(*operands, **kwargs)
with interpreter.lower():
with lower_to_next():
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
validate_vmap_returns_tuple_of_two_elements(result)
unwrapped_output, out_dims = result