mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Feat: Updated torch.nn.Modules.set_submodules() (#127714)
modified: torch/nn/modules/module.py Implemented feature request by #127712. Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/127714 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
c9798d123b
commit
9f29a2291c
2 changed files with 71 additions and 0 deletions
|
|
@ -1435,6 +1435,20 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType',
|
||||
lambda: getattr(net, fn)(None, l))
|
||||
|
||||
def test_set_submodule(self):
|
||||
net = nn.Module()
|
||||
net.t = nn.Module()
|
||||
l = nn.Linear(1, 2)
|
||||
target = "t.l"
|
||||
net.set_submodule(target, l)
|
||||
self.assertEqual(net.get_submodule(target), l)
|
||||
l2 = nn.Linear(2, 1)
|
||||
net.set_submodule(target, l2)
|
||||
self.assertEqual(net.get_submodule(target), l2)
|
||||
self.assertRaises(ValueError, net.set_submodule, "", l)
|
||||
self.assertRaises(AttributeError, net.set_submodule, "a.l", l)
|
||||
self.assertRaises(AttributeError, net.set_submodule, "t.l.l2", l2)
|
||||
|
||||
def test_module_to_argparse(self):
|
||||
net = nn.Sequential(nn.Linear(3, 3))
|
||||
cpu = torch.device('cpu')
|
||||
|
|
|
|||
|
|
@ -724,6 +724,63 @@ class Module:
|
|||
|
||||
return mod
|
||||
|
||||
def set_submodule(self, target: str, module: "Module") -> None:
|
||||
"""
|
||||
Set the submodule given by ``target`` if it exists, otherwise throw an error.
|
||||
|
||||
For example, let's say you have an ``nn.Module`` ``A`` that
|
||||
looks like this:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
A(
|
||||
(net_b): Module(
|
||||
(net_c): Module(
|
||||
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
|
||||
)
|
||||
(linear): Linear(in_features=100, out_features=200, bias=True)
|
||||
)
|
||||
)
|
||||
|
||||
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
|
||||
submodule ``net_b``, which itself has two submodules ``net_c``
|
||||
and ``linear``. ``net_c`` then has a submodule ``conv``.)
|
||||
|
||||
To overide the ``Conv2d`` with a new submodule ``Linear``, you
|
||||
would call
|
||||
``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
|
||||
|
||||
Args:
|
||||
target: The fully-qualified string name of the submodule
|
||||
to look for. (See above example for how to specify a
|
||||
fully-qualified string.)
|
||||
module: The module to set the submodule to.
|
||||
|
||||
Raises:
|
||||
ValueError: If the target string is empty
|
||||
AttributeError: If the target string references an invalid
|
||||
path or resolves to something that is not an
|
||||
``nn.Module``
|
||||
"""
|
||||
if target == "":
|
||||
raise ValueError("Cannot set the submodule without a target name!")
|
||||
|
||||
atoms: List[str] = target.split(".")
|
||||
name = atoms.pop(-1)
|
||||
mod: torch.nn.Module = self
|
||||
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(
|
||||
mod._get_name() + " has no attribute `" + item + "`"
|
||||
)
|
||||
|
||||
mod = getattr(mod, item)
|
||||
|
||||
if type(mod) is not torch.nn.Module:
|
||||
raise AttributeError("`" + item + "` is not an nn.Module")
|
||||
setattr(mod, name, module)
|
||||
|
||||
def get_parameter(self, target: str) -> "Parameter":
|
||||
"""Return the parameter given by ``target`` if it exists, otherwise throw an error.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue