mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR creates a prototype of training convolutional neural networks based on DTensor. - Register required ops and implement operator dispatch - Add unit tests and example Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration). Reference log (on 2xA100 GPU): Unit Test ```bash root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_convolution_ops.py /opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.) return F.conv2d(input, weight, bias, self.stride, /opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.) return F.conv2d(input, weight, bias, self.stride, .. ---------------------------------------------------------------------- Ran 2 tests in 30.354s OK root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_other_ops.py [rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) [rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) [rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) [rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator()) ... ---------------------------------------------------------------------- Ran 3 tests in 16.343s OK ``` ConvNeXt Example ```bash root@luna-prod-78-80gb:/pytorch# python3 torch/distributed/_tensor/examples/convnext_example.py rank 3, 20 iterations, latency 584.80 ms, forward 102.84 ms, backward 297.80 ms, max reserved 16.34 GiB, max allocated 14.75 GiB rank 1, 20 iterations, latency 584.64 ms, forward 104.85 ms, backward 297.60 ms, max reserved 16.40 GiB, max allocated 14.74 GiB rank 0, 20 iterations, latency 584.48 ms, forward 104.64 ms, backward 297.90 ms, max reserved 16.39 GiB, max allocated 14.75 GiB rank 2, 20 iterations, latency 584.96 ms, forward 93.21 ms, backward 297.95 ms, max reserved 16.40 GiB, max allocated 14.74 GiB ``` @wanchaol @fduwjj FYI Pull Request resolved: https://github.com/pytorch/pytorch/pull/113123 Approved by: https://github.com/wanchaol
194 lines
7.7 KiB
Python
194 lines
7.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
ITER_TIME = 10
|
|
LR = 0.001
|
|
|
|
|
|
class DistOtherOpsTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
# hard code world size to 2
|
|
return 2
|
|
|
|
@with_comms
|
|
def test_slice(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
shard_spec = [Replicate()]
|
|
|
|
input_list = torch.rand(ITER_TIME, 1024, 10)
|
|
grad_output_list = torch.rand(ITER_TIME, 1024, 5) * 1e-3
|
|
|
|
for i in range(ITER_TIME):
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
|
|
# droppath with dtensor
|
|
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
|
|
grad_output_dtensor = distribute_tensor(
|
|
grad_output, device_mesh, shard_spec
|
|
)
|
|
output = inp_dtensor[:, :5]
|
|
output.backward(grad_output_dtensor)
|
|
|
|
# nll with plain tensor
|
|
output_gt = inp[:, :5]
|
|
output_gt.backward(grad_output)
|
|
|
|
output_diff_abs = output.to_local() - output_gt
|
|
output_diff_rel = output_diff_abs / (torch.abs(output_gt) + 1e-8)
|
|
output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item()
|
|
output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item()
|
|
|
|
grad_diff_abs = inp_dtensor.grad.to_local() - inp.grad
|
|
grad_diff_rel = grad_diff_abs / (torch.abs(inp.grad) + 1e-8)
|
|
grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
|
|
grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
|
|
|
|
self.assertTrue(
|
|
output_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
output_mse_rel <= 1e-6,
|
|
f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
grad_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
grad_mse_rel <= 1e-6,
|
|
f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
|
|
)
|
|
|
|
@with_comms
|
|
def test_bernoulli(self):
|
|
rank = dist.get_rank()
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
shard_spec = [Replicate()]
|
|
|
|
input_list = torch.rand(ITER_TIME, 1024, 10)
|
|
grad_output_list = torch.rand(ITER_TIME, 1024, 10) * 1e-3
|
|
|
|
for i in range(ITER_TIME):
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
|
|
# bernoulli with dtensor
|
|
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
|
|
grad_output_dtensor = distribute_tensor(
|
|
grad_output, device_mesh, shard_spec
|
|
)
|
|
output = torch.bernoulli(inp_dtensor)
|
|
output.backward(grad_output_dtensor)
|
|
|
|
send_output_tensor = output.to_local()
|
|
recv_output_tensor = torch.zeros_like(send_output_tensor)
|
|
|
|
send_grad_tensor = inp_dtensor.grad.to_local()
|
|
recv_grad_tensor = torch.zeros_like(send_grad_tensor)
|
|
|
|
send_op_1 = dist.P2POp(dist.isend, send_output_tensor, 1 ^ rank)
|
|
send_op_2 = dist.P2POp(dist.isend, send_grad_tensor, 1 ^ rank)
|
|
recv_op_1 = dist.P2POp(dist.irecv, recv_output_tensor, 1 ^ rank)
|
|
recv_op_2 = dist.P2POp(dist.irecv, recv_grad_tensor, 1 ^ rank)
|
|
|
|
reqs = dist.batch_isend_irecv([send_op_1, send_op_2, recv_op_1, recv_op_2])
|
|
for req in reqs:
|
|
req.wait()
|
|
|
|
output_diff_abs = send_output_tensor - recv_output_tensor
|
|
output_diff_rel = output_diff_abs / (torch.abs(recv_output_tensor) + 1e-8)
|
|
output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item()
|
|
output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item()
|
|
|
|
grad_diff_abs = send_grad_tensor - recv_grad_tensor
|
|
grad_diff_rel = grad_diff_abs / (torch.abs(recv_grad_tensor) + 1e-8)
|
|
grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
|
|
grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
|
|
|
|
self.assertTrue(
|
|
output_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
output_mse_rel <= 1e-6,
|
|
f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
grad_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
grad_mse_rel <= 1e-6,
|
|
f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
|
|
)
|
|
|
|
@with_comms
|
|
def test_nll(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
shard_spec = [Replicate()]
|
|
|
|
pred_list = torch.rand(ITER_TIME, 1024, 10)
|
|
target_list = torch.randint(0, 10, (ITER_TIME, 1024), dtype=torch.long)
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
for i in range(ITER_TIME):
|
|
pred = pred_list[i].to(self.device_type).requires_grad_()
|
|
target = target_list[i].to(self.device_type)
|
|
|
|
# nll with dtensor
|
|
pred_dtensor = distribute_tensor(pred, device_mesh, shard_spec)
|
|
target_dtensor = distribute_tensor(target, device_mesh, shard_spec)
|
|
loss = criterion(pred_dtensor, target_dtensor)
|
|
loss.backward()
|
|
|
|
# nll with plain tensor
|
|
loss_gt = criterion(pred, target)
|
|
loss_gt.backward()
|
|
|
|
loss_diff_abs = loss.to_local() - loss_gt
|
|
loss_diff_rel = loss_diff_abs / (torch.abs(loss_gt) + 1e-8)
|
|
loss_mse_abs = torch.mean(loss_diff_abs * loss_diff_abs).item()
|
|
loss_mse_rel = torch.mean(loss_diff_rel * loss_diff_rel).item()
|
|
|
|
grad_diff_abs = pred_dtensor.grad.to_local() - pred.grad
|
|
grad_diff_rel = grad_diff_abs / (torch.abs(pred.grad) + 1e-8)
|
|
grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
|
|
grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
|
|
|
|
self.assertTrue(
|
|
loss_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for loss, expected less equal 1e-6, got {loss_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
loss_mse_rel <= 1e-6,
|
|
f"Too large relative mse for loss, expected less equal 1e-6, got {loss_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
grad_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
grad_mse_rel <= 1e-6,
|
|
f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|