[BE]: Simplify set add with set update (#145152)

Simplifies the set update slightly to be more readable and efficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
This commit is contained in:
Aaron Gokaslan 2025-01-23 20:18:13 +00:00 committed by PyTorch MergeBot
parent d7b6746470
commit 5ebca3015d
2 changed files with 6 additions and 5 deletions

View file

@ -1782,8 +1782,7 @@ def min_cut_rematerialization_partition(
required_bw_nodes.add(node)
if node in required_bw_nodes:
for user in node.users:
required_bw_nodes.add(user)
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(

View file

@ -849,10 +849,12 @@ def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]:
"""
Returns set of non-persistent buffers in a module and its submodules.
"""
result = set()
result: set[str] = set()
for name, m in mod.named_modules(remove_duplicate=False):
for b in m._non_persistent_buffers_set:
result.add(f"{name}.{b}" if name else b)
if name:
result.update(f"{name}.{b}" for b in m._non_persistent_buffers_set)
else:
result.update(m._non_persistent_buffers_set)
return result