mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add NestedTensor ops: logical_not, logical_not_, masked_fill (#97934)
# Summary <!-- copilot:summary --> ### <samp>🤖 Generated by Copilot at 7954302</samp> This pull request adds support for `logical_not` and `masked_fill` operations on nested tensors, which are tensors that can have tensors as elements. It modifies the `native_functions.yaml` file to dispatch these operations to the nested tensor backend, implements the logic for these operations in `NestedTensorBinaryOps.cpp` and `NestedTensorUnaryOps.cpp`, adds documentation in `nested.rst`, and adds tests in `test_nestedtensor.py`. ## Description <!-- copilot:walkthrough --> ### <samp>🤖 Generated by Copilot at 7954302</samp> * Implement `logical_not` operation on nested tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1164), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1172), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f7c94671810b3ce652f9ad5458518cb7bbd67e8bf7e84e0a2fba641d878ba7c5R45-R56), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR203), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0L854-R867)) - Add `NestedTensor_logical_not` and `NestedTensor_logical_not_` functions to `native_functions.yaml` for CPU and CUDA dispatch ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1164), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R1172)) - Define `NestedTensor_logical_not` and `NestedTensor_logical_not_` functions in `NestedTensorUnaryOps.cpp` using `map_nt` and `get_buffer` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f7c94671810b3ce652f9ad5458518cb7bbd67e8bf7e84e0a2fba641d878ba7c5R45-R56)) - Document `torch.logical_not` function for nested tensors in `nested.rst` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR203)) - Add subtest for `logical_not` function in `test_activations` method in `TestNestedTensorDeviceType` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0L854-R867)) * Implement `masked_fill` operation on nested tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R7439), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L210-R224), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR197), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R677-R688), [link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R2515-R2528)) - Add `NestedTensor_masked_fill` function to `native_functions.yaml` for CPU and CUDA dispatch ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991R7439)) - Define `NestedTensor_masked_fill` function in `NestedTensorBinaryOps.cpp` using `NestedTensor_elementwise_Tensor` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L210-R224)) - Document `torch.Tensor.masked_fill` function for nested tensors in `nested.rst` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-c8b131d009badb3f92031b2aaa6e7f93a793f13caee278ea78e1c57d78c0399eR197)) - Add test case for `masked_fill` function in `TestNestedTensorDeviceType` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R677-R688)) - Add test case for backward pass of `masked_fill` function in `TestNestedTensorAutograd` class in `test_nestedtensor.py` ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-6eef496a8ec635930b6e52507358e069c80021f3535b8737d39e14ffc38950c0R2515-R2528)) * Improve error message for unsupported element-wise binary operations on nested dense tensors ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L142-R150)) - Modify `NestedTensor_elementwise_Tensor` function in `NestedTensorBinaryOps.cpp` to include operation name in error message ([link](https://github.com/pytorch/pytorch/pull/97934/files?diff=unified&w=0#diff-f847e41e3d373230df0b25574e993ec0e6b699bf16796b3df9ae9fb518048e25L142-R150)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97934 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
f92cae4849
commit
5a81508bb6
5 changed files with 62 additions and 3 deletions
|
|
@ -1161,6 +1161,7 @@
|
|||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: logical_not
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: logical_not_(Tensor(a!) self) -> Tensor(a!)
|
||||
|
|
@ -1168,6 +1169,7 @@
|
|||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: logical_not_
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not_
|
||||
tags: pointwise
|
||||
|
||||
- func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
@ -7434,6 +7436,7 @@
|
|||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: masked_fill
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_masked_fill
|
||||
tags: pointwise
|
||||
|
||||
- func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -139,11 +139,15 @@ Tensor NestedTensor_elementwise_Tensor(
|
|||
} else if (op_name == "mul") {
|
||||
nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported nested dense elementwise op");
|
||||
TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, ".");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
TORCH_CHECK(false, "Expected both self and other to be nested, but got a nested self and non-nested other.");
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected both self and other to be nested, but got a nested self and non-nested other for op: ",
|
||||
op_name,
|
||||
".");
|
||||
}
|
||||
|
||||
NestedTensorImpl* self_impl = nullptr;
|
||||
|
|
@ -207,6 +211,16 @@ Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) {
|
|||
Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) {
|
||||
return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other));
|
||||
}
|
||||
Tensor NestedTensor_masked_fill(
|
||||
const Tensor& self,
|
||||
const Tensor& mask,
|
||||
const Scalar& value) {
|
||||
return NestedTensor_elementwise_Tensor(
|
||||
self, mask, "masked_fill", false /* supports_striding*/, [value](const Tensor& b1, const Tensor& b2) {
|
||||
return at::masked_fill(b1, b2, value);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
template <typename Func>
|
||||
Tensor& NestedTensor_elementwise__Tensor(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,18 @@ Tensor& NestedTensor_sgn_(Tensor& self) {
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_logical_not_(Tensor& self){
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
auto buffer = self_ptr->get_buffer();
|
||||
buffer.logical_not_();
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_logical_not(const Tensor& self) {
|
||||
return map_nt(self, at::logical_not);
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_relu_(Tensor& self) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
|
|
|
|||
|
|
@ -194,11 +194,13 @@ NestedTensor and any constraints they have.
|
|||
:func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix."
|
||||
:func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
|
||||
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.relu`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.silu`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.abs`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.sgn`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.logical_not`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.neg`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.sub`; "Supports elementwise subtraction of two nested tensors."
|
||||
:func:`torch.add`; "Supports elementwise addition of two nested tensors.
|
||||
|
|
|
|||
|
|
@ -674,6 +674,19 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
for i, inp in enumerate(inputs):
|
||||
self.assertEqual(emb(inp), ys[i])
|
||||
|
||||
|
||||
@skipMeta
|
||||
@torch.inference_mode()
|
||||
@dtypes(*floating_types_and_half())
|
||||
def test_masked_fill(self, device, dtype):
|
||||
# nested tensor * nested tensor
|
||||
(nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
||||
mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
|
||||
ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())])
|
||||
out = nt.masked_fill(mask, 0)
|
||||
self.assertEqual(ref, out)
|
||||
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
def test_to_padded_tensor_simple(self, device, dtype):
|
||||
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
||||
|
|
@ -851,7 +864,8 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'),
|
||||
subtest(torch.abs, name="abs"),
|
||||
subtest(torch.abs_, name="abs_"),
|
||||
subtest(torch.sgn, name="sgn")])
|
||||
subtest(torch.sgn, name="sgn"),
|
||||
subtest(torch.logical_not, name='logical_not'),])
|
||||
def test_activations(self, device, func):
|
||||
nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32)
|
||||
nested_result = func(nt)
|
||||
|
|
@ -2499,6 +2513,20 @@ class TestNestedTensorAutograd(TestCase):
|
|||
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)])
|
||||
self.assertEqual(nt.grad, expected_grad)
|
||||
|
||||
def test_masked_fill_backward(self, device):
|
||||
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
|
||||
def grad_test_func(a, b, c):
|
||||
nt = torch.nested.as_nested_tensor([a, b, c])
|
||||
mask = nt.detach().clone().to(bool)
|
||||
out = nt.masked_fill(mask, 0)
|
||||
out = torch.nested.to_padded_tensor(out, 0)
|
||||
return out
|
||||
data = (a, b, c)
|
||||
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
||||
|
||||
def test_gelu_backward(self, device):
|
||||
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue