Add parametrization version of weight_norm (#103001)

This done in the ordinary way, but also:

* Deprecation warning for the old API, and a migration guide
* Backwards compatibility for state_dict loading the old weight_norm
* Test for pickling and deepcopy, which was the motivating reason

weight_norm is still used by HuggingFace Wav2Vec2.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103001
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2023-06-05 18:54:50 -07:00 committed by PyTorch MergeBot
parent 3a38acf18f
commit ba962fefea
3 changed files with 159 additions and 1 deletions

View file

@ -1,6 +1,7 @@
# Owner(s): ["module: nn"]
from copy import deepcopy
from itertools import product
import pickle
import torch
@ -1519,6 +1520,61 @@ class TestNNParametrization(NNTestCase):
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
def test_weight_norm_parametrization(self):
for dtype in [torch.float, torch.bfloat16]:
input = torch.randn(3, 4, dtype=dtype)
m = nn.Linear(4, 5).to(dtype=dtype)
expected_output = m(input)
# add weight normalization
m = torch.nn.utils.parametrizations.weight_norm(m)
self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size())
self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
self.assertEqual(m(input), expected_output)
# remove weight norm
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
self.assertFalse(hasattr(m, "parametrizations"))
self.assertEqual(m(input), expected_output)
# test with dim=1
m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size())
self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
self.assertEqual(m(input), expected_output)
# test with dim=None
m = nn.Linear(4, 5).to(dtype=dtype)
expected_output = m(input)
m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
self.assertEqual(m(input), expected_output)
def test_weight_norm_state_dict_compat(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.weight_norm(m)
old_dict = m.state_dict()
m2 = nn.Linear(4, 5)
m2 = torch.nn.utils.parametrizations.weight_norm(m2)
m2.load_state_dict(old_dict)
input = torch.randn(3, 4)
self.assertEqual(m(input), m2(input))
def test_weight_norm_pickle(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.parametrizations.weight_norm(m)
with self.assertRaisesRegex(RuntimeError, 'state_dict'):
pickle.dumps(m)
def test_weight_norm_deepcopy(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.parametrizations.weight_norm(m)
m2 = deepcopy(m)
input = torch.randn(3, 4)
self.assertEqual(m(input), m2(input))
instantiate_parametrized_tests(TestNNParametrization)
if __name__ == '__main__':

View file

@ -8,7 +8,7 @@ from .. import functional as F
from typing import Optional
__all__ = ['orthogonal', 'spectral_norm']
__all__ = ['orthogonal', 'spectral_norm', 'weight_norm']
def _is_orthogonal(Q, eps=None):
@ -285,6 +285,84 @@ def orthogonal(module: Module,
return module
class _WeightNorm(Module):
def __init__(
self,
dim: Optional[int] = 0,
) -> None:
super().__init__()
if dim is None:
dim = -1
self.dim = dim
def forward(self, weight_g, weight_v):
return torch._weight_norm(weight_v, weight_g, self.dim)
def right_inverse(self, weight):
weight_g = torch.norm_except_dim(weight, 2, self.dim)
weight_v = weight
return weight_g, weight_v
def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
r"""Applies weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` with two parameters: one specifying the magnitude
and one specifying the direction.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _WeightNorm()
)
)
)
>>> m.parametrizations.weight.original0.size()
torch.Size([40, 1])
>>> m.parametrizations.weight.original1.size()
torch.Size([40, 20])
"""
_weight_norm = _WeightNorm(dim)
parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
g_key = f"{prefix}{name}_g"
v_key = f"{prefix}{name}_v"
if g_key in state_dict and v_key in state_dict:
original0 = state_dict.pop(g_key)
original1 = state_dict.pop(v_key)
state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
return module
class _SpectralNorm(Module):
def __init__(
self,

View file

@ -4,6 +4,7 @@ Weight Normalization from https://arxiv.org/abs/1602.07868
from torch.nn.parameter import Parameter, UninitializedParameter
from torch import _weight_norm, norm_except_dim
from typing import Any, TypeVar
import warnings
from ..modules import Module
__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
@ -26,6 +27,8 @@ class WeightNorm:
@staticmethod
def apply(module, name: str, dim: int) -> 'WeightNorm':
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError("Cannot register two weight_norm hooks on "
@ -87,6 +90,27 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
See https://arxiv.org/abs/1602.07868
.. warning::
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
which uses the modern parametrization API. The new ``weight_norm`` is compatible
with ``state_dict`` generated from old ``weight_norm``.
Migration guide:
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
respectively. If this is bothering you, please comment on
https://github.com/pytorch/pytorch/issues/102999
* To remove the weight normalization reparametrization, use
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
* The weight is no longer recomputed once at module forward; instead, it will
be recomputed on every access. To restore the old behavior, use
:func:`torch.nn.utils.parametrize.cached` before invoking the module
in question.
Args:
module (Module): containing module
name (str, optional): name of weight parameter