From a695fcf20103bb08ae660788d128cd924e6ec05b Mon Sep 17 00:00:00 2001 From: Charlie Yan Date: Thu, 17 Nov 2022 19:05:44 +0000 Subject: [PATCH] Add tests for replicate multiple modules (#89099) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099 Approved by: https://github.com/zhaojuanmao --- .../distributed/_composable/test_replicate.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py index 831ccc3376a..3e8bf44a1fd 100644 --- a/test/distributed/_composable/test_replicate.py +++ b/test/distributed/_composable/test_replicate.py @@ -39,13 +39,7 @@ class ReplicateTest(MultiProcessTestCase): except OSError: pass - def _prepare_module(self, global_batch_size): - model = Net() - input = torch.randn(global_batch_size, 2) - target = torch.randn(global_batch_size, 4) - return model, input, target - - def test_replicate(self): + def _compare_module(self, mod, replicate_mod): dist.init_process_group( backend="gloo", rank=self.rank, @@ -55,8 +49,8 @@ class ReplicateTest(MultiProcessTestCase): local_batch_size = 1 global_batch_size = self.world_size * local_batch_size - model, input, target = self._prepare_module(global_batch_size) - replicate_model = mark_root_module(replicate(deepcopy(model))) + input = torch.randn(global_batch_size, 2) + target = torch.randn(global_batch_size, 4) def step_model(model, input, target): model.train() @@ -69,9 +63,9 @@ class ReplicateTest(MultiProcessTestCase): param.grad = None for iteration in range(2): - step_model(model, input, target) + step_model(mod, input, target) step_model( - replicate_model, + replicate_mod, input[ self.rank * local_batch_size : (self.rank + 1) @@ -85,16 +79,29 @@ class ReplicateTest(MultiProcessTestCase): ) self.assertEqual( - len(list(model.parameters())), - len(list(replicate_model.parameters())), + len(list(mod.parameters())), + len(list(replicate_mod.parameters())), ) - for i, j in zip(model.parameters(), replicate_model.parameters()): + for i, j in zip(mod.parameters(), replicate_mod.parameters()): self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5) # Shuffle the input so that DDP input is different torch.manual_seed(iteration) input = input[torch.randperm(global_batch_size)] + def test_replicate_single_module(self): + model = Net() + replicate_model = mark_root_module(replicate(deepcopy(model))) + self._compare_module(model, replicate_model) + + def test_replicate_multi_module(self): + model = Net() + replicate_model = mark_root_module(deepcopy(model)) + replicate(replicate_model.fc1) + replicate(replicate_model.fc2) + replicate(replicate_model.fc3) + self._compare_module(model, replicate_model) + if __name__ == "__main__": run_tests()