diff --git a/test/test_nn.py b/test/test_nn.py index 052857c001c..d59acb2e5b0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -14432,10 +14432,16 @@ class TestNNDeviceType(NNTestCase): with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"): nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device)) - # no batch dims - input = torch.rand(50, 2, 2, device=device) - self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input) - self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input) + # TODO: Uncomment these lines once no-batch-dim inputs are supported. + # For now, the historical dropout1d behavior is performed for 3D inputs. + # See https://github.com/pytorch/pytorch/issues/77081 + + # input = torch.rand(50, 2, 2, device=device) + # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input) + # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input) + + with self.assertWarnsRegex(UserWarning, "assuming that channel-wise 1D dropout behavior is desired"): + nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device)) # check that complete channels are dropped input = torch.ones(10, 4, 2, 2, device=device) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 07d04584384..2c668dda22e 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1330,15 +1330,19 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).") warnings.warn(warn_msg) - is_batched = inp_dim == 4 - if not is_batched: - input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing + # a 3D input will perform dropout1d behavior instead. This was done historically and the + # behavior is maintained here for now. + # See https://github.com/pytorch/pytorch/issues/77081 + if inp_dim == 3: + warnings.warn("dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " + "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " + "is the channel dim. This behavior will change in a future release to interpret the " + "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " + "channel-wise dropout behavior, please switch to using dropout1d instead.") result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) - if not is_batched: - result = result.squeeze_(0) if inplace else result.squeeze(0) - return result diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index 2fff5ab4e7a..5f25aae7fa5 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -124,9 +124,16 @@ class Dropout2d(_DropoutNd): inplace (bool, optional): If set to ``True``, will do this operation in-place + .. warning :: + Due to historical reasons, this class will perform 1D channel-wise dropout + for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT + support inputs without a batch dimension of shape :math:`(C, H, W)`. This + behavior will change in a future release to interpret 3D inputs as no-batch-dim + inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`. + Shape: - - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. - - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input). + - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`. + - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input). Examples::