mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE]: Simplify set add with set update (#145152)
Simplifies the set update slightly to be more readable and efficient. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152 Approved by: https://github.com/XuehaiPan, https://github.com/albanD Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
This commit is contained in:
parent
d7b6746470
commit
5ebca3015d
2 changed files with 6 additions and 5 deletions
|
|
@ -1782,8 +1782,7 @@ def min_cut_rematerialization_partition(
|
|||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
for user in node.users:
|
||||
required_bw_nodes.add(user)
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
|
|
|
|||
|
|
@ -849,10 +849,12 @@ def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]:
|
|||
"""
|
||||
Returns set of non-persistent buffers in a module and its submodules.
|
||||
"""
|
||||
result = set()
|
||||
result: set[str] = set()
|
||||
for name, m in mod.named_modules(remove_duplicate=False):
|
||||
for b in m._non_persistent_buffers_set:
|
||||
result.add(f"{name}.{b}" if name else b)
|
||||
if name:
|
||||
result.update(f"{name}.{b}" for b in m._non_persistent_buffers_set)
|
||||
else:
|
||||
result.update(m._non_persistent_buffers_set)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue