From f500cb43bb2242eb14df44c6020a5aae5fe595b3 Mon Sep 17 00:00:00 2001 From: rzou Date: Thu, 3 Oct 2024 16:58:53 -0700 Subject: [PATCH] Fix torch.library.register_vmap (#137306) We didn't support multiple levels of vmap. The main problem is, during the batching rule, we need to exclude the vmap dispatch key (FuncTorchBatched) like how our C++ batching rules do it. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/137306 Approved by: https://github.com/Chillee --- test/test_custom_ops.py | 10 ++++++++++ torch/_functorch/autograd_function.py | 18 +++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 30c82717448..cb6ff55f3f4 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3558,6 +3558,11 @@ Please use `add.register_fake` to add an fake impl.""", self.assertTrue(called) self.assertEqual(result, x * y) + x = torch.randn(3) + y = torch.randn(3) + result = torch.vmap(torch.vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y) + self.assertEqual(result, y.unsqueeze(-1) * x) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap_op_decorator(self): @torch.library.custom_op("mylib::f", mutates_args=()) @@ -3584,6 +3589,11 @@ Please use `add.register_fake` to add an fake impl.""", self.assertTrue(called) self.assertEqual(result, x * y) + x = torch.randn(3) + y = torch.randn(2) + result = torch.vmap(torch.vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y) + self.assertEqual(result, y.unsqueeze(-1) * x) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap_register_multiple_times(self): @torch.library.custom_op("mylib::f", mutates_args=()) diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index cb501e2c924..0d66cb7a50c 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -325,18 +325,30 @@ def custom_function_call_vmap_helper( batch_size=interpreter.batch_size(), randomness=interpreter.randomness(), ) + # We're either in the autograd.Function case (vmap staticmethod) + # or the torch.library.register_vmap case. + autograd_function_case = isinstance(op, torch.autograd.function.FunctionMeta) + + def lower_to_next(): + if autograd_function_case: + return interpreter.lower() + else: + return torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.FuncTorchBatched) + ) + unwrapped_operands, in_dims = unwrap_batched(operands, current_level) # If none of the tensors are batched at the current level, then we skip the # current level. This saves the user from needing to handle this case in # their vmap staticmethod (and is consistent with our C++ batching rule API) if pytree.tree_all(lambda dim: dim is None, in_dims): - with interpreter.lower(): - if isinstance(op, torch.autograd.function.FunctionMeta): + with lower_to_next(): + if autograd_function_case: return custom_function_call(op, *operands) else: return op(*operands, **kwargs) - with interpreter.lower(): + with lower_to_next(): result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs) validate_vmap_returns_tuple_of_two_elements(result) unwrapped_output, out_dims = result