mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
This PR creates a prototype of training convolutional neural networks based on DTensor. - Register required ops and implement operator dispatch - Add unit tests and example Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration). Reference log (on 2xA100 GPU): Unit Test ```bash root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_convolution_ops.py /opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.) return F.conv2d(input, weight, bias, self.stride, /opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.) return F.conv2d(input, weight, bias, self.stride, .. ---------------------------------------------------------------------- Ran 2 tests in 30.354s OK root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_other_ops.py [rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) [rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) [rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) [rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) ... ---------------------------------------------------------------------- Ran 3 tests in 16.343s OK ``` ConvNeXt Example ```bash root@luna-prod-78-80gb:/pytorch# python3 torch/distributed/_tensor/examples/convnext_example.py rank 3, 20 iterations, latency 584.80 ms, forward 102.84 ms, backward 297.80 ms, max reserved 16.34 GiB, max allocated 14.75 GiB rank 1, 20 iterations, latency 584.64 ms, forward 104.85 ms, backward 297.60 ms, max reserved 16.40 GiB, max allocated 14.74 GiB rank 0, 20 iterations, latency 584.48 ms, forward 104.64 ms, backward 297.90 ms, max reserved 16.39 GiB, max allocated 14.75 GiB rank 2, 20 iterations, latency 584.96 ms, forward 93.21 ms, backward 297.95 ms, max reserved 16.40 GiB, max allocated 14.74 GiB ``` @wanchaol @fduwjj FYI Pull Request resolved: https://github.com/pytorch/pytorch/pull/113123 Approved by: https://github.com/wanchaol
181 lines
6.9 KiB
Python
181 lines
6.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
distribute_module,
|
|
distribute_tensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
ITER_TIME = 10
|
|
LR = 0.001
|
|
|
|
|
|
def _conv_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
for name, param in module.named_parameters():
|
|
dist_spec = [Replicate()]
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, dist_spec)
|
|
)
|
|
name = "_".join(name.split("."))
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
class DistConvolutionOpsTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
# hard code world size to 2
|
|
return 2
|
|
|
|
@with_comms
|
|
def test_downsampling_convolution(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
shard_spec = [Shard(3)]
|
|
|
|
input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024)
|
|
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
|
|
|
|
model = nn.Conv2d(3, 256, kernel_size=4, stride=4, padding=0).to(
|
|
self.device_type
|
|
)
|
|
nn.init.ones_(model.weight)
|
|
nn.init.zeros_(model.bias)
|
|
model_gt = copy.deepcopy(model).to(self.device_type)
|
|
|
|
# training with dtensor
|
|
model = distribute_module(
|
|
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
)
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
|
|
output = model(inp_dtensor)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
grad_output_dtensor = distribute_tensor(
|
|
grad_output, device_mesh, shard_spec
|
|
)
|
|
output.backward(grad_output_dtensor)
|
|
optimizer.step()
|
|
|
|
# training with plain tensor
|
|
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer_gt.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
output = model_gt(inp)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
output.backward(grad_output)
|
|
optimizer_gt.step()
|
|
|
|
weight_diff_abs = model.weight.to_local() - model_gt.weight
|
|
bias_diff_abs = model.bias.to_local() - model_gt.bias
|
|
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
|
|
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
|
|
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
|
|
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
|
|
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
|
|
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
|
|
self.assertTrue(
|
|
weight_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
weight_mse_rel <= 1e-6,
|
|
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_rel <= 1e-6,
|
|
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
|
|
)
|
|
|
|
@with_comms
|
|
def test_depthwise_convolution(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
shard_spec = [Shard(3)]
|
|
|
|
input_list = torch.rand(ITER_TIME, 7, 256, 128, 256)
|
|
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
|
|
|
|
model = nn.Conv2d(256, 256, kernel_size=7, padding=3, groups=256).to(
|
|
self.device_type
|
|
)
|
|
nn.init.ones_(model.weight)
|
|
nn.init.zeros_(model.bias)
|
|
model_gt = copy.deepcopy(model).to(self.device_type)
|
|
|
|
# training with dtensor
|
|
model = distribute_module(
|
|
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
)
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
|
|
output = model(inp_dtensor)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
grad_output_dtensor = distribute_tensor(
|
|
grad_output, device_mesh, shard_spec
|
|
)
|
|
output.backward(grad_output_dtensor)
|
|
optimizer.step()
|
|
|
|
# training with plain tensor
|
|
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer_gt.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
output = model_gt(inp)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
output.backward(grad_output)
|
|
optimizer_gt.step()
|
|
|
|
weight_diff_abs = model.weight.to_local() - model_gt.weight
|
|
bias_diff_abs = model.bias.to_local() - model_gt.bias
|
|
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
|
|
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
|
|
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
|
|
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
|
|
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
|
|
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
|
|
self.assertTrue(
|
|
weight_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
weight_mse_rel <= 1e-6,
|
|
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_rel <= 1e-6,
|
|
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|