Improve error message for instance norm when channels is incorrect (#94624)

Fixes https://github.com/pytorch/pytorch/issues/90514

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94624
Approved by: https://github.com/jbschlosser
This commit is contained in:
soulitzer 2023-03-03 16:04:03 -05:00 committed by PyTorch MergeBot
parent 436993d52b
commit 7ff9612e34
2 changed files with 30 additions and 0 deletions

View file

@ -8119,6 +8119,23 @@ class TestNNDeviceType(NNTestCase):
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)
@parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__)
@parametrize_test("no_batch_dim", [True, False])
@parametrize_test("affine", [True, False])
def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine):
inst_norm = instance_norm_cls(4, affine=affine)
size = [2] * inst_norm._get_no_batch_dim()
if not no_batch_dim:
size = [3] + size
t = torch.randn(size)
if affine:
with self.assertRaisesRegex(ValueError, "expected input's size at dim="):
inst_norm(t)
else:
with warnings.catch_warnings(record=True) as w:
inst_norm(t)
self.assertIn("which is not used because affine=False", str(w[0].message))
def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device):
x = torch.rand(10)[None, :, None]
with self.assertRaises(ValueError):

View file

@ -1,3 +1,5 @@
import warnings
from torch import Tensor
from .batchnorm import _LazyNormBase, _NormBase
@ -68,6 +70,17 @@ class _InstanceNorm(_NormBase):
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
feature_dim = input.dim() - self._get_no_batch_dim()
if input.size(feature_dim) != self.num_features:
if self.affine:
raise ValueError(
f"expected input's size at dim={feature_dim} to match num_features"
f" ({self.num_features}), but got: {input.size(feature_dim)}.")
else:
warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
"You can silence this warning by not passing in num_features, "
"which is not used because affine=False")
if input.dim() == self._get_no_batch_dim():
return self._handle_no_batch_input(input)