From 9039fbb47ecfc93df74a014a209e5929d10fd2a3 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 4 Nov 2024 10:59:50 -0800 Subject: [PATCH] [FSDP2] Make module-to-state mapping use weakrefs (#139650) Without this, `del model` does not free memory of a module with FSDP2 applied. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139650 Approved by: https://github.com/yf225 --- .../fsdp/test_fully_shard_memory.py | 31 +++++++++++++++++++ torch/distributed/_composable_state.py | 15 ++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py index 7dba4ce7350..88e00e66c5e 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import functools +import gc import torch from torch.distributed._composable.fsdp import ( @@ -197,6 +198,36 @@ class TestFullyShardMemory(FSDPTest): expected_mem_mb += (2 * model_sharded_numel) * 4 / 1e6 + buffer_mb self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb) + @skip_if_lt_x_gpu(2) + def test_fully_shard_del_memory(self): + base_mem_mb = self._get_peak_active_memory_mb() + vocab_size = 32 + model_args = ModelArgs( + vocab_size=vocab_size, n_layers=3, dim=768, n_heads=12, weight_tying=False + ) + model = Transformer(model_args) + # Initializing the model on CPU should not change the GPU memory usage + post_model_init_mem_mb = self._get_peak_active_memory_mb() + self.assertEqual(base_mem_mb, post_model_init_mem_mb) + + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module) + fully_shard(model) + unsharded_numel = sum(p.numel() for p in model.parameters()) + sharded_numel = unsharded_numel // self.world_size + buffer_mb = 4 + mem_mb = self._get_curr_active_memory_mb() + expected_mb = sharded_numel * 4 / 1e6 + buffer_mb + self.assertLessEqual(mem_mb - base_mem_mb, expected_mb) + + # Deleting the model should free all of the FSDP-managed GPU memory + del model + # Manually call garbage collection since there are ref cycles in FSDP + gc.collect() + mem_mb = self._get_curr_active_memory_mb() + self.assertEqual(mem_mb, base_mem_mb) + def _get_peak_active_memory_mb(self) -> int: mem_stats = torch.cuda.memory_stats() return round(mem_stats["active_bytes.all.peak"] / 1e6) diff --git a/torch/distributed/_composable_state.py b/torch/distributed/_composable_state.py index f50da98f8c6..6d2b8baed76 100644 --- a/torch/distributed/_composable_state.py +++ b/torch/distributed/_composable_state.py @@ -1,4 +1,5 @@ -from typing import cast, Dict, Optional +import weakref +from typing import cast, Optional import torch.nn as nn @@ -7,13 +8,15 @@ class _State: pass -_module_state_mapping: Dict[nn.Module, _State] = {} +_module_state_mapping: weakref.WeakKeyDictionary[ + nn.Module, weakref.ReferenceType[_State] +] = weakref.WeakKeyDictionary() def _insert_module_state(module: nn.Module, state: _State) -> None: global _module_state_mapping assert module not in _module_state_mapping, f"Inserting {module} more than once." - _module_state_mapping[module] = state + _module_state_mapping[module] = weakref.ref(state) def _get_module_state(module: nn.Module) -> Optional[_State]: @@ -32,6 +35,10 @@ def _get_module_state(module: nn.Module) -> Optional[_State]: else: # https://github.com/pytorch/pytorch/issues/107054 if module in _module_state_mapping: - return _module_state_mapping[module] + state_ref = _module_state_mapping[module] + state = state_ref() + if state is None: + raise AssertionError("State has already been garbage collected") + return state else: return None