diff --git a/test/test_nn.py b/test/test_nn.py index fe7593a33fb..4d07b1c9823 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8119,6 +8119,23 @@ class TestNNDeviceType(NNTestCase): if self.device_type == 'cuda': self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device) + @parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__) + @parametrize_test("no_batch_dim", [True, False]) + @parametrize_test("affine", [True, False]) + def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine): + inst_norm = instance_norm_cls(4, affine=affine) + size = [2] * inst_norm._get_no_batch_dim() + if not no_batch_dim: + size = [3] + size + t = torch.randn(size) + if affine: + with self.assertRaisesRegex(ValueError, "expected input's size at dim="): + inst_norm(t) + else: + with warnings.catch_warnings(record=True) as w: + inst_norm(t) + self.assertIn("which is not used because affine=False", str(w[0].message)) + def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device): x = torch.rand(10)[None, :, None] with self.assertRaises(ValueError): diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index ceb34f310a2..97a70cde16e 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -1,3 +1,5 @@ + +import warnings from torch import Tensor from .batchnorm import _LazyNormBase, _NormBase @@ -68,6 +70,17 @@ class _InstanceNorm(_NormBase): def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) + feature_dim = input.dim() - self._get_no_batch_dim() + if input.size(feature_dim) != self.num_features: + if self.affine: + raise ValueError( + f"expected input's size at dim={feature_dim} to match num_features" + f" ({self.num_features}), but got: {input.size(feature_dim)}.") + else: + warnings.warn(f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False") + if input.dim() == self._get_no_batch_dim(): return self._handle_no_batch_input(input)