From ba962fefeac0955c31b4e1b9e64ef4ace74ea67f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 5 Jun 2023 18:54:50 -0700 Subject: [PATCH] Add parametrization version of weight_norm (#103001) This done in the ordinary way, but also: * Deprecation warning for the old API, and a migration guide * Backwards compatibility for state_dict loading the old weight_norm * Test for pickling and deepcopy, which was the motivating reason weight_norm is still used by HuggingFace Wav2Vec2. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/103001 Approved by: https://github.com/albanD --- test/nn/test_parametrization.py | 56 +++++++++++++++++++++ torch/nn/utils/parametrizations.py | 80 +++++++++++++++++++++++++++++- torch/nn/utils/weight_norm.py | 24 +++++++++ 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index 0ba361d310d..1e7d8c77ebf 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -1,6 +1,7 @@ # Owner(s): ["module: nn"] from copy import deepcopy from itertools import product +import pickle import torch @@ -1519,6 +1520,61 @@ class TestNNParametrization(NNTestCase): torch.nn.utils.parametrize.remove_parametrizations(m, "weight") + def test_weight_norm_parametrization(self): + for dtype in [torch.float, torch.bfloat16]: + input = torch.randn(3, 4, dtype=dtype) + m = nn.Linear(4, 5).to(dtype=dtype) + expected_output = m(input) + + # add weight normalization + m = torch.nn.utils.parametrizations.weight_norm(m) + self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size()) + self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1)) + self.assertEqual(m(input), expected_output) + + # remove weight norm + torch.nn.utils.parametrize.remove_parametrizations(m, "weight") + self.assertFalse(hasattr(m, "parametrizations")) + self.assertEqual(m(input), expected_output) + + # test with dim=1 + m = torch.nn.utils.parametrizations.weight_norm(m, dim=1) + self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size()) + self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4)) + self.assertEqual(m(input), expected_output) + + # test with dim=None + m = nn.Linear(4, 5).to(dtype=dtype) + expected_output = m(input) + m = torch.nn.utils.parametrizations.weight_norm(m, dim=None) + self.assertEqual(m(input), expected_output) + + def test_weight_norm_state_dict_compat(self): + m = nn.Linear(4, 5) + m = torch.nn.utils.weight_norm(m) + old_dict = m.state_dict() + + m2 = nn.Linear(4, 5) + m2 = torch.nn.utils.parametrizations.weight_norm(m2) + m2.load_state_dict(old_dict) + + input = torch.randn(3, 4) + self.assertEqual(m(input), m2(input)) + + def test_weight_norm_pickle(self): + m = nn.Linear(4, 5) + m = torch.nn.utils.parametrizations.weight_norm(m) + with self.assertRaisesRegex(RuntimeError, 'state_dict'): + pickle.dumps(m) + + def test_weight_norm_deepcopy(self): + m = nn.Linear(4, 5) + m = torch.nn.utils.parametrizations.weight_norm(m) + m2 = deepcopy(m) + input = torch.randn(3, 4) + self.assertEqual(m(input), m2(input)) + + instantiate_parametrized_tests(TestNNParametrization) if __name__ == '__main__': diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 62dcf291892..9c32876b0eb 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -8,7 +8,7 @@ from .. import functional as F from typing import Optional -__all__ = ['orthogonal', 'spectral_norm'] +__all__ = ['orthogonal', 'spectral_norm', 'weight_norm'] def _is_orthogonal(Q, eps=None): @@ -285,6 +285,84 @@ def orthogonal(module: Module, return module +class _WeightNorm(Module): + def __init__( + self, + dim: Optional[int] = 0, + ) -> None: + super().__init__() + if dim is None: + dim = -1 + self.dim = dim + + def forward(self, weight_g, weight_v): + return torch._weight_norm(weight_v, weight_g, self.dim) + + def right_inverse(self, weight): + weight_g = torch.norm_except_dim(weight, 2, self.dim) + weight_v = weight + + return weight_g, weight_v + + +def weight_norm(module: Module, name: str = 'weight', dim: int = 0): + r"""Applies weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` with two parameters: one specifying the magnitude + and one specifying the direction. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _WeightNorm() + ) + ) + ) + >>> m.parametrizations.weight.original0.size() + torch.Size([40, 1]) + >>> m.parametrizations.weight.original1.size() + torch.Size([40, 20]) + + """ + _weight_norm = _WeightNorm(dim) + parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) + + def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + g_key = f"{prefix}{name}_g" + v_key = f"{prefix}{name}_v" + if g_key in state_dict and v_key in state_dict: + original0 = state_dict.pop(g_key) + original1 = state_dict.pop(v_key) + state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 + state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 + module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) + return module + + class _SpectralNorm(Module): def __init__( self, diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index a56a5b15018..07f07070248 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -4,6 +4,7 @@ Weight Normalization from https://arxiv.org/abs/1602.07868 from torch.nn.parameter import Parameter, UninitializedParameter from torch import _weight_norm, norm_except_dim from typing import Any, TypeVar +import warnings from ..modules import Module __all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm'] @@ -26,6 +27,8 @@ class WeightNorm: @staticmethod def apply(module, name: str, dim: int) -> 'WeightNorm': + warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.") + for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, WeightNorm) and hook.name == name: raise RuntimeError("Cannot register two weight_norm hooks on " @@ -87,6 +90,27 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul See https://arxiv.org/abs/1602.07868 + .. warning:: + + This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` + which uses the modern parametrization API. The new ``weight_norm`` is compatible + with ``state_dict`` generated from old ``weight_norm``. + + Migration guide: + + * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. If this is bothering you, please comment on + https://github.com/pytorch/pytorch/issues/102999 + + * To remove the weight normalization reparametrization, use + :func:`torch.nn.utils.parametrize.remove_parametrizations`. + + * The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + :func:`torch.nn.utils.parametrize.cached` before invoking the module + in question. + Args: module (Module): containing module name (str, optional): name of weight parameter