From 5ba6bb7b2fb96fdc8fd8cd4fa6c0e8e4babab6ba Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 18 Apr 2024 09:53:32 -0700 Subject: [PATCH] Add swap_tensors path to nn parametrizations (#124130) Fixes #123859 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124130 Approved by: https://github.com/albanD --- ...deepcopy_after_parametrization_swap_False} | 0 ..._deepcopy_after_parametrization_swap_True} | 0 ...metrized_tensor_parametrization_swap_False | 0 ...ametrized_tensor_parametrization_swap_True | 0 ...t_initialization_parametrization_swap_True | 0 ...nd_remove_buffer_parametrization_swap_True | 0 ...nd_remove_nested_parametrization_swap_True | 0 ...ister_and_remove_parametrization_swap_True | 0 ...st_serialization_parametrization_swap_True | 0 ...zations_and_params_right_inverse_swap_True | 0 ...sfer_parametrizations_and_params_swap_True | 0 ...wrapper_subclass_parametrization_swap_True | 0 ...ation.test_new_spectral_norm_dim_swap_True | 0 test/nn/test_parametrization.py | 213 +++++++++++++++++- torch/nn/utils/parametrize.py | 28 ++- 15 files changed, 226 insertions(+), 15 deletions(-) rename test/dynamo_expected_failures/{TestNNParametrization.test_deepcopy_after_parametrization => TestNNParametrization.test_deepcopy_after_parametrization_swap_False} (100%) rename test/dynamo_expected_failures/{TestNNParametrization.test_errors_unparametrized_tensor_parametrization => TestNNParametrization.test_deepcopy_after_parametrization_swap_True} (100%) create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_initialization_parametrization_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_buffer_parametrization_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_nested_parametrization_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_parametrization_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_serialization_parametrization_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_right_inverse_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_swap_True create mode 100644 test/dynamo_expected_failures/TestNNParametrization.test_wrapper_subclass_parametrization_swap_True create mode 100644 test/dynamo_skips/TestNNParametrization.test_new_spectral_norm_dim_swap_True diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_deepcopy_after_parametrization b/test/dynamo_expected_failures/TestNNParametrization.test_deepcopy_after_parametrization_swap_False similarity index 100% rename from test/dynamo_expected_failures/TestNNParametrization.test_deepcopy_after_parametrization rename to test/dynamo_expected_failures/TestNNParametrization.test_deepcopy_after_parametrization_swap_False diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization b/test/dynamo_expected_failures/TestNNParametrization.test_deepcopy_after_parametrization_swap_True similarity index 100% rename from test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization rename to test/dynamo_expected_failures/TestNNParametrization.test_deepcopy_after_parametrization_swap_True diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_initialization_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_initialization_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_buffer_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_buffer_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_nested_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_nested_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_serialization_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_serialization_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_right_inverse_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_right_inverse_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_wrapper_subclass_parametrization_swap_True b/test/dynamo_expected_failures/TestNNParametrization.test_wrapper_subclass_parametrization_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_skips/TestNNParametrization.test_new_spectral_norm_dim_swap_True b/test/dynamo_skips/TestNNParametrization.test_new_spectral_norm_dim_swap_True new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index 8203c2e275d..1f7b569e864 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -9,6 +9,8 @@ import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init import torch.nn.utils.parametrize as parametrize +from torch import Tensor +from torch.__future__ import get_swap_module_params_on_conversion from torch.nn import Parameter from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -20,8 +22,10 @@ from torch.testing._internal.common_utils import ( set_default_dtype, skipIfNoLapack, skipIfTorchDynamo, + swap, TemporaryFileName, ) +from torch.testing._internal.two_tensor import TwoTensor class TestNNParametrization(NNTestCase): @@ -32,6 +36,7 @@ class TestNNParametrization(NNTestCase): # and remove the `@skipIfNoLapack` (see #70995) # torch/nn/utils/parametrize @skipIfNoLapack + @swap([True, False]) def test_register_and_remove_parametrization(self): r"""Test that it is possible to add a few parametrizations on a parameter or a buffer and that removing them restores the initial state @@ -94,8 +99,7 @@ class TestNNParametrization(NNTestCase): self.assertTrue(parametrize.is_parametrized(model, "weight")) self.assertFalse(parametrize.is_parametrized(model, "bias")) self.assertNotIn("weight", model._parameters) - A = model.weight - self.assertTrue(A.shape[0] == 1) + self.assertTrue(model.weight.shape[0] == 1) parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) self.assertFalse(hasattr(model, "parametrizations")) self.assertEqual(model.weight, initial_model.weight) @@ -110,8 +114,7 @@ class TestNNParametrization(NNTestCase): self.assertTrue(parametrize.is_parametrized(model, "weight")) self.assertFalse(parametrize.is_parametrized(model, "bias")) self.assertNotIn("weight", model._parameters) - A = model.weight - self.assertTrue(A.shape[0] == 1) + self.assertTrue(model.weight.shape[0] == 1) parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) self.assertFalse(hasattr(model, "parametrizations")) self.assertEqual(model.weight, initial_model.weight) @@ -128,6 +131,10 @@ class TestNNParametrization(NNTestCase): # Result should be skew-symmetric A = model.weight self.assertEqual(A, -A.T) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del A # Remove and check consistency parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) self.assertFalse(hasattr(model, "parametrizations")) @@ -145,6 +152,10 @@ class TestNNParametrization(NNTestCase): # Result should be skew-symmetric A = model.weight self.assertEqual(A, -A.T) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del A # Remove and check consistency parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) self.assertFalse(hasattr(model, "parametrizations")) @@ -159,6 +170,10 @@ class TestNNParametrization(NNTestCase): X = model.weight Id = torch.eye(X.size(0), device=X.device) self.assertEqual(X.T @ X, Id) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del X # Structure tests self.assertTrue(hasattr(model, "parametrizations")) self.assertTrue(parametrize.is_parametrized(model)) @@ -246,6 +261,10 @@ class TestNNParametrization(NNTestCase): sgd.step() self.assertNotEqual(model.weight, weight_copy) self.assertNotEqual(model.bias, bias_copy) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del weight_copy, bias_copy # Test leave_parametrized=True for _ in range(2): @@ -266,7 +285,12 @@ class TestNNParametrization(NNTestCase): sgd.step() self.assertNotEqual(model.weight, weight_copy) self.assertNotEqual(model.bias, bias_copy) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del weight_copy, bias_copy + @swap([True, False]) def test_register_and_remove_nested_parametrization(self): r"""Test that it is possible to nest the parametrizations meaning that the original param is parametrized again @@ -288,6 +312,10 @@ class TestNNParametrization(NNTestCase): # Result should be skew-symmetric A = model.weight self.assertEqual(A, -A.T) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del A # Add nested parametrization param_mod = model.parametrizations.weight @@ -316,6 +344,7 @@ class TestNNParametrization(NNTestCase): self.assertFalse(hasattr(model, "parametrizations")) self.assertEqual(model.__class__, nn.Linear) + @swap([True, False]) def test_register_and_remove_buffer_parametrization(self): r"""Test that it is possible to add and remove parametrizations on buffers""" @@ -354,6 +383,7 @@ class TestNNParametrization(NNTestCase): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @swap([True, False]) def test_serialization_parametrization(self): r"""Test that it is possible to serialize a parametrized model via state_dict""" @@ -403,6 +433,7 @@ class TestNNParametrization(NNTestCase): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @swap([True, False]) def test_initialization_parametrization(self): r"""Test that it is possible to initialize a parametrization when it implements a `right_inverse` method @@ -472,6 +503,7 @@ class TestNNParametrization(NNTestCase): self.assertEqual(model.weight, X) self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) + @swap([True, False]) def test_errors_unparametrized_tensor_parametrization(self): # Test errors when registering a parametrization on an unparametrized tensor module = nn.Linear(3, 4) @@ -621,6 +653,7 @@ class TestNNParametrization(NNTestCase): self.assertFalse(parametrize.is_parametrized(module)) self.assertEqual(module.weight, weight_init) + @swap([True, False]) def test_errors_parametrized_tensor_parametrization(self): # Test errors when registering a parametrization on a parametrized tensor @@ -702,6 +735,7 @@ class TestNNParametrization(NNTestCase): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @swap([True, False]) def test_multiple_inputs_parametrization(self): # A parametrization with several outputs class RankOne(nn.Module): @@ -803,6 +837,7 @@ class TestNNParametrization(NNTestCase): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @swap([True, False]) def test_caching_parametrization(self): r"""Test the caching system of a parametrization""" @@ -830,6 +865,7 @@ class TestNNParametrization(NNTestCase): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @swap([True, False]) def test_caching_parametrization_with_transfer_parametrizations_and_params(self): r"""Test that transferring parametrizations doesn't cause issues with caching""" @@ -862,6 +898,7 @@ class TestNNParametrization(NNTestCase): # test that the results are distinct objects for each module self.assertNotEqual(id(A), id(X)) + @swap([True, False]) def test_parametrization_same_training_mode(self): r"""Test training mode updated on parametrization registration""" @@ -878,6 +915,7 @@ class TestNNParametrization(NNTestCase): self.assertTrue(module.parametrizations.weight[0].training) self.assertTrue(module.parametrizations.weight[1].training) + @swap([True, False]) def test_type_before_parametrizations(self): r"""Test that type_before_parametrizations always retrieves original type""" @@ -895,6 +933,7 @@ class TestNNParametrization(NNTestCase): parametrize.type_before_parametrizations(model) == original_type ) + @swap([True, False]) def test_deepcopy_after_parametrization(self): r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" @@ -955,6 +994,7 @@ class TestNNParametrization(NNTestCase): parametrize.register_parametrization(model, "weight", AddOne()) check_deepcopy(model, deepcopy(model)) + @swap([True, False]) def test_transfer_parametrizations_and_params(self): r"""Test that all parametrizations and their associated parameters are transferred.""" @@ -994,6 +1034,10 @@ class TestNNParametrization(NNTestCase): # check that the transfer didn't affect the original value self.assertEqual(hold_weight, model.weight) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + del hold_weight # testing that changes to one set of parametrizations do not affect the other parametrize.remove_parametrizations(to_model, "weight") @@ -1018,6 +1062,7 @@ class TestNNParametrization(NNTestCase): # check that the new transfer didn't change the value for the from_module self.assertEqual(hold_test_param, model.test_param) + @swap([True, False]) def test_transfer_parametrizations_and_params_right_inverse(self): r"""Test that all parametrizations and their associated parameters are transferred.""" @@ -1047,6 +1092,7 @@ class TestNNParametrization(NNTestCase): # check that transfer doesn't affect the from_model weight self.assertEqual(hold_weight, model.weight) + @swap([True, False]) def test_transfer_parametrizations_and_params_single_param(self): r"""Test that all parametrizations and their associated parameters are transferred.""" @@ -1086,6 +1132,7 @@ class TestNNParametrization(NNTestCase): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @swap([True, False]) def test_transfer_parametrizations_and_params_many_to_one(self): # A parametrization with several outputs class RankOne(nn.Module): @@ -1152,6 +1199,7 @@ class TestNNParametrization(NNTestCase): # check that the new transfer didn't change the value for the from_module self.assertEqual(hold_test_param, model.test_param) + @swap([True, False]) def test_new_spectral_norm(self): with set_default_dtype(torch.double): input = torch.randn(3, 5) @@ -1289,16 +1337,30 @@ class TestNNParametrization(NNTestCase): # avoid doing another power iteration m, wrapped_m, _ = get_modules() pre_remove_out = wrapped_m(input) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + pre_remove_out_ref = pre_remove_out.detach() + del pre_remove_out + else: + pre_remove_out_ref = pre_remove_out m.eval() m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") - self.assertEqual(wrapped_m(input), pre_remove_out) + self.assertEqual(wrapped_m(input), pre_remove_out_ref) torch.nn.utils.parametrizations.spectral_norm(m) for _ in range(3): pre_remove_out = wrapped_m(input) + if get_swap_module_params_on_conversion(): + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + pre_remove_out_ref = pre_remove_out.detach() + del pre_remove_out + else: + pre_remove_out_ref = pre_remove_out m.eval() m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight") - self.assertEqual(wrapped_m(input), pre_remove_out) + self.assertEqual(wrapped_m(input), pre_remove_out_ref) # TEST EVAL BEHAVIOR m, wrapped_m, spectral_norm_m = get_modules() @@ -1352,6 +1414,7 @@ class TestNNParametrization(NNTestCase): gradcheck(fn, (m.parametrizations.weight.original,)) + @swap([True, False]) def test_new_spectral_norm_load_state_dict(self): for activate_times in (0, 3): inp = torch.randn(2, 3) @@ -1431,6 +1494,7 @@ class TestNNParametrization(NNTestCase): snm.eval() self.assertEqual(out3_eval, snm(inp)) + @swap([True, False]) def test_new_spectral_norm_dim(self): inp = torch.randn(2, 3, 10, 12) m = nn.ConvTranspose2d(3, 4, (5, 6)) @@ -1443,6 +1507,7 @@ class TestNNParametrization(NNTestCase): snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape ) + @swap([True, False]) def test_new_spectral_norm_forward(self): input = torch.randn(3, 5) m = nn.Linear(5, 7) @@ -1461,6 +1526,7 @@ class TestNNParametrization(NNTestCase): expect_out = m(input) self.assertEqual(expect_out, out_hat) + @swap([True, False]) @skipIfTorchDynamo("Test does not work with TorchDynamo") def test_new_spectral_norm_value(self): # a test that the spectral norm (= top singular value) @@ -1477,6 +1543,7 @@ class TestNNParametrization(NNTestCase): self.assertEqual(m.weight.data, expected) @skipIfNoLapack + @swap([True, False]) def test_orthogonal_parametrization(self): # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) @@ -1532,7 +1599,13 @@ class TestNNParametrization(NNTestCase): # We do not support householder for complex inputs # See Note [Householder complex] - w_init = m.weight.clone() + + # When using the swap_tensors path, this is needed so that the autograd + # graph is not alive anymore. + if get_swap_module_params_on_conversion(): + w_init = m.weight.clone().detach() + else: + w_init = m.weight.clone() if parametrization == "householder" and m.weight.is_complex(): msg = "householder parametrization does not support complex tensors" with self.assertRaisesRegex(ValueError, msg): @@ -1605,6 +1678,7 @@ class TestNNParametrization(NNTestCase): assert_is_orthogonal(m.weight) @skipIfNoLapack + @swap([True, False]) def test_orthogonal_errors(self): m = nn.Linear(3, 4) with self.assertRaisesRegex(ValueError, "has to be one of"): @@ -1618,6 +1692,7 @@ class TestNNParametrization(NNTestCase): m.weight = torch.randn(5, 5) torch.nn.utils.parametrize.remove_parametrizations(m, "weight") + @swap([True, False]) def test_weight_norm_state_dict_compat(self): m = nn.Linear(4, 5) m = torch.nn.utils.weight_norm(m) @@ -1630,12 +1705,14 @@ class TestNNParametrization(NNTestCase): input = torch.randn(3, 4) self.assertEqual(m(input), m2(input)) + @swap([True, False]) def test_weight_norm_pickle(self): m = nn.Linear(4, 5) m = torch.nn.utils.parametrizations.weight_norm(m) with self.assertRaisesRegex(RuntimeError, "state_dict"): pickle.dumps(m) + @swap([True, False]) def test_weight_norm_deepcopy(self): m = nn.Linear(4, 5) m = torch.nn.utils.parametrizations.weight_norm(m) @@ -1643,8 +1720,130 @@ class TestNNParametrization(NNTestCase): input = torch.randn(3, 4) self.assertEqual(m(input), m2(input)) + @swap([True]) + def test_wrapper_subclass_parametrization(self): + class Subclassify(nn.Module): + def forward(self, X): + return TwoTensor(X, X) + + class UnSubclassify(nn.Module): + def forward(self, X): + return X.a + + class IdentityWithRightInverse(nn.Module): + def forward(self, X): + return X + + def right_inverse(self, X): + return TwoTensor(X, X) + + def _check_parametrization( + parametrization, + type_before_registration, + type_after_registration, + leave_parametrized=False, + type_after_right_inverse=None, + ): + model = nn.Linear(2, 2) + buf = torch.randn(2, 2) + model.register_buffer("buf", buf) + if ( + type_before_registration == TwoTensor + and type_after_registration == Tensor + ): + model._apply(lambda t: TwoTensor(t, t)) + initial_weight = model.weight.clone().detach() + initial_weight_id = id(model.weight) + initial_buf = model.buf.clone().detach() + initial_buf_id = id(model.buf) + type_original_weight = ( + type_before_registration + if type_after_right_inverse is None + else type_after_right_inverse + ) + type_original_buf = ( + Tensor if type_original_weight is nn.Parameter else type_original_weight + ) + type_after_removal_buf = ( + type_after_registration if leave_parametrized else type_original_buf + ) + if leave_parametrized: + if type_after_registration is Tensor: + type_after_removal_weight = nn.Parameter + else: + type_after_removal_weight = type_after_registration + else: + type_after_removal_weight = type_original_weight + + parametrize.register_parametrization(model, "weight", parametrization()) + parametrize.register_parametrization(model, "buf", parametrization()) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + # checks for weight + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertTrue( + isinstance(model.parametrizations.weight.original, nn.Parameter) + ) + self.assertTrue( + type(model.parametrizations.weight.original) is type_original_weight + ) + self.assertNotIn("weight", model._parameters) + self.assertTrue(type(model.weight) is type_after_registration) + # checks for buf + self.assertTrue(parametrize.is_parametrized(model, "buf")) + self.assertFalse( + isinstance(model.parametrizations.buf.original, nn.Parameter) + ) + self.assertTrue( + type(model.parametrizations.buf.original) is type_original_buf + ) + self.assertTrue(type(model.buf) is type_after_registration) + parametrize.remove_parametrizations( + model, "weight", leave_parametrized=leave_parametrized + ) + parametrize.remove_parametrizations( + model, "buf", leave_parametrized=leave_parametrized + ) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + # checks for weight + self.assertTrue(type(model.weight) is type_after_removal_weight) + self.assertTrue(isinstance(model.weight, nn.Parameter)) + self.assertEqual(id(model.weight), initial_weight_id) + # checks for buf + self.assertTrue(type(model.buf) is type_after_removal_buf) + self.assertFalse(isinstance(model.buf, nn.Parameter)) + self.assertEqual(id(model.buf), initial_buf_id) + if not leave_parametrized and type_after_right_inverse is None: + self.assertEqual(model.weight, initial_weight) + self.assertEqual(model.buf, initial_buf) + + _check_parametrization(Subclassify, nn.Parameter, TwoTensor) + _check_parametrization(UnSubclassify, TwoTensor, Tensor) + _check_parametrization( + IdentityWithRightInverse, + nn.Parameter, + TwoTensor, + type_after_right_inverse=TwoTensor, + ) + _check_parametrization( + Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True + ) + _check_parametrization( + UnSubclassify, TwoTensor, Tensor, leave_parametrized=True + ) + _check_parametrization( + IdentityWithRightInverse, + nn.Parameter, + TwoTensor, + leave_parametrized=True, + type_after_right_inverse=TwoTensor, + ) + class TestNNParametrizationDevice(NNTestCase): + @swap([True, False]) def test_weight_norm_parametrization(self, device): for dtype in [torch.float, torch.bfloat16]: input = torch.randn(3, 4, dtype=dtype, device=device) diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index e73aada232a..aa4f9656d52 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -1,6 +1,8 @@ import torch +from torch.__future__ import get_swap_module_params_on_conversion from torch.nn.modules.container import ModuleList, ModuleDict, Module from torch.nn.parameter import Parameter +from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch import Tensor import collections @@ -64,6 +66,14 @@ def _register_parameter_or_buffer(module, name, X): else: module.register_buffer(name, X) +def _maybe_set(dest: Tensor, src: Tensor) -> None: + should_swap = get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest) + if should_swap: + if isinstance(dest, Parameter) and not isinstance(src, Parameter): + src = Parameter(src, requires_grad=dest.requires_grad) + torch.utils.swap_tensors(dest, src) + else: + dest.set_(src) # type: ignore[call-overload] class ParametrizationList(ModuleList): r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`. @@ -157,7 +167,7 @@ class ParametrizationList(ModuleList): # Set the original to original so that the user does not need to re-register the parameter # manually in the optimiser with torch.no_grad(): - original.set_(new) # type: ignore[call-overload] + _maybe_set(original, new) _register_parameter_or_buffer(self, "original", original) else: for i, originali in enumerate(new): @@ -231,7 +241,7 @@ class ParametrizationList(ModuleList): f"while `original` has dtype {self.original.dtype}" ) # We know that the result is going to have the same dtype - self.original.set_(value) # type: ignore[call-overload] + _maybe_set(self.original, value) else: if not isinstance(value, collections.abc.Sequence): raise ValueError( @@ -255,7 +265,7 @@ class ParametrizationList(ModuleList): f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " f"while `original{i}` has dtype {original_i.dtype}" ) - original_i.set_(tensor) + _maybe_set(original_i, tensor) def forward(self) -> Tensor: if torch.jit.is_scripting(): @@ -645,18 +655,20 @@ def remove_parametrizations( # This way the user does not need to update the optimizer with torch.no_grad(): if type(original) is torch.Tensor: - original.set_(t) + _maybe_set(original, t) else: try: - original.set_(t) + _maybe_set(original, t) except RuntimeError as e: # TODO: Fix this for tensor subclasses that are parameters: # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True " "for a parameter that is an instance of a tensor subclass requires " - "set_() to be implemented correctly for the tensor subclass. Either " - "set leave_parametrized=False or provide a working implementation for " - "set_() in the tensor subclass.") from e + "set_() to be implemented correctly for the tensor subclass." + "Alternatively, one can opt into the swap_tensors path" + "Either set leave_parametrized=False or provide a working implementation" + "for set_() in the tensor subclass or set " + "torch.__future__.set_swap_module_params_on_conversion(True).") from e else: if leave_parametrized: # We cannot use no_grad because we need to know whether one or more