Add forward and backward support for silu to NestedTensors (#97181)

# Summary
Add forward and backward support for silu to NestedTensors
- Add forward support to silu
- Add forward support to silu_
- Add backward support to silu
- Add to NT docs
- Add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97181
Approved by: https://github.com/cpuhrsch, https://github.com/jbschlosser
This commit is contained in:
Driss Guessous 2023-03-20 23:46:07 +00:00 committed by PyTorch MergeBot
parent 9a5fed1bd0
commit a269e5fa04
5 changed files with 41 additions and 1 deletions

View file

@ -4859,10 +4859,14 @@
- func: silu(Tensor self) -> Tensor
structured_delegate: silu.out
python_module: nn
dispatch:
NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu
- func: silu_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: silu.out
python_module: nn
dispatch:
NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu_
- func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -4885,6 +4889,7 @@
python_module: nn
dispatch:
CompositeImplicitAutograd: math_silu_backward
NestedTensorCPU, NestedTensorCUDA: silu_backward_nested
- func: mish(Tensor self) -> Tensor
structured_delegate: mish.out

View file

@ -185,6 +185,12 @@ Tensor threshold_backwards_nested(const Tensor& grad_output, const Tensor& input
return map_nt_binary(grad_output, input, partial_relu_backward);
}
// Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
Tensor silu_backward_nested(const Tensor& grad_output, const Tensor& self){
auto partial_silu_backward = [](auto && PH1, auto && PH2) { return at::silu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); };
return map_nt_binary(grad_output, self, partial_silu_backward);
}
std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
const Tensor& grad,
const Tensor& input,

View file

@ -76,5 +76,17 @@ Tensor& zero_nested_(Tensor& self) {
return self;
}
Tensor NestedTensor_silu(const Tensor& self){
return map_nt(self, at::silu);
}
Tensor& NestedTensor_silu_(Tensor& self){
auto self_ptr = get_nested_tensor_impl(self);
check_numel_equals_buffer_size(self_ptr);
auto buffer = self_ptr->get_buffer();
at::silu_(buffer);
return self;
}
} // namespace native
} // namespace at

View file

@ -196,6 +196,7 @@ NestedTensor and any constraints they have.
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
:func:`torch.relu`; "Behavior is the same as on regular tensors."
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
:func:`torch.silu`; "Behavior is the same as on regular tensors."
:func:`torch.neg`; "Behavior is the same as on regular tensors."
:func:`torch.add`; "Supports elementwise addition of two nested tensors.
Supports addition of a scalar to a nested tensor."

View file

@ -1,6 +1,7 @@
# Owner(s): ["module: nestedtensor"]
import unittest
from functools import partial
import numpy as np
import torch
@ -845,7 +846,9 @@ class TestNestedTensorDeviceType(TestCase):
subtest(torch._C._nn.gelu_, name='gelu_'),
subtest(torch.tanh, name='tanh'),
subtest(torch.tanh_, name='tanh_'),
subtest(torch.neg, name='neg')])
subtest(torch.neg, name='neg'),
subtest(torch.nn.functional.silu, name='silu'),
subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'), ])
def test_activations(self, device, func):
nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32)
nested_result = func(nt)
@ -2401,6 +2404,19 @@ class TestNestedTensorAutograd(TestCase):
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_selu_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
nt_relu = torch.nn.functional.silu(nt)
return torch.nested.to_padded_tensor(nt_relu, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# Previously would error when input NT doesn't require grad
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
def test_layer_norm_backward_edge_case(self, device):