Add tests for replicate multiple modules (#89099)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099
Approved by: https://github.com/zhaojuanmao
This commit is contained in:
Charlie Yan 2022-11-17 19:05:44 +00:00 committed by PyTorch MergeBot
parent 767f6aa49f
commit a695fcf201

View file

@ -39,13 +39,7 @@ class ReplicateTest(MultiProcessTestCase):
except OSError:
pass
def _prepare_module(self, global_batch_size):
model = Net()
input = torch.randn(global_batch_size, 2)
target = torch.randn(global_batch_size, 4)
return model, input, target
def test_replicate(self):
def _compare_module(self, mod, replicate_mod):
dist.init_process_group(
backend="gloo",
rank=self.rank,
@ -55,8 +49,8 @@ class ReplicateTest(MultiProcessTestCase):
local_batch_size = 1
global_batch_size = self.world_size * local_batch_size
model, input, target = self._prepare_module(global_batch_size)
replicate_model = mark_root_module(replicate(deepcopy(model)))
input = torch.randn(global_batch_size, 2)
target = torch.randn(global_batch_size, 4)
def step_model(model, input, target):
model.train()
@ -69,9 +63,9 @@ class ReplicateTest(MultiProcessTestCase):
param.grad = None
for iteration in range(2):
step_model(model, input, target)
step_model(mod, input, target)
step_model(
replicate_model,
replicate_mod,
input[
self.rank
* local_batch_size : (self.rank + 1)
@ -85,16 +79,29 @@ class ReplicateTest(MultiProcessTestCase):
)
self.assertEqual(
len(list(model.parameters())),
len(list(replicate_model.parameters())),
len(list(mod.parameters())),
len(list(replicate_mod.parameters())),
)
for i, j in zip(model.parameters(), replicate_model.parameters()):
for i, j in zip(mod.parameters(), replicate_mod.parameters()):
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
# Shuffle the input so that DDP input is different
torch.manual_seed(iteration)
input = input[torch.randperm(global_batch_size)]
def test_replicate_single_module(self):
model = Net()
replicate_model = mark_root_module(replicate(deepcopy(model)))
self._compare_module(model, replicate_model)
def test_replicate_multi_module(self):
model = Net()
replicate_model = mark_root_module(deepcopy(model))
replicate(replicate_model.fc1)
replicate(replicate_model.fc2)
replicate(replicate_model.fc3)
self._compare_module(model, replicate_model)
if __name__ == "__main__":
run_tests()