mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add Dropout1d module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79545 Approved by: https://github.com/ngimel, https://github.com/albanD
This commit is contained in:
parent
081ff9602a
commit
2d73c8e6e0
10 changed files with 114 additions and 2 deletions
|
|
@ -120,6 +120,7 @@ Dropout functions
|
|||
dropout
|
||||
alpha_dropout
|
||||
feature_alpha_dropout
|
||||
dropout1d
|
||||
dropout2d
|
||||
dropout3d
|
||||
|
||||
|
|
|
|||
|
|
@ -250,6 +250,7 @@ Dropout Layers
|
|||
:template: classtemplate.rst
|
||||
|
||||
nn.Dropout
|
||||
nn.Dropout1d
|
||||
nn.Dropout2d
|
||||
nn.Dropout3d
|
||||
nn.AlphaDropout
|
||||
|
|
|
|||
|
|
@ -3896,6 +3896,7 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"cross_entropy": CONTROL_FLOW,
|
||||
"ctc_loss": CONTROL_FLOW,
|
||||
"dropout": CONTROL_FLOW,
|
||||
"dropout1d": CONTROL_FLOW,
|
||||
"dropout2d": CONTROL_FLOW,
|
||||
"dropout3d": CONTROL_FLOW,
|
||||
"elu": CONTROL_FLOW,
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ def build_constructor_arg_db():
|
|||
torch.nn.CosineSimilarity: ((), {}),
|
||||
torch.nn.CrossEntropyLoss: ((), {}),
|
||||
torch.nn.CrossMapLRN2d: ((5,), {}),
|
||||
torch.nn.Dropout1d: ((), {}),
|
||||
torch.nn.Dropout2d: ((), {}),
|
||||
torch.nn.Dropout3d: ((), {}),
|
||||
torch.nn.Dropout: ((), {}),
|
||||
|
|
|
|||
|
|
@ -7356,6 +7356,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
v = torch.ones(1)
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout(1.1))
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1))
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1))
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1))
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1))
|
||||
self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1))
|
||||
|
|
@ -14389,6 +14391,28 @@ class TestNNDeviceType(NNTestCase):
|
|||
for b, c in product(range(B), range(C)):
|
||||
self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel))
|
||||
|
||||
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
|
||||
def test_Dropout1d(self, device):
|
||||
N, C, L = random.randint(10, 15), random.randint(10, 15), random.randint(10, 15)
|
||||
input = torch.empty(N, C, L)
|
||||
self._test_dropout(nn.Dropout1d, device, input)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected 2D or 3D input, but received a 4D input"):
|
||||
nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected 2D or 3D input, but received a 1D input"):
|
||||
nn.Dropout1d(p=0.5)(torch.rand(2, device=device))
|
||||
|
||||
# no batch dims
|
||||
input = torch.rand(50, 2, device=device)
|
||||
self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input)
|
||||
self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input)
|
||||
|
||||
# check that complete channels are dropped
|
||||
input = torch.ones(10, 4, 2, device=device)
|
||||
self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input)
|
||||
self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input)
|
||||
|
||||
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
|
||||
def test_Dropout2d(self, device):
|
||||
b = random.randint(1, 5)
|
||||
|
|
|
|||
|
|
@ -1264,6 +1264,44 @@ def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace
|
|||
return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training)
|
||||
|
||||
|
||||
def dropout1d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor:
|
||||
r"""
|
||||
Randomly zero out entire channels (a channel is a 1D feature map,
|
||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||
batched input is a 1D tensor :math:`\text{input}[i, j]`) of the input tensor).
|
||||
Each channel will be zeroed out independently on every forward call with
|
||||
probability :attr:`p` using samples from a Bernoulli distribution.
|
||||
|
||||
See :class:`~torch.nn.Dropout1d` for details.
|
||||
|
||||
Args:
|
||||
p: probability of a channel to be zeroed. Default: 0.5
|
||||
training: apply dropout if is ``True``. Default: ``True``
|
||||
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
|
||||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(dropout1d, (input,), input, p=p, training=training, inplace=inplace)
|
||||
if p < 0.0 or p > 1.0:
|
||||
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
||||
inp_dim = input.dim()
|
||||
if inp_dim not in (2, 3):
|
||||
raise RuntimeError(f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. "
|
||||
"Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
|
||||
"spatial dimension, a channel dimension, and an optional batch dimension "
|
||||
"(i.e. 2D or 3D inputs).")
|
||||
|
||||
is_batched = inp_dim == 3
|
||||
if not is_batched:
|
||||
input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
|
||||
|
||||
result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training)
|
||||
|
||||
if not is_batched:
|
||||
result = result.squeeze_(0) if inplace else result.squeeze(0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor:
|
||||
r"""
|
||||
Randomly zero out entire channels (a channel is a 2D feature map,
|
||||
|
|
|
|||
|
|
@ -98,6 +98,9 @@ def dropout(input: Tensor, p: float = ..., training: bool = ..., inplace: bool =
|
|||
def alpha_dropout(input: Tensor, p: float = ..., training: bool = ..., inplace: bool = ...) -> Tensor: ...
|
||||
|
||||
|
||||
def dropout1d(input: Tensor, p: float = ..., training: bool = ..., inplace: bool = ...) -> Tensor: ...
|
||||
|
||||
|
||||
def dropout2d(input: Tensor, p: float = ..., training: bool = ..., inplace: bool = ...) -> Tensor: ...
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
|
|||
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \
|
||||
LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d
|
||||
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
|
||||
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
|
||||
from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
|
||||
from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \
|
||||
ReplicationPad3d, ZeroPad2d, ConstantPad1d, ConstantPad2d, ConstantPad3d
|
||||
from .sparse import Embedding, EmbeddingBag
|
||||
|
|
@ -49,7 +49,7 @@ __all__ = [
|
|||
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
|
||||
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
|
||||
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
|
||||
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
||||
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
||||
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
|
||||
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
|
||||
'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
|
||||
|
|
|
|||
|
|
@ -58,6 +58,48 @@ class Dropout(_DropoutNd):
|
|||
return F.dropout(input, self.p, self.training, self.inplace)
|
||||
|
||||
|
||||
class Dropout1d(_DropoutNd):
|
||||
r"""Randomly zero out entire channels (a channel is a 1D feature map,
|
||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||
batched input is a 1D tensor :math:`\text{input}[i, j]`).
|
||||
Each channel will be zeroed out independently on every forward call with
|
||||
probability :attr:`p` using samples from a Bernoulli distribution.
|
||||
|
||||
Usually the input comes from :class:`nn.Conv1d` modules.
|
||||
|
||||
As described in the paper
|
||||
`Efficient Object Localization Using Convolutional Networks`_ ,
|
||||
if adjacent pixels within feature maps are strongly correlated
|
||||
(as is normally the case in early convolution layers) then i.i.d. dropout
|
||||
will not regularize the activations and will otherwise just result
|
||||
in an effective learning rate decrease.
|
||||
|
||||
In this case, :func:`nn.Dropout1d` will help promote independence between
|
||||
feature maps and should be used instead.
|
||||
|
||||
Args:
|
||||
p (float, optional): probability of an element to be zero-ed.
|
||||
inplace (bool, optional): If set to ``True``, will do this operation
|
||||
in-place
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, L)` or :math:`(C, L)`.
|
||||
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m = nn.Dropout1d(p=0.2)
|
||||
>>> input = torch.randn(20, 16, 32)
|
||||
>>> output = m(input)
|
||||
|
||||
.. _Efficient Object Localization Using Convolutional Networks:
|
||||
https://arxiv.org/abs/1411.4280
|
||||
"""
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.dropout1d(input, self.p, self.training, self.inplace)
|
||||
|
||||
|
||||
class Dropout2d(_DropoutNd):
|
||||
r"""Randomly zero out entire channels (a channel is a 2D feature map,
|
||||
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
||||
|
|
|
|||
|
|
@ -752,6 +752,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
|
||||
reduction='mean', zero_infinity=False: -1),
|
||||
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue