pytorch/test/distributed/_tensor/test_convolution_ops.py
KingsleyLiu-NV cd2798943d [dtensor] support convolution ops (#113123)
This PR creates a prototype of training convolutional neural networks based on DTensor.

- Register required ops and implement operator dispatch
- Add unit tests and example

Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration).

Reference log (on 2xA100 GPU):

Unit Test
```bash
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_convolution_ops.py
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
..
----------------------------------------------------------------------
Ran 2 tests in 30.354s

OK
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_other_ops.py
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
...
----------------------------------------------------------------------
Ran 3 tests in 16.343s

OK
```
ConvNeXt Example
```bash
root@luna-prod-78-80gb:/pytorch# python3 torch/distributed/_tensor/examples/convnext_example.py
rank 3, 20 iterations, latency     584.80 ms, forward     102.84 ms, backward     297.80 ms, max reserved    16.34 GiB, max allocated    14.75 GiB
rank 1, 20 iterations, latency     584.64 ms, forward     104.85 ms, backward     297.60 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
rank 0, 20 iterations, latency     584.48 ms, forward     104.64 ms, backward     297.90 ms, max reserved    16.39 GiB, max allocated    14.75 GiB
rank 2, 20 iterations, latency     584.96 ms, forward      93.21 ms, backward     297.95 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
```

@wanchaol @fduwjj FYI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113123
Approved by: https://github.com/wanchaol
2023-11-20 21:01:28 +00:00

181 lines
6.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import copy
import torch
import torch.nn as nn
from torch.distributed._tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
ITER_TIME = 10
LR = 0.001
def _conv_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
) -> None:
for name, param in module.named_parameters():
dist_spec = [Replicate()]
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, dist_spec)
)
name = "_".join(name.split("."))
module.register_parameter(name, dist_param)
class DistConvolutionOpsTest(DTensorTestBase):
@property
def world_size(self) -> int:
# hard code world size to 2
return 2
@with_comms
def test_downsampling_convolution(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(3)]
input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024)
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
model = nn.Conv2d(3, 256, kernel_size=4, stride=4, padding=0).to(
self.device_type
)
nn.init.ones_(model.weight)
nn.init.zeros_(model.bias)
model_gt = copy.deepcopy(model).to(self.device_type)
# training with dtensor
model = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
output = model(inp_dtensor)
grad_output = grad_output_list[i].to(self.device_type)
grad_output_dtensor = distribute_tensor(
grad_output, device_mesh, shard_spec
)
output.backward(grad_output_dtensor)
optimizer.step()
# training with plain tensor
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer_gt.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
output = model_gt(inp)
grad_output = grad_output_list[i].to(self.device_type)
output.backward(grad_output)
optimizer_gt.step()
weight_diff_abs = model.weight.to_local() - model_gt.weight
bias_diff_abs = model.bias.to_local() - model_gt.bias
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
self.assertTrue(
weight_mse_abs <= 1e-6,
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
)
self.assertTrue(
bias_mse_abs <= 1e-6,
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
)
self.assertTrue(
weight_mse_rel <= 1e-6,
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
)
self.assertTrue(
bias_mse_rel <= 1e-6,
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
)
@with_comms
def test_depthwise_convolution(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(3)]
input_list = torch.rand(ITER_TIME, 7, 256, 128, 256)
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
model = nn.Conv2d(256, 256, kernel_size=7, padding=3, groups=256).to(
self.device_type
)
nn.init.ones_(model.weight)
nn.init.zeros_(model.bias)
model_gt = copy.deepcopy(model).to(self.device_type)
# training with dtensor
model = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
output = model(inp_dtensor)
grad_output = grad_output_list[i].to(self.device_type)
grad_output_dtensor = distribute_tensor(
grad_output, device_mesh, shard_spec
)
output.backward(grad_output_dtensor)
optimizer.step()
# training with plain tensor
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
for i in range(ITER_TIME):
optimizer_gt.zero_grad()
inp = input_list[i].to(self.device_type).requires_grad_()
output = model_gt(inp)
grad_output = grad_output_list[i].to(self.device_type)
output.backward(grad_output)
optimizer_gt.step()
weight_diff_abs = model.weight.to_local() - model_gt.weight
bias_diff_abs = model.bias.to_local() - model_gt.bias
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
self.assertTrue(
weight_mse_abs <= 1e-6,
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
)
self.assertTrue(
bias_mse_abs <= 1e-6,
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
)
self.assertTrue(
weight_mse_rel <= 1e-6,
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
)
self.assertTrue(
bias_mse_rel <= 1e-6,
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
)
if __name__ == "__main__":
run_tests()