Revert "Revert "[distributed] Handle object collectives and NCCL. (#79034)""

This reverts commit 279634f384.
This commit is contained in:
Nikita Shulga 2022-06-15 10:04:37 -07:00
parent d79f99c4b4
commit 09df27fe45
2 changed files with 154 additions and 50 deletions

View file

@ -0,0 +1,122 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
from functools import wraps, partial
import torch
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
TEST_SKIPS
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init()
func(self)
self.destroy_comms()
return wrapper
class TestObjectCollectives(MultiProcessTestCase):
def setUp(self):
super(TestObjectCollectives, self).setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
else torch.device("cpu")
@property
def world_size(self):
return WORLD_SIZE
@property
def process_group(self):
return dist.group.WORLD
def destroy_comms(self):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
dist.destroy_process_group()
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
@with_comms()
def test_all_gather_object(self):
output = [None] * dist.get_world_size()
dist.all_gather_object(
object_list=output,
obj=self.rank)
for i, v in enumerate(output):
self.assertEqual(i, v, f"rank: {self.rank}")
@with_comms()
def test_gather_object(self):
output = [None] * dist.get_world_size() if self.rank == 0 else None
dist.gather_object(
obj=self.rank,
object_gather_list=output)
if self.rank == 0:
for i, v in enumerate(output):
self.assertEqual(i, v, f"rank: {self.rank}")
@with_comms()
def test_broadcast_object_list(self):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
# TODO test with broadcast_object_list's device argument
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
@with_comms()
def test_scatter_object_list(self):
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
output_list = [None]
dist.scatter_object_list(
scatter_object_output_list=output_list,
scatter_object_input_list=input_list)
self.assertEqual(self.rank, output_list[0])
if __name__ == "__main__":
run_tests()

View file

@ -217,6 +217,16 @@ _group_count = 0
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
def _get_pg_device(group: ProcessGroup):
"""
Returns the device to use with ``group``.
This is cuda for NCCL and CPU for everything else
"""
if _check_for_nccl_backend(group):
return torch.device("cuda", torch.cuda.current_device())
return torch.device("cpu")
def _store_based_barrier(rank, store, timeout):
"""
Barrier based on store which is used for synchronizing processes after
@ -1554,19 +1564,20 @@ def all_gather_multigpu(
work.wait()
def _object_to_tensor(obj):
def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
byte_tensor = torch.ByteTensor(byte_storage).to(device)
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
return byte_tensor, local_size
def _tensor_to_object(tensor, tensor_size):
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
@ -1634,16 +1645,9 @@ def all_gather_object(object_list, obj, group=None):
_warn_not_in_group("all_gather_object")
return
input_tensor, local_size = _object_to_tensor(obj)
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
current_device = _get_pg_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
@ -1735,14 +1739,9 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# Ensure object_gather_list is specified appopriately.
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
input_tensor, local_size = _object_to_tensor(obj)
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
current_device = _get_pg_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
@ -1780,8 +1779,6 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
@ -1843,35 +1840,20 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
_warn_not_in_group("broadcast_object_list")
return
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
# Current device selection.
# To preserve backwards compatibility, ``device`` is default to ``None``
# in which case we run current logic of device selection, i.e.
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
# case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
is_nccl_backend = _check_for_nccl_backend(group)
current_device = None
if device is not None:
if is_nccl_backend and device.type != "cuda":
raise ValueError("device type must be cuda for nccl backend")
current_device = device
current_device = device or _get_pg_device(group)
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
# not necessarily true.
current_device = torch.device("cuda", torch.cuda.current_device())
if is_nccl_backend:
object_sizes_tensor = object_sizes_tensor.to(current_device)
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device)
# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)
@ -1883,10 +1865,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=current_device
)
if is_nccl_backend:
object_tensor = object_tensor.to(current_device)
broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
offset = 0
@ -1966,9 +1947,10 @@ def scatter_object_list(
)
my_rank = get_rank(group)
pg_device = _get_pg_device(group)
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[_object_to_tensor(obj) for obj in scatter_object_input_list]
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
)
tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
@ -1979,11 +1961,11 @@ def scatter_object_list(
for tensor in tensor_list:
tensor.resize_(max_tensor_size)
else:
max_tensor_size = torch.tensor([0], dtype=torch.long)
max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
broadcast(max_tensor_size, src=src, group=group)
# Scatter actual serialized objects
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8)
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
scatter(
output_tensor,
scatter_list=None if my_rank != src else tensor_list,
@ -1992,7 +1974,7 @@ def scatter_object_list(
)
# Scatter per-object sizes to trim tensors when deserializing back to object
obj_tensor_size = torch.tensor([0], dtype=torch.long)
obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
scatter(
obj_tensor_size,
scatter_list=None if my_rank != src else tensor_sizes,