mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099 Approved by: https://github.com/zhaojuanmao
107 lines
3.3 KiB
Python
107 lines
3.3 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.distributed._composable.replicate import mark_root_module, replicate
|
|
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.fc1 = nn.Linear(2, 10, bias=False)
|
|
self.fc2 = nn.Linear(10, 50, bias=False)
|
|
self.fc3 = nn.Linear(50, 4, bias=False)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.fc1(x))
|
|
x = self.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return F.softmax(x, dim=1)
|
|
|
|
|
|
class ReplicateTest(MultiProcessTestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def _compare_module(self, mod, replicate_mod):
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
store=dist.FileStore(self.file_name, self.world_size),
|
|
)
|
|
|
|
local_batch_size = 1
|
|
global_batch_size = self.world_size * local_batch_size
|
|
input = torch.randn(global_batch_size, 2)
|
|
target = torch.randn(global_batch_size, 4)
|
|
|
|
def step_model(model, input, target):
|
|
model.train()
|
|
output = model(input)
|
|
loss = F.mse_loss(output, target.to(output.device))
|
|
loss.backward()
|
|
for param in model.parameters():
|
|
with torch.no_grad():
|
|
param -= param.grad
|
|
param.grad = None
|
|
|
|
for iteration in range(2):
|
|
step_model(mod, input, target)
|
|
step_model(
|
|
replicate_mod,
|
|
input[
|
|
self.rank
|
|
* local_batch_size : (self.rank + 1)
|
|
* local_batch_size
|
|
],
|
|
target[
|
|
self.rank
|
|
* local_batch_size : (self.rank + 1)
|
|
* local_batch_size
|
|
],
|
|
)
|
|
|
|
self.assertEqual(
|
|
len(list(mod.parameters())),
|
|
len(list(replicate_mod.parameters())),
|
|
)
|
|
for i, j in zip(mod.parameters(), replicate_mod.parameters()):
|
|
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
|
|
|
|
# Shuffle the input so that DDP input is different
|
|
torch.manual_seed(iteration)
|
|
input = input[torch.randperm(global_batch_size)]
|
|
|
|
def test_replicate_single_module(self):
|
|
model = Net()
|
|
replicate_model = mark_root_module(replicate(deepcopy(model)))
|
|
self._compare_module(model, replicate_model)
|
|
|
|
def test_replicate_multi_module(self):
|
|
model = Net()
|
|
replicate_model = mark_root_module(deepcopy(model))
|
|
replicate(replicate_model.fc1)
|
|
replicate(replicate_model.fc2)
|
|
replicate(replicate_model.fc3)
|
|
self._compare_module(model, replicate_model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|