From 5a81508bb6d265197d4c29bc4f2be546dc0662ab Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Thu, 30 Mar 2023 08:14:36 +0000 Subject: [PATCH] Add NestedTensor ops: logical_not, logical_not_, masked_fill (#97934) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Summary ### 🤖 Generated by Copilot at 7954302 This pull request adds support for `logical_not` and `masked_fill` operations on nested tensors, which are tensors that can have tensors as elements. It modifies the `native_functions.yaml` file to dispatch these operations to the nested tensor backend, implements the logic for these operations in `NestedTensorBinaryOps.cpp` and `NestedTensorUnaryOps.cpp`, adds documentation in `nested.rst`, and adds tests in `test_nestedtensor.py`. ## Description ### 🤖 Generated by Copilot at 7954302 * Implement `logical_not` operation on nested tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1164), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1172), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f7c94671810b3ce652f9ad5458518cb7bbd67e8bf7e84e0a2fba641d878ba7c5R45-R56), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR203), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0L854-R867)) - Add `NestedTensor_logical_not` and `NestedTensor_logical_not_` functions to `native_functions.yaml` for CPU and CUDA dispatch ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1164), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1172)) - Define `NestedTensor_logical_not` and `NestedTensor_logical_not_` functions in `NestedTensorUnaryOps.cpp` using `map_nt` and `get_buffer` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f7c94671810b3ce652f9ad5458518cb7bbd67e8bf7e84e0a2fba641d878ba7c5R45-R56)) - Document `torch.logical_not` function for nested tensors in `nested.rst` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR203)) - Add subtest for `logical_not` function in `test_activations` method in `TestNestedTensorDeviceType` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0L854-R867)) * Implement `masked_fill` operation on nested tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R7439), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L210-R224), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR197), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R677-R688), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R2515-R2528)) - Add `NestedTensor_masked_fill` function to `native_functions.yaml` for CPU and CUDA dispatch ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R7439)) - Define `NestedTensor_masked_fill` function in `NestedTensorBinaryOps.cpp` using `NestedTensor_elementwise_Tensor` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L210-R224)) - Document `torch.Tensor.masked_fill` function for nested tensors in `nested.rst` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR197)) - Add test case for `masked_fill` function in `TestNestedTensorDeviceType` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R677-R688)) - Add test case for backward pass of `masked_fill` function in `TestNestedTensorAutograd` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R2515-R2528)) * Improve error message for unsupported element-wise binary operations on nested dense tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L142-R150)) - Modify `NestedTensor_elementwise_Tensor` function in `NestedTensorBinaryOps.cpp` to include operation name in error message ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L142-R150)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97934 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 3 ++ .../native/nested/NestedTensorBinaryOps.cpp | 18 +++++++++-- .../native/nested/NestedTensorUnaryOps.cpp | 12 ++++++++ docs/source/nested.rst | 2 ++ test/test_nestedtensor.py | 30 ++++++++++++++++++- 5 files changed, 62 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a8eae570c8b..893715750e5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1161,6 +1161,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: logical_not + NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not tags: [core, pointwise] - func: logical_not_(Tensor(a!) self) -> Tensor(a!) @@ -1168,6 +1169,7 @@ variants: method dispatch: CompositeExplicitAutograd: logical_not_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not_ tags: pointwise - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -7434,6 +7436,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: masked_fill + NestedTensorCPU, NestedTensorCUDA: NestedTensor_masked_fill tags: pointwise - func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) diff --git a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp index 6f2b36c6bea..bf1dfe6aedf 100644 --- a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp @@ -139,11 +139,15 @@ Tensor NestedTensor_elementwise_Tensor( } else if (op_name == "mul") { nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL); } else { - TORCH_CHECK(false, "Unsupported nested dense elementwise op"); + TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, "."); } return result; } - TORCH_CHECK(false, "Expected both self and other to be nested, but got a nested self and non-nested other."); + TORCH_CHECK( + false, + "Expected both self and other to be nested, but got a nested self and non-nested other for op: ", + op_name, + "."); } NestedTensorImpl* self_impl = nullptr; @@ -207,6 +211,16 @@ Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) { Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) { return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other)); } +Tensor NestedTensor_masked_fill( + const Tensor& self, + const Tensor& mask, + const Scalar& value) { + return NestedTensor_elementwise_Tensor( + self, mask, "masked_fill", false /* supports_striding*/, [value](const Tensor& b1, const Tensor& b2) { + return at::masked_fill(b1, b2, value); + }); +} + template Tensor& NestedTensor_elementwise__Tensor( diff --git a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp index 662da4a183f..9cfc53b60a2 100644 --- a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp @@ -42,6 +42,18 @@ Tensor& NestedTensor_sgn_(Tensor& self) { return self; } +Tensor& NestedTensor_logical_not_(Tensor& self){ + auto self_ptr = get_nested_tensor_impl(self); + check_numel_equals_buffer_size(self_ptr); + auto buffer = self_ptr->get_buffer(); + buffer.logical_not_(); + return self; +} + +Tensor NestedTensor_logical_not(const Tensor& self) { + return map_nt(self, at::logical_not); +} + Tensor& NestedTensor_relu_(Tensor& self) { auto self_ptr = get_nested_tensor_impl(self); check_numel_equals_buffer_size(self_ptr); diff --git a/docs/source/nested.rst b/docs/source/nested.rst index c0874f3b670..779f3c6ac00 100644 --- a/docs/source/nested.rst +++ b/docs/source/nested.rst @@ -194,11 +194,13 @@ NestedTensor and any constraints they have. :func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix." :func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0." :func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors." + :func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors." :func:`torch.relu`; "Behavior is the same as on regular tensors." :func:`torch.gelu`; "Behavior is the same as on regular tensors." :func:`torch.silu`; "Behavior is the same as on regular tensors." :func:`torch.abs`; "Behavior is the same as on regular tensors." :func:`torch.sgn`; "Behavior is the same as on regular tensors." + :func:`torch.logical_not`; "Behavior is the same as on regular tensors." :func:`torch.neg`; "Behavior is the same as on regular tensors." :func:`torch.sub`; "Supports elementwise subtraction of two nested tensors." :func:`torch.add`; "Supports elementwise addition of two nested tensors. diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 1e56eb67075..8e69b1c104b 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -674,6 +674,19 @@ class TestNestedTensorDeviceType(TestCase): for i, inp in enumerate(inputs): self.assertEqual(emb(inp), ys[i]) + + @skipMeta + @torch.inference_mode() + @dtypes(*floating_types_and_half()) + def test_masked_fill(self, device, dtype): + # nested tensor * nested tensor + (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) + mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) + ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]) + out = nt.masked_fill(mask, 0) + self.assertEqual(ref, out) + + @dtypes(torch.float, torch.float16) def test_to_padded_tensor_simple(self, device, dtype): t = torch.randn(4, 4, 4, device=device, dtype=dtype) @@ -851,7 +864,8 @@ class TestNestedTensorDeviceType(TestCase): subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'), subtest(torch.abs, name="abs"), subtest(torch.abs_, name="abs_"), - subtest(torch.sgn, name="sgn")]) + subtest(torch.sgn, name="sgn"), + subtest(torch.logical_not, name='logical_not'),]) def test_activations(self, device, func): nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32) nested_result = func(nt) @@ -2499,6 +2513,20 @@ class TestNestedTensorAutograd(TestCase): expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)]) self.assertEqual(nt.grad, expected_grad) + def test_masked_fill_backward(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + mask = nt.detach().clone().to(bool) + out = nt.masked_fill(mask, 0) + out = torch.nested.to_padded_tensor(out, 0) + return out + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + def test_gelu_backward(self, device): a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)