Feat: Updated torch.nn.Modules.set_submodules() (#127714)

modified:   torch/nn/modules/module.py

Implemented feature request by #127712.
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127714
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Yang Cao 2024-06-27 06:38:53 +00:00 committed by PyTorch MergeBot
parent c9798d123b
commit 9f29a2291c
2 changed files with 71 additions and 0 deletions

View file

@ -1435,6 +1435,20 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType',
lambda: getattr(net, fn)(None, l))
def test_set_submodule(self):
net = nn.Module()
net.t = nn.Module()
l = nn.Linear(1, 2)
target = "t.l"
net.set_submodule(target, l)
self.assertEqual(net.get_submodule(target), l)
l2 = nn.Linear(2, 1)
net.set_submodule(target, l2)
self.assertEqual(net.get_submodule(target), l2)
self.assertRaises(ValueError, net.set_submodule, "", l)
self.assertRaises(AttributeError, net.set_submodule, "a.l", l)
self.assertRaises(AttributeError, net.set_submodule, "t.l.l2", l2)
def test_module_to_argparse(self):
net = nn.Sequential(nn.Linear(3, 3))
cpu = torch.device('cpu')

View file

@ -724,6 +724,63 @@ class Module:
return mod
def set_submodule(self, target: str, module: "Module") -> None:
"""
Set the submodule given by ``target`` if it exists, otherwise throw an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To overide the ``Conv2d`` with a new submodule ``Linear``, you
would call
``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
Raises:
ValueError: If the target string is empty
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
"""
if target == "":
raise ValueError("Cannot set the submodule without a target name!")
atoms: List[str] = target.split(".")
name = atoms.pop(-1)
mod: torch.nn.Module = self
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no attribute `" + item + "`"
)
mod = getattr(mod, item)
if type(mod) is not torch.nn.Module:
raise AttributeError("`" + item + "` is not an nn.Module")
setattr(mod, name, module)
def get_parameter(self, target: str) -> "Parameter":
"""Return the parameter given by ``target`` if it exists, otherwise throw an error.