mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add `buffers(), named_buffers()` methods. (#10554)
Summary: This commit adds the ``buffers()`` and ``named_buffers()`` methods as analogues of ``parameters()`` and ``named_parameters()``. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10554 Reviewed By: SsnL Differential Revision: D9367762 Pulled By: jma127 fbshipit-source-id: f2042e46a7e833dce40cb41681dbd80d7885c74e
This commit is contained in:
parent
342517e6e7
commit
afd7477eaa
6 changed files with 149 additions and 66 deletions
|
|
@ -525,6 +525,26 @@ class TestNN(NNTestCase):
|
|||
d_params.append(p.grad)
|
||||
return params, d_params
|
||||
|
||||
def _create_basic_net(self):
|
||||
class Layer(nn.Module):
|
||||
def __init__(self):
|
||||
super(Layer, self).__init__()
|
||||
self.layer_dummy_param = Parameter(torch.Tensor(3, 5))
|
||||
self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = Layer()
|
||||
self.dummy_param = Parameter(torch.Tensor(3, 5))
|
||||
self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
|
||||
|
||||
l = Layer()
|
||||
n = Net()
|
||||
s = nn.Sequential(n, n)
|
||||
|
||||
return l, n, s
|
||||
|
||||
def test_module_backcompat(self):
|
||||
from torch.serialization import SourceChangeWarning
|
||||
path = download_file('https://download.pytorch.org/test_data/linear.pt')
|
||||
|
|
@ -769,51 +789,57 @@ class TestNN(NNTestCase):
|
|||
self.assertLess(abs(output.data.std() - std), 0.1)
|
||||
output.backward(input)
|
||||
|
||||
def test_parameters(self):
|
||||
def num_params(module):
|
||||
return len(list(module.parameters()))
|
||||
def test_parameters_and_named_parameters(self):
|
||||
def names(named_parameters):
|
||||
return [k for k, _ in named_parameters]
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = l
|
||||
self.l2 = l
|
||||
self.param = Parameter(torch.Tensor(3, 5))
|
||||
l, n, s = self._create_basic_net()
|
||||
|
||||
l = nn.Linear(10, 20)
|
||||
n = Net()
|
||||
s = nn.Sequential(n, n, n, n)
|
||||
self.assertEqual(num_params(l), 2)
|
||||
self.assertEqual(num_params(n), 3)
|
||||
self.assertEqual(num_params(s), 3)
|
||||
self.assertEqual(len(list(l.parameters())), 1)
|
||||
self.assertEqual(
|
||||
names(l.named_parameters()),
|
||||
['layer_dummy_param'])
|
||||
|
||||
def test_named_parameters(self):
|
||||
def num_params(module):
|
||||
return len(dict(module.named_parameters()))
|
||||
self.assertEqual(len(list(n.parameters())), 2)
|
||||
self.assertEqual(
|
||||
names(n.named_parameters()),
|
||||
['dummy_param', 'l1.layer_dummy_param'])
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = l
|
||||
self.l2 = l
|
||||
self.param = Parameter(torch.Tensor(3, 5))
|
||||
self.assertEqual(len(list(n.parameters(recurse=False))), 1)
|
||||
self.assertEqual(
|
||||
names(n.named_parameters(recurse=False)),
|
||||
['dummy_param'])
|
||||
|
||||
l = nn.Linear(10, 20)
|
||||
n = Net()
|
||||
s = nn.Sequential(n, n, n, n)
|
||||
self.assertEqual(len(list(s.parameters())), 2)
|
||||
self.assertEqual(
|
||||
names(s.named_parameters()),
|
||||
['0.dummy_param', '0.l1.layer_dummy_param'])
|
||||
|
||||
for name in dict(l.named_parameters()).keys():
|
||||
self.assertTrue(name in ['bias', 'weight'])
|
||||
def test_buffers_and_named_buffers(self):
|
||||
def names(named_buffers):
|
||||
return [k for k, _ in named_buffers]
|
||||
|
||||
for name in dict(n.named_parameters()).keys():
|
||||
self.assertTrue(name in ['l1.bias', 'l1.weight', 'param'])
|
||||
l, n, s = self._create_basic_net()
|
||||
|
||||
for name in dict(s.named_parameters()).keys():
|
||||
self.assertTrue(name in ['0.l1.bias', '0.l1.weight', '0.param'])
|
||||
self.assertEqual(len(list(l.buffers())), 1)
|
||||
self.assertEqual(
|
||||
names(l.named_buffers()),
|
||||
['layer_dummy_buf'])
|
||||
|
||||
self.assertEqual(num_params(l), 2)
|
||||
self.assertEqual(num_params(n), 3)
|
||||
self.assertEqual(num_params(s), 3)
|
||||
self.assertEqual(len(list(n.buffers())), 2)
|
||||
self.assertEqual(
|
||||
names(n.named_buffers()),
|
||||
['dummy_buf', 'l1.layer_dummy_buf'])
|
||||
|
||||
self.assertEqual(len(list(n.buffers(recurse=False))), 1)
|
||||
self.assertEqual(
|
||||
names(n.named_buffers(recurse=False)),
|
||||
['dummy_buf'])
|
||||
|
||||
self.assertEqual(len(list(s.buffers())), 2)
|
||||
self.assertEqual(
|
||||
names(s.named_buffers()),
|
||||
['0.dummy_buf', '0.l1.layer_dummy_buf'])
|
||||
|
||||
def test_call_supports_python_dict_output(self):
|
||||
class Net(nn.Module):
|
||||
|
|
|
|||
|
|
@ -656,8 +656,9 @@ def _get_methods(cls):
|
|||
_compiled_methods_whitelist = {
|
||||
'forward', 'register_buffer', 'register_parameter', 'add_module',
|
||||
'_apply', 'apply', 'cuda', 'cpu', 'type', 'float', 'double', 'half',
|
||||
'state_dict', 'load_state_dict', '_load_from_state_dict', 'parameters',
|
||||
'named_parameters', '_all_buffers', 'children', 'named_children', 'modules',
|
||||
'state_dict', 'load_state_dict', '_load_from_state_dict',
|
||||
'_named_members', 'parameters', 'named_parameters',
|
||||
'buffers', 'named_buffers', 'children', 'named_children', 'modules',
|
||||
'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
|
||||
'_slow_forward', '_tracing_name'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -722,11 +722,29 @@ class Module(object):
|
|||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
def parameters(self):
|
||||
def _named_members(self, get_members_fn, prefix='', recurse=True):
|
||||
r"""Helper method for yielding various names + members of modules."""
|
||||
memo = set()
|
||||
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
|
||||
for module_prefix, module in modules:
|
||||
members = get_members_fn(module)
|
||||
for k, v in members:
|
||||
if v is None or v in memo:
|
||||
continue
|
||||
memo.add(v)
|
||||
name = module_prefix + ('.' if module_prefix else '') + k
|
||||
yield name, v
|
||||
|
||||
def parameters(self, recurse=True):
|
||||
r"""Returns an iterator over module parameters.
|
||||
|
||||
This is typically passed to an optimizer.
|
||||
|
||||
Args:
|
||||
recurse (bool): if True, then yields parameters of this module
|
||||
and all submodules. Otherwise, yields only parameters that
|
||||
are direct members of this module.
|
||||
|
||||
Yields:
|
||||
Parameter: module parameter
|
||||
|
||||
|
|
@ -738,12 +756,18 @@ class Module(object):
|
|||
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
||||
|
||||
"""
|
||||
for name, param in self.named_parameters():
|
||||
for name, param in self.named_parameters(recurse=recurse):
|
||||
yield param
|
||||
|
||||
def named_parameters(self, memo=None, prefix=''):
|
||||
def named_parameters(self, prefix='', recurse=True):
|
||||
r"""Returns an iterator over module parameters, yielding both the
|
||||
name of the parameter as well as the parameter itself
|
||||
name of the parameter as well as the parameter itself.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix to prepend to all parameter names.
|
||||
recurse (bool): if True, then yields parameters of this module
|
||||
and all submodules. Otherwise, yields only parameters that
|
||||
are direct members of this module.
|
||||
|
||||
Yields:
|
||||
(string, Parameter): Tuple containing the name and parameter
|
||||
|
|
@ -755,27 +779,59 @@ class Module(object):
|
|||
>>> print(param.size())
|
||||
|
||||
"""
|
||||
if memo is None:
|
||||
memo = set()
|
||||
for name, p in self._parameters.items():
|
||||
if p is not None and p not in memo:
|
||||
memo.add(p)
|
||||
yield prefix + ('.' if prefix else '') + name, p
|
||||
for mname, module in self.named_children():
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + mname
|
||||
for name, p in module.named_parameters(memo, submodule_prefix):
|
||||
yield name, p
|
||||
gen = self._named_members(
|
||||
lambda module: module._parameters.items(),
|
||||
prefix=prefix, recurse=recurse)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
def _all_buffers(self, memo=None):
|
||||
if memo is None:
|
||||
memo = set()
|
||||
for name, b in self._buffers.items():
|
||||
if b is not None and b not in memo:
|
||||
memo.add(b)
|
||||
yield b
|
||||
for module in self.children():
|
||||
for b in module._all_buffers(memo):
|
||||
yield b
|
||||
def buffers(self, recurse=True):
|
||||
r"""Returns an iterator over module buffers.
|
||||
|
||||
Args:
|
||||
recurse (bool): if True, then yields buffers of this module
|
||||
and all submodules. Otherwise, yields only buffers that
|
||||
are direct members of this module.
|
||||
|
||||
Yields:
|
||||
torch.Tensor: module buffer
|
||||
|
||||
Example::
|
||||
|
||||
>>> for buf in model.buffers():
|
||||
>>> print(type(buf.data), buf.size())
|
||||
<class 'torch.FloatTensor'> (20L,)
|
||||
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
||||
|
||||
"""
|
||||
for name, buf in self.named_buffers(recurse=recurse):
|
||||
yield buf
|
||||
|
||||
def named_buffers(self, prefix='', recurse=True):
|
||||
r"""Returns an iterator over module buffers, yielding both the
|
||||
name of the buffer as well as the buffer itself.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix to prepend to all buffer names.
|
||||
recurse (bool): if True, then yields buffers of this module
|
||||
and all submodules. Otherwise, yields only buffers that
|
||||
are direct members of this module.
|
||||
|
||||
Yields:
|
||||
(string, torch.Tensor): Tuple containing the name and buffer
|
||||
|
||||
Example::
|
||||
|
||||
>>> for name, buf in self.named_buffers():
|
||||
>>> if name in ['running_var']:
|
||||
>>> print(buf.size())
|
||||
|
||||
"""
|
||||
gen = self._named_members(
|
||||
lambda module: module._buffers.items(),
|
||||
prefix=prefix, recurse=recurse)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
def children(self):
|
||||
r"""Returns an iterator over immediate children modules.
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ class DistributedDataParallel(Module):
|
|||
|
||||
# module buffer sync
|
||||
if self.broadcast_buffers:
|
||||
buffers = [b.data for b in self.module._all_buffers()]
|
||||
buffers = [b.data for b in self.module.buffers()]
|
||||
if len(buffers) > 0:
|
||||
# cross-node buffer sync
|
||||
self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)
|
||||
|
|
@ -273,7 +273,7 @@ class DistributedDataParallel(Module):
|
|||
# intra-node buffer sync
|
||||
result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
|
||||
for tensors, module in zip(result[1:], self._module_copies[1:]):
|
||||
for tensor, buf in zip(tensors, module._all_buffers()):
|
||||
for tensor, buf in zip(tensors, module.buffers()):
|
||||
buf.data.set_(tensor)
|
||||
|
||||
def _register_grad_hooks(self):
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class _DistributedDataParallelC10d(Module):
|
|||
|
||||
for dev_idx, module in enumerate(self._module_copies):
|
||||
self.modules_params_data[dev_idx] = [p.data for p in module.parameters()]
|
||||
self.modules_buffers_data[dev_idx] = [b.data for b in module._all_buffers()]
|
||||
self.modules_buffers_data[dev_idx] = [b.data for b in module.buffers()]
|
||||
|
||||
bucket_bytes_cap = bucket_cap_mb * MB
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def replicate(network, devices, detach=False):
|
|||
param_copies = [param_copies[i:i + len(params)]
|
||||
for i in range(0, len(param_copies), len(params))]
|
||||
|
||||
buffers = list(network._all_buffers())
|
||||
buffers = list(network.buffers())
|
||||
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
|
||||
buffer_copies = comm.broadcast_coalesced(buffers, devices)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue