mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Add parametrization version of weight_norm (#103001)
This done in the ordinary way, but also: * Deprecation warning for the old API, and a migration guide * Backwards compatibility for state_dict loading the old weight_norm * Test for pickling and deepcopy, which was the motivating reason weight_norm is still used by HuggingFace Wav2Vec2. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/103001 Approved by: https://github.com/albanD
This commit is contained in:
parent
3a38acf18f
commit
ba962fefea
3 changed files with 159 additions and 1 deletions
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["module: nn"]
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -1519,6 +1520,61 @@ class TestNNParametrization(NNTestCase):
|
|||
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
|
||||
|
||||
|
||||
def test_weight_norm_parametrization(self):
|
||||
for dtype in [torch.float, torch.bfloat16]:
|
||||
input = torch.randn(3, 4, dtype=dtype)
|
||||
m = nn.Linear(4, 5).to(dtype=dtype)
|
||||
expected_output = m(input)
|
||||
|
||||
# add weight normalization
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m)
|
||||
self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size())
|
||||
self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
|
||||
self.assertEqual(m(input), expected_output)
|
||||
|
||||
# remove weight norm
|
||||
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
|
||||
self.assertFalse(hasattr(m, "parametrizations"))
|
||||
self.assertEqual(m(input), expected_output)
|
||||
|
||||
# test with dim=1
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
|
||||
self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size())
|
||||
self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
|
||||
self.assertEqual(m(input), expected_output)
|
||||
|
||||
# test with dim=None
|
||||
m = nn.Linear(4, 5).to(dtype=dtype)
|
||||
expected_output = m(input)
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
|
||||
self.assertEqual(m(input), expected_output)
|
||||
|
||||
def test_weight_norm_state_dict_compat(self):
|
||||
m = nn.Linear(4, 5)
|
||||
m = torch.nn.utils.weight_norm(m)
|
||||
old_dict = m.state_dict()
|
||||
|
||||
m2 = nn.Linear(4, 5)
|
||||
m2 = torch.nn.utils.parametrizations.weight_norm(m2)
|
||||
m2.load_state_dict(old_dict)
|
||||
|
||||
input = torch.randn(3, 4)
|
||||
self.assertEqual(m(input), m2(input))
|
||||
|
||||
def test_weight_norm_pickle(self):
|
||||
m = nn.Linear(4, 5)
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m)
|
||||
with self.assertRaisesRegex(RuntimeError, 'state_dict'):
|
||||
pickle.dumps(m)
|
||||
|
||||
def test_weight_norm_deepcopy(self):
|
||||
m = nn.Linear(4, 5)
|
||||
m = torch.nn.utils.parametrizations.weight_norm(m)
|
||||
m2 = deepcopy(m)
|
||||
input = torch.randn(3, 4)
|
||||
self.assertEqual(m(input), m2(input))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNNParametrization)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from .. import functional as F
|
|||
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ['orthogonal', 'spectral_norm']
|
||||
__all__ = ['orthogonal', 'spectral_norm', 'weight_norm']
|
||||
|
||||
|
||||
def _is_orthogonal(Q, eps=None):
|
||||
|
|
@ -285,6 +285,84 @@ def orthogonal(module: Module,
|
|||
return module
|
||||
|
||||
|
||||
class _WeightNorm(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: Optional[int] = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if dim is None:
|
||||
dim = -1
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, weight_g, weight_v):
|
||||
return torch._weight_norm(weight_v, weight_g, self.dim)
|
||||
|
||||
def right_inverse(self, weight):
|
||||
weight_g = torch.norm_except_dim(weight, 2, self.dim)
|
||||
weight_v = weight
|
||||
|
||||
return weight_g, weight_v
|
||||
|
||||
|
||||
def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
|
||||
r"""Applies weight normalization to a parameter in the given module.
|
||||
|
||||
.. math::
|
||||
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
|
||||
|
||||
Weight normalization is a reparameterization that decouples the magnitude
|
||||
of a weight tensor from its direction. This replaces the parameter specified
|
||||
by :attr:`name` with two parameters: one specifying the magnitude
|
||||
and one specifying the direction.
|
||||
|
||||
By default, with ``dim=0``, the norm is computed independently per output
|
||||
channel/plane. To compute a norm over the entire weight tensor, use
|
||||
``dim=None``.
|
||||
|
||||
See https://arxiv.org/abs/1602.07868
|
||||
|
||||
Args:
|
||||
module (Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
dim (int, optional): dimension over which to compute the norm
|
||||
|
||||
Returns:
|
||||
The original module with the weight norm hook
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
|
||||
>>> m
|
||||
ParametrizedLinear(
|
||||
in_features=20, out_features=40, bias=True
|
||||
(parametrizations): ModuleDict(
|
||||
(weight): ParametrizationList(
|
||||
(0): _WeightNorm()
|
||||
)
|
||||
)
|
||||
)
|
||||
>>> m.parametrizations.weight.original0.size()
|
||||
torch.Size([40, 1])
|
||||
>>> m.parametrizations.weight.original1.size()
|
||||
torch.Size([40, 20])
|
||||
|
||||
"""
|
||||
_weight_norm = _WeightNorm(dim)
|
||||
parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
|
||||
|
||||
def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
g_key = f"{prefix}{name}_g"
|
||||
v_key = f"{prefix}{name}_v"
|
||||
if g_key in state_dict and v_key in state_dict:
|
||||
original0 = state_dict.pop(g_key)
|
||||
original1 = state_dict.pop(v_key)
|
||||
state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
|
||||
state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
|
||||
module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
|
||||
return module
|
||||
|
||||
|
||||
class _SpectralNorm(Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Weight Normalization from https://arxiv.org/abs/1602.07868
|
|||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
from torch import _weight_norm, norm_except_dim
|
||||
from typing import Any, TypeVar
|
||||
import warnings
|
||||
from ..modules import Module
|
||||
|
||||
__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
|
||||
|
|
@ -26,6 +27,8 @@ class WeightNorm:
|
|||
|
||||
@staticmethod
|
||||
def apply(module, name: str, dim: int) -> 'WeightNorm':
|
||||
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
|
||||
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, WeightNorm) and hook.name == name:
|
||||
raise RuntimeError("Cannot register two weight_norm hooks on "
|
||||
|
|
@ -87,6 +90,27 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
|
|||
|
||||
See https://arxiv.org/abs/1602.07868
|
||||
|
||||
.. warning::
|
||||
|
||||
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
|
||||
which uses the modern parametrization API. The new ``weight_norm`` is compatible
|
||||
with ``state_dict`` generated from old ``weight_norm``.
|
||||
|
||||
Migration guide:
|
||||
|
||||
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
|
||||
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
|
||||
respectively. If this is bothering you, please comment on
|
||||
https://github.com/pytorch/pytorch/issues/102999
|
||||
|
||||
* To remove the weight normalization reparametrization, use
|
||||
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
|
||||
|
||||
* The weight is no longer recomputed once at module forward; instead, it will
|
||||
be recomputed on every access. To restore the old behavior, use
|
||||
:func:`torch.nn.utils.parametrize.cached` before invoking the module
|
||||
in question.
|
||||
|
||||
Args:
|
||||
module (Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
|
|
|
|||
Loading…
Reference in a new issue