mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Revert "[distributed] Handle object collectives and NCCL. (#79034)""
This reverts commit 279634f384.
This commit is contained in:
parent
d79f99c4b4
commit
09df27fe45
2 changed files with 154 additions and 50 deletions
122
test/distributed/test_c10d_object_collectives.py
Normal file
122
test/distributed/test_c10d_object_collectives.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps, partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
TEST_SKIPS
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
||||
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
||||
|
||||
def with_comms(func=None):
|
||||
if func is None:
|
||||
return partial(
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
self.dist_init()
|
||||
func(self)
|
||||
self.destroy_comms()
|
||||
return wrapper
|
||||
|
||||
class TestObjectCollectives(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super(TestObjectCollectives, self).setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
os.environ["BACKEND"] = BACKEND
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
|
||||
else torch.device("cpu")
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return WORLD_SIZE
|
||||
|
||||
@property
|
||||
def process_group(self):
|
||||
return dist.group.WORLD
|
||||
|
||||
def destroy_comms(self):
|
||||
# Wait for all ranks to reach here before starting shutdown.
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def dist_init(self):
|
||||
dist.init_process_group(
|
||||
backend=BACKEND,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
init_method=f"file://{self.file_name}",
|
||||
)
|
||||
|
||||
# set device for nccl pg for collectives
|
||||
if BACKEND == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
@with_comms()
|
||||
def test_all_gather_object(self):
|
||||
output = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(
|
||||
object_list=output,
|
||||
obj=self.rank)
|
||||
|
||||
for i, v in enumerate(output):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@with_comms()
|
||||
def test_gather_object(self):
|
||||
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
||||
dist.gather_object(
|
||||
obj=self.rank,
|
||||
object_gather_list=output)
|
||||
|
||||
if self.rank == 0:
|
||||
for i, v in enumerate(output):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
|
||||
@with_comms()
|
||||
def test_broadcast_object_list(self):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
# TODO test with broadcast_object_list's device argument
|
||||
dist.broadcast_object_list(object_list=object_list)
|
||||
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_scatter_object_list(self):
|
||||
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
||||
output_list = [None]
|
||||
dist.scatter_object_list(
|
||||
scatter_object_output_list=output_list,
|
||||
scatter_object_input_list=input_list)
|
||||
|
||||
self.assertEqual(self.rank, output_list[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -217,6 +217,16 @@ _group_count = 0
|
|||
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
|
||||
|
||||
|
||||
def _get_pg_device(group: ProcessGroup):
|
||||
"""
|
||||
Returns the device to use with ``group``.
|
||||
This is cuda for NCCL and CPU for everything else
|
||||
"""
|
||||
if _check_for_nccl_backend(group):
|
||||
return torch.device("cuda", torch.cuda.current_device())
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _store_based_barrier(rank, store, timeout):
|
||||
"""
|
||||
Barrier based on store which is used for synchronizing processes after
|
||||
|
|
@ -1554,19 +1564,20 @@ def all_gather_multigpu(
|
|||
work.wait()
|
||||
|
||||
|
||||
def _object_to_tensor(obj):
|
||||
def _object_to_tensor(obj, device):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||
# Otherwise, it will casue 100X slowdown.
|
||||
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||
byte_tensor = torch.ByteTensor(byte_storage)
|
||||
local_size = torch.LongTensor([byte_tensor.numel()])
|
||||
byte_tensor = torch.ByteTensor(byte_storage).to(device)
|
||||
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
|
||||
return byte_tensor, local_size
|
||||
|
||||
|
||||
def _tensor_to_object(tensor, tensor_size):
|
||||
tensor = tensor.cpu()
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
return _unpickler(io.BytesIO(buf)).load()
|
||||
|
||||
|
|
@ -1634,16 +1645,9 @@ def all_gather_object(object_list, obj, group=None):
|
|||
_warn_not_in_group("all_gather_object")
|
||||
return
|
||||
|
||||
input_tensor, local_size = _object_to_tensor(obj)
|
||||
current_device = torch.device("cpu")
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
if is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in docstring.
|
||||
# We cannot simply use my_rank since rank == device is not necessarily
|
||||
# true.
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
current_device = _get_pg_device(group)
|
||||
input_tensor, local_size = _object_to_tensor(obj, current_device)
|
||||
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
# until the correct size when deserializing the tensors.
|
||||
group_size = get_world_size(group=group)
|
||||
|
|
@ -1735,14 +1739,9 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
|
|||
# Ensure object_gather_list is specified appopriately.
|
||||
my_rank = get_rank()
|
||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
|
||||
input_tensor, local_size = _object_to_tensor(obj)
|
||||
current_device = torch.device("cpu")
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
current_device = _get_pg_device(group)
|
||||
input_tensor, local_size = _object_to_tensor(obj, current_device)
|
||||
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
# until the correct size when deserializing the tensors.
|
||||
group_size = get_world_size(group=group)
|
||||
|
|
@ -1780,8 +1779,6 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
|
|||
return
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
if tensor.device != torch.device("cpu"):
|
||||
tensor = tensor.cpu()
|
||||
tensor_size = object_size_list[i]
|
||||
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
|
||||
|
||||
|
|
@ -1843,35 +1840,20 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
|
|||
_warn_not_in_group("broadcast_object_list")
|
||||
return
|
||||
|
||||
my_rank = get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
else:
|
||||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
|
||||
|
||||
# Current device selection.
|
||||
# To preserve backwards compatibility, ``device`` is default to ``None``
|
||||
# in which case we run current logic of device selection, i.e.
|
||||
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
|
||||
# case it is not ``None`` we move the size and object tensors to be
|
||||
# broadcasted to this device.
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
if device is not None:
|
||||
if is_nccl_backend and device.type != "cuda":
|
||||
raise ValueError("device type must be cuda for nccl backend")
|
||||
current_device = device
|
||||
current_device = device or _get_pg_device(group)
|
||||
my_rank = get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device) for obj in object_list])
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
else:
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in
|
||||
# docstring. We cannot simply use my_rank since rank == device is
|
||||
# not necessarily true.
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
if is_nccl_backend:
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device)
|
||||
|
||||
# Broadcast object sizes
|
||||
broadcast(object_sizes_tensor, src=src, group=group)
|
||||
|
|
@ -1883,10 +1865,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
|
|||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
device=current_device
|
||||
)
|
||||
|
||||
if is_nccl_backend:
|
||||
object_tensor = object_tensor.to(current_device)
|
||||
broadcast(object_tensor, src=src, group=group)
|
||||
# Deserialize objects using their stored sizes.
|
||||
offset = 0
|
||||
|
|
@ -1966,9 +1947,10 @@ def scatter_object_list(
|
|||
)
|
||||
|
||||
my_rank = get_rank(group)
|
||||
pg_device = _get_pg_device(group)
|
||||
if my_rank == src:
|
||||
tensor_list, tensor_sizes = zip(
|
||||
*[_object_to_tensor(obj) for obj in scatter_object_input_list]
|
||||
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
|
||||
)
|
||||
tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
|
||||
|
||||
|
|
@ -1979,11 +1961,11 @@ def scatter_object_list(
|
|||
for tensor in tensor_list:
|
||||
tensor.resize_(max_tensor_size)
|
||||
else:
|
||||
max_tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
|
||||
broadcast(max_tensor_size, src=src, group=group)
|
||||
|
||||
# Scatter actual serialized objects
|
||||
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8)
|
||||
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
|
||||
scatter(
|
||||
output_tensor,
|
||||
scatter_list=None if my_rank != src else tensor_list,
|
||||
|
|
@ -1992,7 +1974,7 @@ def scatter_object_list(
|
|||
)
|
||||
|
||||
# Scatter per-object sizes to trim tensors when deserializing back to object
|
||||
obj_tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
|
||||
scatter(
|
||||
obj_tensor_size,
|
||||
scatter_list=None if my_rank != src else tensor_sizes,
|
||||
|
|
|
|||
Loading…
Reference in a new issue