Add `buffers(), named_buffers()` methods. (#10554)

Summary:
This commit adds the ``buffers()`` and ``named_buffers()`` methods as
analogues of ``parameters()`` and ``named_parameters()``.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10554

Reviewed By: SsnL

Differential Revision: D9367762

Pulled By: jma127

fbshipit-source-id: f2042e46a7e833dce40cb41681dbd80d7885c74e
This commit is contained in:
Jerry Ma 2018-08-16 16:16:22 -07:00 committed by Facebook Github Bot
parent 342517e6e7
commit afd7477eaa
6 changed files with 149 additions and 66 deletions

View file

@ -525,6 +525,26 @@ class TestNN(NNTestCase):
d_params.append(p.grad)
return params, d_params
def _create_basic_net(self):
class Layer(nn.Module):
def __init__(self):
super(Layer, self).__init__()
self.layer_dummy_param = Parameter(torch.Tensor(3, 5))
self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = Layer()
self.dummy_param = Parameter(torch.Tensor(3, 5))
self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
l = Layer()
n = Net()
s = nn.Sequential(n, n)
return l, n, s
def test_module_backcompat(self):
from torch.serialization import SourceChangeWarning
path = download_file('https://download.pytorch.org/test_data/linear.pt')
@ -769,51 +789,57 @@ class TestNN(NNTestCase):
self.assertLess(abs(output.data.std() - std), 0.1)
output.backward(input)
def test_parameters(self):
def num_params(module):
return len(list(module.parameters()))
def test_parameters_and_named_parameters(self):
def names(named_parameters):
return [k for k, _ in named_parameters]
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = l
self.l2 = l
self.param = Parameter(torch.Tensor(3, 5))
l, n, s = self._create_basic_net()
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(n, n, n, n)
self.assertEqual(num_params(l), 2)
self.assertEqual(num_params(n), 3)
self.assertEqual(num_params(s), 3)
self.assertEqual(len(list(l.parameters())), 1)
self.assertEqual(
names(l.named_parameters()),
['layer_dummy_param'])
def test_named_parameters(self):
def num_params(module):
return len(dict(module.named_parameters()))
self.assertEqual(len(list(n.parameters())), 2)
self.assertEqual(
names(n.named_parameters()),
['dummy_param', 'l1.layer_dummy_param'])
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = l
self.l2 = l
self.param = Parameter(torch.Tensor(3, 5))
self.assertEqual(len(list(n.parameters(recurse=False))), 1)
self.assertEqual(
names(n.named_parameters(recurse=False)),
['dummy_param'])
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(n, n, n, n)
self.assertEqual(len(list(s.parameters())), 2)
self.assertEqual(
names(s.named_parameters()),
['0.dummy_param', '0.l1.layer_dummy_param'])
for name in dict(l.named_parameters()).keys():
self.assertTrue(name in ['bias', 'weight'])
def test_buffers_and_named_buffers(self):
def names(named_buffers):
return [k for k, _ in named_buffers]
for name in dict(n.named_parameters()).keys():
self.assertTrue(name in ['l1.bias', 'l1.weight', 'param'])
l, n, s = self._create_basic_net()
for name in dict(s.named_parameters()).keys():
self.assertTrue(name in ['0.l1.bias', '0.l1.weight', '0.param'])
self.assertEqual(len(list(l.buffers())), 1)
self.assertEqual(
names(l.named_buffers()),
['layer_dummy_buf'])
self.assertEqual(num_params(l), 2)
self.assertEqual(num_params(n), 3)
self.assertEqual(num_params(s), 3)
self.assertEqual(len(list(n.buffers())), 2)
self.assertEqual(
names(n.named_buffers()),
['dummy_buf', 'l1.layer_dummy_buf'])
self.assertEqual(len(list(n.buffers(recurse=False))), 1)
self.assertEqual(
names(n.named_buffers(recurse=False)),
['dummy_buf'])
self.assertEqual(len(list(s.buffers())), 2)
self.assertEqual(
names(s.named_buffers()),
['0.dummy_buf', '0.l1.layer_dummy_buf'])
def test_call_supports_python_dict_output(self):
class Net(nn.Module):

View file

@ -656,8 +656,9 @@ def _get_methods(cls):
_compiled_methods_whitelist = {
'forward', 'register_buffer', 'register_parameter', 'add_module',
'_apply', 'apply', 'cuda', 'cpu', 'type', 'float', 'double', 'half',
'state_dict', 'load_state_dict', '_load_from_state_dict', 'parameters',
'named_parameters', '_all_buffers', 'children', 'named_children', 'modules',
'state_dict', 'load_state_dict', '_load_from_state_dict',
'_named_members', 'parameters', 'named_parameters',
'buffers', 'named_buffers', 'children', 'named_children', 'modules',
'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
'_slow_forward', '_tracing_name'
}

View file

@ -722,11 +722,29 @@ class Module(object):
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
def parameters(self):
def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
def parameters(self, recurse=True):
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
@ -738,12 +756,18 @@ class Module(object):
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters():
for name, param in self.named_parameters(recurse=recurse):
yield param
def named_parameters(self, memo=None, prefix=''):
def named_parameters(self, prefix='', recurse=True):
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself
name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(string, Parameter): Tuple containing the name and parameter
@ -755,27 +779,59 @@ class Module(object):
>>> print(param.size())
"""
if memo is None:
memo = set()
for name, p in self._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in self.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in module.named_parameters(memo, submodule_prefix):
yield name, p
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def _all_buffers(self, memo=None):
if memo is None:
memo = set()
for name, b in self._buffers.items():
if b is not None and b not in memo:
memo.add(b)
yield b
for module in self.children():
for b in module._all_buffers(memo):
yield b
def buffers(self, recurse=True):
r"""Returns an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> for buf in model.buffers():
>>> print(type(buf.data), buf.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
for name, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(self, prefix='', recurse=True):
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
(string, torch.Tensor): Tuple containing the name and buffer
Example::
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
"""
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def children(self):
r"""Returns an iterator over immediate children modules.

View file

@ -264,7 +264,7 @@ class DistributedDataParallel(Module):
# module buffer sync
if self.broadcast_buffers:
buffers = [b.data for b in self.module._all_buffers()]
buffers = [b.data for b in self.module.buffers()]
if len(buffers) > 0:
# cross-node buffer sync
self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)
@ -273,7 +273,7 @@ class DistributedDataParallel(Module):
# intra-node buffer sync
result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
for tensors, module in zip(result[1:], self._module_copies[1:]):
for tensor, buf in zip(tensors, module._all_buffers()):
for tensor, buf in zip(tensors, module.buffers()):
buf.data.set_(tensor)
def _register_grad_hooks(self):

View file

@ -162,7 +162,7 @@ class _DistributedDataParallelC10d(Module):
for dev_idx, module in enumerate(self._module_copies):
self.modules_params_data[dev_idx] = [p.data for p in module.parameters()]
self.modules_buffers_data[dev_idx] = [b.data for b in module._all_buffers()]
self.modules_buffers_data[dev_idx] = [b.data for b in module.buffers()]
bucket_bytes_cap = bucket_cap_mb * MB

View file

@ -14,7 +14,7 @@ def replicate(network, devices, detach=False):
param_copies = [param_copies[i:i + len(params)]
for i in range(0, len(param_copies), len(params))]
buffers = list(network._all_buffers())
buffers = list(network.buffers())
buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
buffer_copies = comm.broadcast_coalesced(buffers, devices)