pytorch/test/distributed/checkpoint/test_fsdp_model_state.py
hippocookie 5f57be7571 [Distributed] Change function call in test to non-deprecated to eliminate warning (#134938)
Migrate function call in test to eliminate warning message in below and reduce the chance of test fail when methods removed

-  from deprecated `save_state_dict` change to `save`
-  from deprecated `load_state_dict` change to `load`

Warning message:
```bash
pytorch/test/distributed/checkpoint/test_fsdp_model_state.py:37: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134938
Approved by: https://github.com/wz337, https://github.com/fegin
2024-09-06 03:25:09 +00:00

99 lines
3.3 KiB
Python

# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class FsdpModelStateCheckpoint(DTensorTestBase):
@property
def backend(self):
return "cpu:gloo,cuda:nccl"
def _test_fsdp_model_state(self, process_group) -> None:
CHECKPOINT_DIR = self.temp_dir
model = FSDP(torch.nn.Linear(8, 8, device="meta"))
model(torch.rand(8, 8, device=dist.get_rank())).sum().backward()
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model.state_dict(),
}
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=DefaultSavePlanner(),
)
model_2 = FSDP(
torch.nn.Linear(8, 8, device="meta"), process_group=process_group
)
with FSDP.summon_full_params(model):
with FSDP.summon_full_params(model_2):
self.assertNotEqual(model.weight, model_2.weight)
self.assertNotEqual(model.bias, model_2.bias)
# now load the model and ensure the values are the same
with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model_2.state_dict(),
}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=DefaultLoadPlanner(),
)
model_2.load_state_dict(state_dict["model"])
with FSDP.summon_full_params(model):
with FSDP.summon_full_params(model_2):
self.assertEqual(model.weight, model_2.weight)
self.assertEqual(model.bias, model_2.bias)
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_fsdp_model_state_no_resharding(self):
self._test_fsdp_model_state(process_group=None)
def _create_new_dist_group(self):
world_size = dist.get_world_size()
group1 = [i for i in range(world_size) if i % 2 == 0]
group2 = [i for i in range(world_size) if i % 2 != 0]
# create new fsdp group for resharding
fsdp_0 = dist.new_group(ranks=group1)
fsdp_1 = dist.new_group(ranks=group2)
if dist.get_rank() % 2 == 0:
my_fsdp = fsdp_0
else:
my_fsdp = fsdp_1
return my_fsdp
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_fsdp_model_state_with_resharding(self):
self._test_fsdp_model_state(process_group=self._create_new_dist_group())
if __name__ == "__main__":
run_tests()