Add NestedTensor ops: logical_not, logical_not_, masked_fill (#97934)

# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 7954302</samp>

This pull request adds support for `logical_not` and `masked_fill` operations on nested tensors, which are tensors that can have tensors as elements. It modifies the `native_functions.yaml` file to dispatch these operations to the nested tensor backend, implements the logic for these operations in `NestedTensorBinaryOps.cpp` and `NestedTensorUnaryOps.cpp`, adds documentation in `nested.rst`, and adds tests in `test_nestedtensor.py`.

## Description
<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 7954302</samp>

*  Implement `logical_not` operation on nested tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1164), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1172), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f7c94671810b3ce652f9ad5458518cb7bbd67e8bf7e84e0a2fba641d878ba7c5R45-R56), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR203), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0L854-R867))
  - Add `NestedTensor_logical_not` and `NestedTensor_logical_not_` functions to `native_functions.yaml` for CPU and CUDA dispatch ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1164), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1172))
  - Define `NestedTensor_logical_not` and `NestedTensor_logical_not_` functions in `NestedTensorUnaryOps.cpp` using `map_nt` and `get_buffer` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f7c94671810b3ce652f9ad5458518cb7bbd67e8bf7e84e0a2fba641d878ba7c5R45-R56))
  - Document `torch.logical_not` function for nested tensors in `nested.rst` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR203))
  - Add subtest for `logical_not` function in `test_activations` method in `TestNestedTensorDeviceType` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0L854-R867))
* Implement `masked_fill` operation on nested tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R7439), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L210-R224), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR197), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R677-R688), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R2515-R2528))
  - Add `NestedTensor_masked_fill` function to `native_functions.yaml` for CPU and CUDA dispatch ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R7439))
  - Define `NestedTensor_masked_fill` function in `NestedTensorBinaryOps.cpp` using `NestedTensor_elementwise_Tensor` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L210-R224))
  - Document `torch.Tensor.masked_fill` function for nested tensors in `nested.rst` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR197))
  - Add test case for `masked_fill` function in `TestNestedTensorDeviceType` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R677-R688))
  - Add test case for backward pass of `masked_fill` function in `TestNestedTensorAutograd` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R2515-R2528))
* Improve error message for unsupported element-wise binary operations on nested dense tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L142-R150))
  - Modify `NestedTensor_elementwise_Tensor` function in `NestedTensorBinaryOps.cpp` to include operation name in error message ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L142-R150))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97934
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous 2023-03-30 08:14:36 +00:00 committed by PyTorch MergeBot
parent f92cae4849
commit 5a81508bb6
5 changed files with 62 additions and 3 deletions

View file

@ -1161,6 +1161,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: logical_not
NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not
tags: [core, pointwise]
- func: logical_not_(Tensor(a!) self) -> Tensor(a!)
@ -1168,6 +1169,7 @@
variants: method
dispatch:
CompositeExplicitAutograd: logical_not_
NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not_
tags: pointwise
- func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
@ -7434,6 +7436,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: masked_fill
NestedTensorCPU, NestedTensorCUDA: NestedTensor_masked_fill
tags: pointwise
- func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)

View file

@ -139,11 +139,15 @@ Tensor NestedTensor_elementwise_Tensor(
} else if (op_name == "mul") {
nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL);
} else {
TORCH_CHECK(false, "Unsupported nested dense elementwise op");
TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, ".");
}
return result;
}
TORCH_CHECK(false, "Expected both self and other to be nested, but got a nested self and non-nested other.");
TORCH_CHECK(
false,
"Expected both self and other to be nested, but got a nested self and non-nested other for op: ",
op_name,
".");
}
NestedTensorImpl* self_impl = nullptr;
@ -207,6 +211,16 @@ Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) {
Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) {
return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other));
}
Tensor NestedTensor_masked_fill(
const Tensor& self,
const Tensor& mask,
const Scalar& value) {
return NestedTensor_elementwise_Tensor(
self, mask, "masked_fill", false /* supports_striding*/, [value](const Tensor& b1, const Tensor& b2) {
return at::masked_fill(b1, b2, value);
});
}
template <typename Func>
Tensor& NestedTensor_elementwise__Tensor(

View file

@ -42,6 +42,18 @@ Tensor& NestedTensor_sgn_(Tensor& self) {
return self;
}
Tensor& NestedTensor_logical_not_(Tensor& self){
auto self_ptr = get_nested_tensor_impl(self);
check_numel_equals_buffer_size(self_ptr);
auto buffer = self_ptr->get_buffer();
buffer.logical_not_();
return self;
}
Tensor NestedTensor_logical_not(const Tensor& self) {
return map_nt(self, at::logical_not);
}
Tensor& NestedTensor_relu_(Tensor& self) {
auto self_ptr = get_nested_tensor_impl(self);
check_numel_equals_buffer_size(self_ptr);

View file

@ -194,11 +194,13 @@ NestedTensor and any constraints they have.
:func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix."
:func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
:func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors."
:func:`torch.relu`; "Behavior is the same as on regular tensors."
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
:func:`torch.silu`; "Behavior is the same as on regular tensors."
:func:`torch.abs`; "Behavior is the same as on regular tensors."
:func:`torch.sgn`; "Behavior is the same as on regular tensors."
:func:`torch.logical_not`; "Behavior is the same as on regular tensors."
:func:`torch.neg`; "Behavior is the same as on regular tensors."
:func:`torch.sub`; "Supports elementwise subtraction of two nested tensors."
:func:`torch.add`; "Supports elementwise addition of two nested tensors.

View file

@ -674,6 +674,19 @@ class TestNestedTensorDeviceType(TestCase):
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
@skipMeta
@torch.inference_mode()
@dtypes(*floating_types_and_half())
def test_masked_fill(self, device, dtype):
# nested tensor * nested tensor
(nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())])
out = nt.masked_fill(mask, 0)
self.assertEqual(ref, out)
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_simple(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
@ -851,7 +864,8 @@ class TestNestedTensorDeviceType(TestCase):
subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'),
subtest(torch.abs, name="abs"),
subtest(torch.abs_, name="abs_"),
subtest(torch.sgn, name="sgn")])
subtest(torch.sgn, name="sgn"),
subtest(torch.logical_not, name='logical_not'),])
def test_activations(self, device, func):
nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32)
nested_result = func(nt)
@ -2499,6 +2513,20 @@ class TestNestedTensorAutograd(TestCase):
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)])
self.assertEqual(nt.grad, expected_grad)
def test_masked_fill_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
mask = nt.detach().clone().to(bool)
out = nt.masked_fill(mask, 0)
out = torch.nested.to_padded_tensor(out, 0)
return out
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_gelu_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)