diff --git a/docs/source/library.rst b/docs/source/library.rst index c3f78839e90..d67b5497e31 100644 --- a/docs/source/library.rst +++ b/docs/source/library.rst @@ -42,6 +42,7 @@ via PyTorch's C++ operator registration APIs). .. autofunction:: register_kernel .. autofunction:: register_autograd .. autofunction:: register_fake +.. autofunction:: register_vmap .. autofunction:: impl_abstract .. autofunction:: get_ctx .. autofunction:: register_torch_dispatch diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 1f25e6f0f79..3cc61b84a52 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -18,6 +18,7 @@ from torch.testing._internal.autograd_function_db import autograd_function_db from torch.testing._internal.common_device_type import toleranceOverride from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db from torch.testing._internal.common_modules import module_db +from torch.testing._internal.custom_op_db import custom_op_db IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1" @@ -38,8 +39,26 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values): flat_out, out_spec = pytree.tree_flatten(out) outs.append(flat_out) + # use the same out_dim for all outputs + if isinstance(out_dim, int): + flat_out_dim = [out_dim for _ in flat_out] + else: + flat_out_dim, _ = pytree.tree_flatten(out_dim) + outs = zip(*outs) - result = [torch.stack(out_lst) for out_lst in outs] + + result = [] + for i, out_lst in enumerate(outs): + if flat_out_dim[i] is not None: + if not all(isinstance(x, torch.Tensor) for x in out_lst): + raise ValueError( + f"vmap `{op}` must only return " + "Tensors. Did you mean to set out_dims= to None for output?" + ) + result.append(torch.stack(out_lst)) + else: + # not batched over, result should be the same for all batches + result.append(out_lst[0]) return pytree.tree_unflatten(result, out_spec) @@ -317,9 +336,9 @@ def _compute_quantities_for_vmap_test( inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0,) + in_dims batched_args, kwarg_values = maybe_clone_inputs() - vmapvmap_output = vmap(vmap(f, inner_in_dims), outer_in_dims)( - dummy, *batched_args, **kwarg_values - ) + vmapvmap_output = vmap( + vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim + )(dummy, *batched_args, **kwarg_values) yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected) @@ -440,7 +459,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None): def skipOps(test_case_name, base_test_name, to_skip): - all_opinfos = op_db + additional_op_db + autograd_function_db + all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db for decorate_meta in to_skip: matching_opinfos = [ o diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index aa0c6c5c578..4a68b394be4 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -1617,6 +1617,28 @@ class TestAutogradFunctionVmapAPI(TestCase): with self.assertRaisesRegex(RuntimeError, "returned an incompatible"): result = vmap(Zeros.apply)(x) + def test_kwarg_only_tensors(self, device): + with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): + + class MyClass(torch.autograd.Function): + @staticmethod + def forward(x, *, y): + return x + y + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, x, *, y): + assert in_dims == (0,) + return x + y, 0 + + x = torch.randn(3) + y = torch.randn(3) + + vmap(MyClass.apply)(x, y=y) + @markDynamoStrictTest class TestVmapOfGrad(TestCase): diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index f43d2da257a..fb10f222052 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -68,6 +68,7 @@ from torch.testing._internal.common_utils import ( unMarkDynamoStrictTest, xfailIfTorchDynamo, ) +from torch.testing._internal.custom_op_db import custom_op_db from torch.utils import _pytree as pytree @@ -3937,10 +3938,17 @@ def discover_variants(opinfo): @unMarkDynamoStrictTest class TestVmapOperatorsOpInfo(TestCase): def vmap_outplace_test( - self, func, args, kwargs, in_dims, check_shape_only=False, postprocess_fn=None + self, + func, + args, + kwargs, + in_dims, + check_shape_only=False, + postprocess_fn=None, + out_dim=0, ): for vmap_out, loop_out in compute_quantities_for_vmap_test( - func, args, kwargs, in_dims + func, args, kwargs, in_dims, out_dim=out_dim ): if postprocess_fn is not None: loop_out = postprocess_fn(loop_out) @@ -3950,7 +3958,9 @@ class TestVmapOperatorsOpInfo(TestCase): continue self.assertEqual(vmap_out, loop_out) - def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None): + def vmap_inplace_test( + self, func, args, kwargs, in_dims, postprocess_fn=None, out_dim=0 + ): # NB: This test assumes that the first argument is being modified. # This is OK because it's what every other OpInfo-based test assumes, # but it is going to need a more robust solution eventually. @@ -3963,13 +3973,19 @@ class TestVmapOperatorsOpInfo(TestCase): args, kwargs, in_dims, + out_dim=out_dim, compute_loop_out=False, clone_inputs=True, ): pass return for vmap_out, loop_out in compute_quantities_for_vmap_test( - func, args, kwargs, in_dims, clone_inputs=True + func, + args, + kwargs, + in_dims, + clone_inputs=True, + out_dim=out_dim, ): if postprocess_fn is not None: loop_out = postprocess_fn(loop_out) @@ -4027,6 +4043,13 @@ class TestVmapOperatorsOpInfo(TestCase): continue kwargs = sample_input.kwargs is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs) + out_dim = 0 + if op.name == "NumpySplitCopyWithIntCustomOp": + # special case for this custom op + def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim): + return [0 for _ in range(len(splits) + 1)], None + + out_dim = sample_vmap_out_dim_numpy_split_copy_with_int(*args) for batched_args, in_dims, _ in generate_vmap_inputs( args, {}, is_batch_norm_and_training=is_batch_norm_and_training ): @@ -4038,6 +4061,7 @@ class TestVmapOperatorsOpInfo(TestCase): in_dims, check_shape_only, postprocess_fn, + out_dim=out_dim, ) if op.name in skip_inplace: continue @@ -4109,6 +4133,9 @@ class TestVmapOperatorsOpInfo(TestCase): "linalg.eigh", "" ), # not always return the same result for the same input, see test_linalg_eigh for manual test skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format + # UnimplementedError: data-dependent operators cannot be vmapped + xfail("NumpyNonzeroCustomOp"), + xfail("NumpyNMSCustomOp"), # ---------------------------------------------------------------------- # ---------------------------- BUGS ------------------------------------ # entries in here don't work and need to be fixed. @@ -4187,7 +4214,10 @@ class TestVmapOperatorsOpInfo(TestCase): } @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 - @ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one) + @ops( + op_db + additional_op_db + autograd_function_db + custom_op_db, + dtypes=OpDTypes.any_one, + ) @opsToleranceOverride( "TestVmapOperatorsOpInfo", "test_vmap_exhaustive", @@ -4248,7 +4278,10 @@ class TestVmapOperatorsOpInfo(TestCase): ) @with_tf32_off - @ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one) + @ops( + op_db + additional_op_db + autograd_function_db + custom_op_db, + dtypes=OpDTypes.any_one, + ) @opsToleranceOverride( "TestVmapOperatorsOpInfo", "test_op_has_batch_rule", diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 48057d243a4..19e1b241b71 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -2327,6 +2327,12 @@ class TestCustomOpAPI(TestCase): setup_context=lambda ctx, inputs, keyword_only_inputs, output: None, ) + with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): + torch.library.register_vmap( + "_torch_testing::foo", + lambda info, in_dims, x, *, y: (x, 0), + ) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_register_autograd_kwargonly_low_level(self): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: @@ -3382,6 +3388,246 @@ Please use `add.register_fake` to add an fake impl.""", with f.set_kernel_enabled("cpu", enabled=False): self.assertEqual(f(x), x + 1) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_register_vmap_kwargonly_low_level(self): + with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: + lib.define("foo(Tensor x, *, float y) -> Tensor") + called = False + + def foo_impl(x, *, y): + return x * y + + lib.impl("foo", foo_impl, "CPU") + + def vmap(info, in_dims, x, *, y): + nonlocal called + called = True + return x * y, 0 + + torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) + + x = torch.ones(3) + result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14) + self.assertTrue(called) + self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14])) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_register_vmap_defaults(self): + with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: + lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") + + def foo_impl(w, x=2, *, y=3, z): + return w * x * y * z + + lib.impl("foo", foo_impl, "CPU") + + called = False + + def vmap(info, in_dims, w, x=2, *, y=3, z): + nonlocal called + called = True + return w * x * y * z, 0 + + torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) + + w = torch.ones(3) + result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42) + self.assertTrue(called) + self.assertEqual(result, w * 2 * 3 * 42) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_library_register_vmap(self): + for mode in ["function", "qualname", "opoverload", "c_opdef"]: + + @torch.library.custom_op("mylib::f", mutates_args=()) + def f(x: Tensor, y: Tensor) -> Tensor: + return x * y + + called = False + + def fvmap(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + + if mode == "function": + torch.library.register_vmap( + f, + fvmap, + ) + elif mode == "qualname": + torch.library.register_vmap( + "mylib::f", + fvmap, + ) + elif mode == "opoverload": + torch.library.register_vmap( + torch.ops.mylib.f.default, + fvmap, + ) + elif mode == "c_opdef": + f.register_vmap( + fvmap, + ) + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x * y) + + called = False + result = torch.vmap(f, out_dims=1)(x, y) + self.assertEqual(result, (x * y).T) + self.assertTrue(called) + + called = False + result = torch.vmap(f, in_dims=1)(x, y) + self.assertEqual(result, (x * y).T) + self.assertTrue(called) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_library_register_vmap_library_decorator(self): + @torch.library.custom_op("mylib::f", mutates_args=()) + def f(x: Tensor, y: Tensor) -> Tensor: + return x * y + + called = False + + @torch.library.register_vmap("mylib::f") + def fvmap(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x * y) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_library_register_vmap_op_decorator(self): + @torch.library.custom_op("mylib::f", mutates_args=()) + def f(x: Tensor, y: Tensor) -> Tensor: + return x * y + + called = False + + @f.register_vmap + def fvmap(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x * y) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_library_register_vmap_register_multiple_times(self): + @torch.library.custom_op("mylib::f", mutates_args=()) + def f(x: Tensor, y: Tensor) -> Tensor: + return x * y + + called = False + + @f.register_vmap + def fvmap(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x * y) + called = False + + @f.register_vmap + def fvmap2(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x + y + result = result.movedim(-1, 0) + return result, 0 + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x + y) + + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_library_register_vmap_register_multiple_times_2(self): + @torch.library.custom_op("mylib::f", mutates_args=()) + def f(x: Tensor, y: Tensor) -> Tensor: + return x * y + + called = False + + @torch.library.register_vmap("mylib::f") + def fvmap(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x * y) + called = False + + @torch.library.register_vmap("mylib::f") + def fvmap2(info, in_dims, x, y): + nonlocal called + called = True + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x + y + result = result.movedim(-1, 0) + return result, 0 + + result = torch.vmap(f)(x, y) + self.assertTrue(called) + self.assertEqual(result, x + y) + class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index b827fb20424..f80b7dee55b 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -274,7 +274,17 @@ def validate_vmap_returns_tuple_of_two_elements(result): @custom_function_call.py_impl(TransformType.Vmap) -def custom_function_call_vmap(interpreter, autograd_function, *operands): +def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs): + if any( + isinstance(val, torch.Tensor) + for val in torch.utils._pytree.tree_flatten(kwargs)[0] + ): + raise NotImplementedError( + f"Run vmap on autograd.Function with kwarg-only Tensor args. " + f"Please do not pass kwarg-only Tensors to autograd.Function. " + f"Got: {kwargs}" + ) + if autograd_function.generate_vmap_rule: if has_overriden_vmap_rule(autograd_function): # TODO: Update link to stable once that's out @@ -302,22 +312,32 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands): f"https://pytorch.org/docs/main/notes/extending.func.html" ) + return custom_function_call_vmap_helper( + interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs + ) + + +def custom_function_call_vmap_helper( + interpreter, vmap_function, op, *operands, **kwargs +): current_level = interpreter.level() info = VmapInfo( batch_size=interpreter.batch_size(), randomness=interpreter.randomness(), ) unwrapped_operands, in_dims = unwrap_batched(operands, current_level) - # If none of the tensors are batched at the current level, then we skip the # current level. This saves the user from needing to handle this case in # their vmap staticmethod (and is consistent with our C++ batching rule API) if pytree.tree_all(lambda dim: dim is None, in_dims): with interpreter.lower(): - return custom_function_call(autograd_function, *operands) + if isinstance(op, torch.autograd.function.FunctionMeta): + return custom_function_call(op, *operands) + else: + return op(*operands, **kwargs) with interpreter.lower(): - result = autograd_function.vmap(info, in_dims, *unwrapped_operands) + result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs) validate_vmap_returns_tuple_of_two_elements(result) unwrapped_output, out_dims = result diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 153ce907185..57f813e52d2 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -180,6 +180,7 @@ class CustomOpDef: self._setup_context_fn: Optional[Callable] = None self._backward_fn: Optional[Callable] = None self._torch_dispatch_fns: Dict[type, Callable] = {} + self._vmap_fn: Optional[Callable] = None self._lib = get_library_allowing_overwrite(self._namespace, self._name) self._register_to_dispatcher() @@ -662,6 +663,103 @@ class CustomOpDef: def __call__(self, *args, **kwargs): return self._opoverload(*args, **kwargs) + def register_vmap( + self, + func: Optional[Callable] = None, + ): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator. + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> numpy_cube.register_vmap(numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @numpy_mul.register_vmap + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + """ + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def register(func): + need_register = self._vmap_fn is None + self._vmap_fn = func + + if need_register: + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, self._vmap_fn, self._opoverload, *args, **kwargs + ) + + self._lib.impl( + self._name, wrapped_func, "FuncTorchBatched", with_keyset=True + ) + + if func is None: + return register + else: + return register(func) + # NOTE: [Supporting decorator and non-decorator usage] # diff --git a/torch/library.py b/torch/library.py index 3dc64a6eb25..73f71f7eb95 100644 --- a/torch/library.py +++ b/torch/library.py @@ -954,6 +954,136 @@ def register_torch_dispatch( return register(func) +def register_vmap( + op: _op_identifier, + func: Optional[Callable] = None, + /, + *, + lib=None, +): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator (see examples). + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + We do not support kwarg-only Tensor args. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @torch.library.register_vmap("mylib::numpy_mul") + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + + .. note:: + The vmap function should aim to preserve the semantics of the entire custom operator. + That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. + + If your custom operator has any custom behavior in the backward pass, please + keep this in mind. + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}") + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_vmap(func) + assert isinstance(op, str) + qualname = op + op = torch._library.utils.lookup_op(qualname) + schema = op._schema + if _library.utils.has_kwarg_only_tensors(schema): + raise NotImplementedError( + f"register_vmap with kwarg-only Tensor args. In the original " + f"definition of the op, please make your tensors not kwarg-only. " + f"Got: {schema}" + ) + + def register(func): + nonlocal op, lib + + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, func, op, *args, **kwargs + ) + + lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) + + if func is None: + return register + else: + return register(func) + + # If the op was defined in C++, then we want to make sure there was an # m.set_python_module(module, ...) call and that the module is the # same as the module that called torch.library.register_fake. diff --git a/torch/testing/_internal/custom_op_db.py b/torch/testing/_internal/custom_op_db.py index 71a2a8f1065..a2c0b439e7c 100644 --- a/torch/testing/_internal/custom_op_db.py +++ b/torch/testing/_internal/custom_op_db.py @@ -50,6 +50,12 @@ def numpy_cube_backward(ctx, grad_out, grad_dx): numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context) +def numpy_cube_vmap(info, in_dims, x): + result = numpy_cube(x) + return result, (in_dims[0], in_dims[0]) + +numpy_cube.register_vmap(numpy_cube_vmap) + @torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=()) def numpy_mul(x: Tensor, y: Tensor) -> Tensor: return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) @@ -70,6 +76,16 @@ def numpy_mul_backward(ctx, grad_out): numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context) +def numpy_mul_vmap(info, in_dims, x, y): + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = x * y + result = result.movedim(-1, 0) + return result, 0 + +numpy_mul.register_vmap(numpy_mul_vmap) + @torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=()) def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor: return torch.tensor(to_numpy(x) * scalar, device=x.device) @@ -87,6 +103,15 @@ def numpy_mul_scalar_backward(ctx, grad_out): numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context) +def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar): + x_bdim, = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + result = x * scalar + result = result.movedim(-1, 0) + return result, 0 + +numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap) + @torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=()) def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]: device = x.device @@ -116,6 +141,14 @@ def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv): numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context) +def numpy_sort_vmap(info, in_dims, x, dim): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + dim = dim if dim >= 0 else dim + x.dim() - 1 + result = numpy_sort(x, dim + 1) + return result, (0, 0, 0) + +numpy_sort.register_vmap(numpy_sort_vmap) @torch.library.custom_op("_torch_testing::numpy_take", mutates_args=()) def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor: @@ -144,6 +177,26 @@ def numpy_take_backward(ctx, grad_out): numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context) +def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim): + x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims + + # wrap dim + logical_dim = x.dim() if x_bdim is None else x_bdim - 1 + dim = dim if dim >= 0 else dim + logical_dim + + def expand_bdim(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x = expand_bdim(x, x_bdim) + ind = expand_bdim(ind, ind_bdim) + ind_inv = expand_bdim(ind_inv, ind_inv_bdim) + + return numpy_take(x, ind, ind_inv, dim + 1), 0 + +numpy_take.register_vmap(numpy_take_vmap) + @torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=()) def numpy_nonzero(x: Tensor) -> Tensor: x_np = to_numpy(x) @@ -170,6 +223,11 @@ def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs): yield SampleInput(result, args=()) +def numpy_nonzero_vmap(info, in_dims, x): + raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") + +numpy_nonzero.register_vmap(numpy_nonzero_vmap) + @torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=()) def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor: return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device) @@ -186,6 +244,16 @@ def numpy_view_copy_backward(ctx, grad_out): numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context) +def numpy_view_copy_vmap(info, in_dims, x, shape): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + x_shape = x.shape[0] + batch_shape = (x_shape, *shape) + result = numpy_view_copy(x, batch_shape) + return result, 0 + +numpy_view_copy.register_vmap(numpy_view_copy_vmap) + def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) result = make_arg(2, 3, 4, low=0.9, high=2) @@ -222,6 +290,13 @@ def numpy_cat_backward(ctx, grad_out): numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context) +def numpy_cat_vmap(info, in_dims, x, dim): + x_bdim, = in_dims + result = numpy_cat(x, dim) + return result, x_bdim + +numpy_cat.register_vmap(numpy_cat_vmap) + def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) r0 = make_arg(2, 3, 4, low=0.9, high=2) @@ -249,6 +324,14 @@ def numpy_split_copy_backward(ctx, grad_out): numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context) +def numpy_split_copy_vmap(info, in_dims, x, splits, dim): + x_bdim, _ , _ = in_dims + x = x.movedim(x_bdim, 0) + result = numpy_split_copy(x, splits, dim + 1) + return result, 0 + +numpy_split_copy.register_vmap(numpy_split_copy_vmap) + def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) x = make_arg(2, 9, low=0.9, high=2) @@ -275,6 +358,14 @@ numpy_split_copy_with_int.register_autograd( numpy_split_copy_with_int_backward, setup_context=numpy_split_copy_with_int_setup_context) +def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim): + x_bdim, _ , _ = in_dims + x = x.movedim(x_bdim, 0) + result, len_split = numpy_split_copy_with_int(x, splits, dim + 1) + return (result, len_split), ([0 for _ in range(len(result))], None) + +numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap) + @torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=()) def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor: # Adapted from Ross Girshick's fast-rcnn implementation at @@ -331,6 +422,11 @@ def _(boxes, scores, iou_threshold): result = boxes.new_empty([i0], dtype=torch.int64) return result +def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold): + raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") + +numpy_nms.register_vmap(numpy_nms_vmap) + def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial(make_tensor, device=device, dtype=dtype) N = 64