mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Improve error message for instance norm when channels is incorrect (#94624)
Fixes https://github.com/pytorch/pytorch/issues/90514 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94624 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
436993d52b
commit
7ff9612e34
2 changed files with 30 additions and 0 deletions
|
|
@ -8119,6 +8119,23 @@ class TestNNDeviceType(NNTestCase):
|
|||
if self.device_type == 'cuda':
|
||||
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)
|
||||
|
||||
@parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__)
|
||||
@parametrize_test("no_batch_dim", [True, False])
|
||||
@parametrize_test("affine", [True, False])
|
||||
def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine):
|
||||
inst_norm = instance_norm_cls(4, affine=affine)
|
||||
size = [2] * inst_norm._get_no_batch_dim()
|
||||
if not no_batch_dim:
|
||||
size = [3] + size
|
||||
t = torch.randn(size)
|
||||
if affine:
|
||||
with self.assertRaisesRegex(ValueError, "expected input's size at dim="):
|
||||
inst_norm(t)
|
||||
else:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
inst_norm(t)
|
||||
self.assertIn("which is not used because affine=False", str(w[0].message))
|
||||
|
||||
def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device):
|
||||
x = torch.rand(10)[None, :, None]
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
|
||||
import warnings
|
||||
from torch import Tensor
|
||||
|
||||
from .batchnorm import _LazyNormBase, _NormBase
|
||||
|
|
@ -68,6 +70,17 @@ class _InstanceNorm(_NormBase):
|
|||
def forward(self, input: Tensor) -> Tensor:
|
||||
self._check_input_dim(input)
|
||||
|
||||
feature_dim = input.dim() - self._get_no_batch_dim()
|
||||
if input.size(feature_dim) != self.num_features:
|
||||
if self.affine:
|
||||
raise ValueError(
|
||||
f"expected input's size at dim={feature_dim} to match num_features"
|
||||
f" ({self.num_features}), but got: {input.size(feature_dim)}.")
|
||||
else:
|
||||
warnings.warn(f"input's size at dim={feature_dim} does not match num_features. "
|
||||
"You can silence this warning by not passing in num_features, "
|
||||
"which is not used because affine=False")
|
||||
|
||||
if input.dim() == self._get_no_batch_dim():
|
||||
return self._handle_no_batch_input(input)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue