mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add tests for replicate multiple modules (#89099)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099 Approved by: https://github.com/zhaojuanmao
This commit is contained in:
parent
767f6aa49f
commit
a695fcf201
1 changed files with 21 additions and 14 deletions
|
|
@ -39,13 +39,7 @@ class ReplicateTest(MultiProcessTestCase):
|
|||
except OSError:
|
||||
pass
|
||||
|
||||
def _prepare_module(self, global_batch_size):
|
||||
model = Net()
|
||||
input = torch.randn(global_batch_size, 2)
|
||||
target = torch.randn(global_batch_size, 4)
|
||||
return model, input, target
|
||||
|
||||
def test_replicate(self):
|
||||
def _compare_module(self, mod, replicate_mod):
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=self.rank,
|
||||
|
|
@ -55,8 +49,8 @@ class ReplicateTest(MultiProcessTestCase):
|
|||
|
||||
local_batch_size = 1
|
||||
global_batch_size = self.world_size * local_batch_size
|
||||
model, input, target = self._prepare_module(global_batch_size)
|
||||
replicate_model = mark_root_module(replicate(deepcopy(model)))
|
||||
input = torch.randn(global_batch_size, 2)
|
||||
target = torch.randn(global_batch_size, 4)
|
||||
|
||||
def step_model(model, input, target):
|
||||
model.train()
|
||||
|
|
@ -69,9 +63,9 @@ class ReplicateTest(MultiProcessTestCase):
|
|||
param.grad = None
|
||||
|
||||
for iteration in range(2):
|
||||
step_model(model, input, target)
|
||||
step_model(mod, input, target)
|
||||
step_model(
|
||||
replicate_model,
|
||||
replicate_mod,
|
||||
input[
|
||||
self.rank
|
||||
* local_batch_size : (self.rank + 1)
|
||||
|
|
@ -85,16 +79,29 @@ class ReplicateTest(MultiProcessTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(
|
||||
len(list(model.parameters())),
|
||||
len(list(replicate_model.parameters())),
|
||||
len(list(mod.parameters())),
|
||||
len(list(replicate_mod.parameters())),
|
||||
)
|
||||
for i, j in zip(model.parameters(), replicate_model.parameters()):
|
||||
for i, j in zip(mod.parameters(), replicate_mod.parameters()):
|
||||
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
|
||||
|
||||
# Shuffle the input so that DDP input is different
|
||||
torch.manual_seed(iteration)
|
||||
input = input[torch.randperm(global_batch_size)]
|
||||
|
||||
def test_replicate_single_module(self):
|
||||
model = Net()
|
||||
replicate_model = mark_root_module(replicate(deepcopy(model)))
|
||||
self._compare_module(model, replicate_model)
|
||||
|
||||
def test_replicate_multi_module(self):
|
||||
model = Net()
|
||||
replicate_model = mark_root_module(deepcopy(model))
|
||||
replicate(replicate_model.fc1)
|
||||
replicate(replicate_model.fc2)
|
||||
replicate(replicate_model.fc3)
|
||||
self._compare_module(model, replicate_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in a new issue