mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix torch.library.register_vmap (#137306)
We didn't support multiple levels of vmap. The main problem is, during the batching rule, we need to exclude the vmap dispatch key (FuncTorchBatched) like how our C++ batching rules do it. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/137306 Approved by: https://github.com/Chillee
This commit is contained in:
parent
cfc51c858a
commit
f500cb43bb
2 changed files with 25 additions and 3 deletions
|
|
@ -3558,6 +3558,11 @@ Please use `add.register_fake` to add an fake impl.""",
|
|||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
result = torch.vmap(torch.vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
|
||||
self.assertEqual(result, y.unsqueeze(-1) * x)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap_op_decorator(self):
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
|
|
@ -3584,6 +3589,11 @@ Please use `add.register_fake` to add an fake impl.""",
|
|||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(2)
|
||||
result = torch.vmap(torch.vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
|
||||
self.assertEqual(result, y.unsqueeze(-1) * x)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap_register_multiple_times(self):
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
|
|
|
|||
|
|
@ -325,18 +325,30 @@ def custom_function_call_vmap_helper(
|
|||
batch_size=interpreter.batch_size(),
|
||||
randomness=interpreter.randomness(),
|
||||
)
|
||||
# We're either in the autograd.Function case (vmap staticmethod)
|
||||
# or the torch.library.register_vmap case.
|
||||
autograd_function_case = isinstance(op, torch.autograd.function.FunctionMeta)
|
||||
|
||||
def lower_to_next():
|
||||
if autograd_function_case:
|
||||
return interpreter.lower()
|
||||
else:
|
||||
return torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.FuncTorchBatched)
|
||||
)
|
||||
|
||||
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
|
||||
# If none of the tensors are batched at the current level, then we skip the
|
||||
# current level. This saves the user from needing to handle this case in
|
||||
# their vmap staticmethod (and is consistent with our C++ batching rule API)
|
||||
if pytree.tree_all(lambda dim: dim is None, in_dims):
|
||||
with interpreter.lower():
|
||||
if isinstance(op, torch.autograd.function.FunctionMeta):
|
||||
with lower_to_next():
|
||||
if autograd_function_case:
|
||||
return custom_function_call(op, *operands)
|
||||
else:
|
||||
return op(*operands, **kwargs)
|
||||
|
||||
with interpreter.lower():
|
||||
with lower_to_next():
|
||||
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
|
||||
validate_vmap_returns_tuple_of_two_elements(result)
|
||||
unwrapped_output, out_dims = result
|
||||
|
|
|
|||
Loading…
Reference in a new issue