mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[c10d] Introduce ProcessGroupWrapper (#58224)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58224 Adds C++ implementation of ProcessGroupWrapper. It wraps an underlying ProcessGroup and does debug checks before dispatching the collective to the underlying pg. The design mostly follows https://github.com/pytorch/pytorch/issues/22071. Concretely, on each collective, we: 1. Verify op type consistency. This can help catch mismatched ops in the user application (i.e. allreduce on one rank and allgather on another) 2. Verify tensor shapes. This can help catch bugs where the tensor inputs are malformed, whereas normally in NCCL this would just lead to a hang. The shapes verification for allgather/allreduce_coalesced is omitted because they actually accept different shape tensors and don't error out. This is done through an abstraction called `CollectiveFingerPrint` which uses a helper process group to do the above verification. Concretely, we gather the data we need for each of the above checks into tensors, and allgather them, and verify their equivalence. Once all of this passes we simply dispatch the collective to the underlying pg. Added `ProcessGroupWrapperTest` in python to comprehensively test these changes. ghstack-source-id: 129735687 Test Plan: ci Reviewed By: zhaojuanmao Differential Revision: D28023981 fbshipit-source-id: 1defc203c5efa72ca0476ade0d1d8d05aacd4e64
This commit is contained in:
parent
c00eefb6c7
commit
cf395c0718
14 changed files with 767 additions and 12 deletions
|
|
@ -597,6 +597,124 @@ class SparseGradientModule(nn.Module):
|
|||
return F.softmax(self.embedding(x), dim=1)
|
||||
|
||||
|
||||
class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super(AbstractProcessGroupWrapperTest, self).setUp()
|
||||
# For Windows platform, Python does not support fork, change it to spawn here.
|
||||
if sys.platform == "win32":
|
||||
self._spawn_processes()
|
||||
else:
|
||||
self._fork_processes()
|
||||
|
||||
def _test_collective_hang(self, wrapper_pg, use_cuda=False):
|
||||
# All ranks besides 1 call allreduce and wrapper_pg should detect a hang
|
||||
# and report an issue with rank 1.
|
||||
faulty_rank = 1
|
||||
if self.rank != faulty_rank:
|
||||
tensor = torch.randn(20, 10)
|
||||
if use_cuda:
|
||||
tensor = tensor.to(self.rank)
|
||||
|
||||
if self.rank == 0:
|
||||
# Rank 0 reports faulty ranks
|
||||
err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
|
||||
else:
|
||||
err = "Please check rank 0 logs for faulty rank"
|
||||
with self.assertRaisesRegex(RuntimeError, err):
|
||||
wrapper_pg.allreduce([tensor])
|
||||
|
||||
def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
|
||||
tensor = torch.randn(20, 10)
|
||||
if use_cuda:
|
||||
tensor = tensor.to(self.rank)
|
||||
works = []
|
||||
# Run a few successful collectives
|
||||
for _ in range(10):
|
||||
work = wrapper_pg.allreduce([tensor])
|
||||
works.append(work)
|
||||
|
||||
for w in works:
|
||||
w.wait()
|
||||
|
||||
# Simulate mismatch: allreduce vs reduce.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mismatch between collective operation types"
|
||||
):
|
||||
if self.rank == 0:
|
||||
wrapper_pg.allreduce([tensor])
|
||||
else:
|
||||
wrapper_pg.reduce([tensor])
|
||||
|
||||
# Check additional mismatches
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mismatch between collective operation types"
|
||||
):
|
||||
if self.rank == 0:
|
||||
wrapper_pg.reduce([tensor])
|
||||
else:
|
||||
wrapper_pg.barrier()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mismatch between collective operation types"
|
||||
):
|
||||
scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
|
||||
scattered_tensor = torch.empty(4)
|
||||
if self.rank == 0:
|
||||
wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
|
||||
else:
|
||||
wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Mismatch between collective operation types"
|
||||
):
|
||||
if self.rank == 0:
|
||||
wrapper_pg.broadcast(tensor, 0)
|
||||
else:
|
||||
output_tensors = [
|
||||
torch.zeros_like(tensor) for _ in range(self.world_size)
|
||||
]
|
||||
wrapper_pg.allgather([output_tensors], [tensor])
|
||||
|
||||
def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
|
||||
wrapper_pg.barrier()
|
||||
dim = 2 if self.rank == 0 else 10
|
||||
tensor = torch.randn(20, dim)
|
||||
if use_cuda:
|
||||
tensor = tensor.to(self.rank)
|
||||
with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
|
||||
wrapper_pg.allreduce([tensor])
|
||||
# Check errors are raised when dimensionality of shapes is different
|
||||
tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
|
||||
if use_cuda:
|
||||
tensor = tensor.to(self.rank)
|
||||
with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
|
||||
wrapper_pg.allreduce([tensor])
|
||||
|
||||
# Check shape errors with scatter
|
||||
input = [
|
||||
torch.tensor(
|
||||
[self.rank] if self.rank == 0 else [self.rank, self.rank],
|
||||
device=self.rank if use_cuda else "cpu",
|
||||
)
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
outputs = [
|
||||
torch.tensor(
|
||||
[-1] if self.rank == 0 else [-1, -1],
|
||||
device=self.rank if use_cuda else "cpu",
|
||||
)
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
root_rank = 0
|
||||
opts = c10d.ScatterOptions()
|
||||
opts.rootRank = root_rank
|
||||
with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
|
||||
if self.rank == root_rank:
|
||||
wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
|
||||
else:
|
||||
wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
|
||||
|
||||
class AbstractDistributedDataParallelTest(object):
|
||||
def tearDown(self):
|
||||
# DistributedDataParallel test doesn't seem to call FileStore destructor
|
||||
|
|
|
|||
|
|
@ -37,7 +37,14 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_TSAN,
|
||||
)
|
||||
import test_c10d_common
|
||||
from test_c10d_common import LOOPBACK, gpus_for_rank, Task, ModuleForDdpCommHook, SparseGradientModule
|
||||
from test_c10d_common import (
|
||||
LOOPBACK,
|
||||
gpus_for_rank,
|
||||
Task,
|
||||
ModuleForDdpCommHook,
|
||||
SparseGradientModule,
|
||||
AbstractProcessGroupWrapperTest,
|
||||
)
|
||||
|
||||
|
||||
def simple_reduce_tests(rank, world_size):
|
||||
|
|
@ -194,6 +201,59 @@ class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
|
|||
def test_default_store_timeout_gloo(self):
|
||||
self._test_default_store_timeout("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
|
||||
)
|
||||
class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
|
||||
def setUp(self):
|
||||
super(ProcessGroupGlooWrapperTest, self).setUp()
|
||||
|
||||
def opts(self, threads=2, timeout=10.0):
|
||||
opts = c10d.ProcessGroupGloo._Options()
|
||||
opts._timeout = timeout
|
||||
opts._devices = [create_device(interface=LOOPBACK)]
|
||||
opts._threads = threads
|
||||
return opts
|
||||
|
||||
def _create_wrapper_pg(self, timeout=10.0):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="gloo", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
_pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(timeout=timeout))
|
||||
pg = c10d._create_process_group_wrapper(
|
||||
_pg,
|
||||
"unused",
|
||||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
return pg
|
||||
|
||||
def test_collective_hang(self):
|
||||
pg = self._create_wrapper_pg(timeout=2.0)
|
||||
self._test_collective_hang(pg)
|
||||
|
||||
def test_collectives_op_mismatch(self):
|
||||
pg = self._create_wrapper_pg()
|
||||
self._test_collectives_op_mismatch(pg)
|
||||
|
||||
def test_collective_shape_mismatch(self):
|
||||
pg = self._create_wrapper_pg()
|
||||
self._test_collective_shape_mismatch(pg)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_collectives_op_mismatch_cuda(self):
|
||||
pg = self._create_wrapper_pg()
|
||||
self._test_collectives_op_mismatch(pg, use_cuda=True)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_collective_shape_mismatch_cuda(self):
|
||||
pg = self._create_wrapper_pg()
|
||||
self._test_collective_shape_mismatch(pg, use_cuda=True)
|
||||
|
||||
@requires_gloo()
|
||||
@unittest.skipIf(
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from torch.nn.parallel import DistributedDataParallel
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_gloo,
|
||||
requires_nccl,
|
||||
requires_nccl_version,
|
||||
skip_if_lt_x_gpu,
|
||||
|
|
@ -45,7 +46,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_TSAN,
|
||||
)
|
||||
import test_c10d_common
|
||||
from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
|
||||
from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook, AbstractProcessGroupWrapperTest
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
|
|
@ -158,6 +159,65 @@ class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
|
|||
raise unittest.SkipTest("No GPUs available, skipping test")
|
||||
self._test_default_store_timeout("nccl")
|
||||
|
||||
@requires_gloo()
|
||||
@requires_nccl()
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
|
||||
)
|
||||
class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
|
||||
def setUp(self):
|
||||
self.num_gpus = torch.cuda.device_count()
|
||||
if self.num_gpus < 2:
|
||||
raise unittest.SkipTest("NCCL test requires 2+ GPUs")
|
||||
super(AbstractProcessGroupWrapperTest, self).setUp()
|
||||
self._spawn_processes()
|
||||
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use NCCL_BLOCKING_WAIT will test it as expected.
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
def _create_wrapper_pg(self, timeout=10.0):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store,
|
||||
timeout=timedelta(seconds=timeout)
|
||||
)
|
||||
_pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=timeout))
|
||||
pg = c10d._create_process_group_wrapper(
|
||||
_pg,
|
||||
"unused",
|
||||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
return pg
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_collective_hang(self):
|
||||
pg = self._create_wrapper_pg(timeout=2.0)
|
||||
self._test_collective_hang(pg)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_collectives_op_mismatch(self):
|
||||
wrapper_pg = self._create_wrapper_pg()
|
||||
self._test_collectives_op_mismatch(wrapper_pg, use_cuda=True)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_collective_shape_mismatch(self):
|
||||
wrapper_pg = self._create_wrapper_pg()
|
||||
self._test_collective_shape_mismatch(wrapper_pg, use_cuda=True)
|
||||
|
||||
|
||||
class ProcessGroupNCCLNoGPUTest(TestCase):
|
||||
MAIN_PROCESS_RANK = 0
|
||||
|
|
|
|||
|
|
@ -346,6 +346,13 @@ class ProcessGroupGloo(ProcessGroup):
|
|||
def create_default_device() -> Device: ...
|
||||
...
|
||||
|
||||
class _ProcessGroupWrapper(ProcessGroup):
|
||||
def __init__(
|
||||
self,
|
||||
pg: ProcessGroup,
|
||||
gloo_pg: ProcessGroupGloo
|
||||
): ...
|
||||
|
||||
class ProcessGroupNCCL(ProcessGroup):
|
||||
class Options: ...
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#ifdef USE_C10D_GLOO
|
||||
#include <c10d/ProcessGroupGloo.hpp>
|
||||
#include <c10d/ProcessGroupWrapper.hpp>
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
|
@ -1305,6 +1306,23 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
|||
py::arg("timeout") = kProcessGroupDefaultTimeout,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions);
|
||||
|
||||
// ProcessGroupWrapper is a wrapper pg that includes a helper gloo process
|
||||
// group. It can be used to validate collective calls across processes by
|
||||
// checking the op type and input tensor shapes.
|
||||
auto processGroupWrapper =
|
||||
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>(
|
||||
module, "_ProcessGroupWrapper", processGroup)
|
||||
.def(
|
||||
py::init([](const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroupGloo>&
|
||||
gloo_pg) {
|
||||
return c10::make_intrusive<::c10d::ProcessGroupWrapper>(
|
||||
pg, gloo_pg);
|
||||
}),
|
||||
py::arg("pg"),
|
||||
py::arg("gloo_pg"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
|
|
|||
|
|
@ -53,4 +53,8 @@ if is_available():
|
|||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||
# this.
|
||||
|
||||
from .distributed_c10d import _backend, _all_gather_base
|
||||
from .distributed_c10d import (
|
||||
_backend,
|
||||
_all_gather_base,
|
||||
_create_process_group_wrapper
|
||||
)
|
||||
|
|
|
|||
|
|
@ -50,12 +50,15 @@ except ImportError:
|
|||
|
||||
try:
|
||||
from torch._C._distributed_c10d import ProcessGroupGloo
|
||||
from torch._C._distributed_c10d import _ProcessGroupWrapper
|
||||
except ImportError:
|
||||
_GLOO_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
|
||||
|
||||
|
||||
# Some reduce ops are not supported by complex numbers and will result in an error.
|
||||
# We currently provide complex support to the distributed API by viewing
|
||||
|
|
@ -2599,6 +2602,26 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
|
|||
group_to_use = _get_default_group() if group is None else group
|
||||
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
|
||||
|
||||
def _create_process_group_wrapper(
|
||||
wrapped_pg: ProcessGroup,
|
||||
store_prefix: str,
|
||||
store: Store,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
timeout: timedelta = default_pg_timeout
|
||||
):
|
||||
# Create a separate prefix store for the helper process group.
|
||||
prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
|
||||
store = PrefixStore(prefix, store)
|
||||
helper_pg = ProcessGroupGloo(
|
||||
store,
|
||||
rank,
|
||||
world_size,
|
||||
timeout=timeout
|
||||
)
|
||||
# Wrap the underlying pg with ProcessGroupWrapper.
|
||||
wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
|
||||
return wrapped_pg
|
||||
|
||||
def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ if(USE_C10D_MPI)
|
|||
endif()
|
||||
|
||||
if(USE_C10D_GLOO)
|
||||
list(APPEND C10D_SRCS ProcessGroupGloo.cpp GlooDeviceFactory.cpp)
|
||||
list(APPEND C10D_SRCS ProcessGroupGloo.cpp GlooDeviceFactory.cpp ProcessGroupWrapper.cpp)
|
||||
list(APPEND C10D_LIBS gloo)
|
||||
if(USE_CUDA)
|
||||
list(APPEND C10D_LIBS gloo_cuda)
|
||||
|
|
@ -137,6 +137,7 @@ copy_header(sequence_num.hpp)
|
|||
if(USE_GLOO)
|
||||
copy_header(ProcessGroupGloo.hpp)
|
||||
copy_header(GlooDeviceFactory.hpp)
|
||||
copy_header(ProcessGroupWrapper.hpp)
|
||||
endif()
|
||||
if(NOT WIN32)
|
||||
copy_header(HashStore.hpp)
|
||||
|
|
|
|||
|
|
@ -2892,10 +2892,9 @@ void ProcessGroupGloo::setSequenceNumberForGroup() {
|
|||
}
|
||||
|
||||
uint64_t ProcessGroupGloo::getSequenceNumberForGroup() {
|
||||
TORCH_CHECK(
|
||||
sequenceNum_ != c10::nullopt,
|
||||
"Sequence number is not set for rank ",
|
||||
rank_);
|
||||
if (sequenceNum_ == c10::nullopt) {
|
||||
return 0;
|
||||
}
|
||||
return sequenceNum_->get();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -505,10 +505,9 @@ void ProcessGroupNCCL::setSequenceNumberForGroup() {
|
|||
}
|
||||
|
||||
uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() {
|
||||
TORCH_CHECK(
|
||||
sequenceNum_ != c10::nullopt,
|
||||
"Sequence number is not set for rank ",
|
||||
rank_);
|
||||
if (sequenceNum_ == c10::nullopt) {
|
||||
return 0;
|
||||
}
|
||||
return sequenceNum_->get();
|
||||
}
|
||||
|
||||
|
|
|
|||
322
torch/lib/c10d/ProcessGroupWrapper.cpp
Normal file
322
torch/lib/c10d/ProcessGroupWrapper.cpp
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <c10d/ProcessGroupGloo.hpp>
|
||||
#include <c10d/ProcessGroupWrapper.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
namespace {
|
||||
// A container for information about a particular collective, including optype
|
||||
// and input tensors (if applicable.)
|
||||
struct CollectiveFingerPrint {
|
||||
// Current collective's operation type.
|
||||
OpType op_type_;
|
||||
// Ref to input tensors, if given, of the collective. If given, shapes will be
|
||||
// checked across processes to ensure valid input into the collective.
|
||||
const std::vector<at::Tensor>& input_tensors_;
|
||||
explicit CollectiveFingerPrint(
|
||||
OpType op_type,
|
||||
const std::vector<at::Tensor>& input_tensors)
|
||||
: op_type_(op_type), input_tensors_(input_tensors) {}
|
||||
|
||||
// Verifies a given int is the same across processes.
|
||||
void verify_num(
|
||||
int64_t value,
|
||||
const c10::intrusive_ptr<ProcessGroup>& pg,
|
||||
const std::string& failureMsg) {
|
||||
auto tensor = at::full({1}, value, at::TensorOptions().dtype(at::kLong));
|
||||
std::vector<at::Tensor> tensors;
|
||||
tensors.reserve(pg->getSize());
|
||||
for (int i = 0; i < pg->getSize(); ++i) {
|
||||
tensors.emplace_back(at::zeros_like(tensor));
|
||||
}
|
||||
std::vector<std::vector<at::Tensor>> out_tensors({tensors});
|
||||
std::vector<at::Tensor> inp_tensors({tensor});
|
||||
pg->allgather(out_tensors, inp_tensors)->wait();
|
||||
std::unordered_set<int64_t> gathered;
|
||||
for (const auto& t : out_tensors[0]) {
|
||||
auto n = t.item().to<int64_t>();
|
||||
gathered.insert(n);
|
||||
if (gathered.size() > 1) {
|
||||
TORCH_CHECK(false, failureMsg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that shapes are consistent across processes.
|
||||
// shape_tensors_to_report should be specified as the tensors to report when a
|
||||
// shape inconsistency is found. This is not necessarily shape_tensors such as
|
||||
// in the case we are checking shape dimensionality.
|
||||
void verify_shapes(
|
||||
std::vector<at::Tensor> shape_tensors,
|
||||
std::vector<at::Tensor> shape_tensors_to_report,
|
||||
c10::intrusive_ptr<ProcessGroup>& pg) {
|
||||
std::vector<std::vector<at::Tensor>> output_tensors;
|
||||
output_tensors.reserve(shape_tensors.size());
|
||||
for (auto & tensor_shape : shape_tensors) {
|
||||
std::vector<at::Tensor> outputs;
|
||||
outputs.reserve(pg->getSize());
|
||||
for (int i = 0; i < pg->getSize(); ++i) {
|
||||
outputs.emplace_back(at::zeros_like(tensor_shape));
|
||||
}
|
||||
output_tensors.emplace_back(outputs);
|
||||
}
|
||||
// Allgather tensor shapes.
|
||||
pg->allgather(output_tensors, shape_tensors)->wait();
|
||||
// Verify equivalence
|
||||
for (int i = 0; i < output_tensors.size(); ++i) {
|
||||
auto world_tensor_shapes = output_tensors[i];
|
||||
auto reference_shape_tensor = shape_tensors[i];
|
||||
for (const auto& rank_tensor_shape : world_tensor_shapes) {
|
||||
if (!rank_tensor_shape.equal(reference_shape_tensor)) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str(
|
||||
"Error when verifying shape tensors for collective ",
|
||||
opTypeToString(op_type_),
|
||||
" on rank ",
|
||||
pg->getRank(),
|
||||
". This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: ",
|
||||
shape_tensors_to_report));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Executes and verifies the collective fingerprint.
|
||||
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
|
||||
// For collectives, all ranks should participate and call into them in the
|
||||
// same order. Verify the same operation type is being requested.
|
||||
int64_t op_type_int = static_cast<int64_t>(op_type_);
|
||||
verify_num(
|
||||
op_type_int,
|
||||
pg,
|
||||
c10::str(
|
||||
"Mismatch between collective operation types across ranks.",
|
||||
"This likely indicates an application bug where different ranks are ",
|
||||
"calling different collectives. ",
|
||||
"Rank ",
|
||||
pg->getRank(),
|
||||
" is calling collective: ",
|
||||
opTypeToString(op_type_)));
|
||||
// Retrieve input tensor shapes.
|
||||
std::vector<at::Tensor> shape_tensors =
|
||||
c10d::getTensorShapes(input_tensors_);
|
||||
// If input_tensors_ is empty we would get no shape tensors back, but still
|
||||
// do verification in case input_tensors_.empty() is
|
||||
// inconsistent across ranks. In this case, sub in a single zeros tensor and
|
||||
// ensure all ranks agree, because gloo pg does not allow collectives with
|
||||
// empty inputs.
|
||||
if (shape_tensors.size() == 0) {
|
||||
shape_tensors = {at::zeros(1)};
|
||||
}
|
||||
// Verify dimensionality of shapes. This catches errors where tensor shapes
|
||||
// have different dimensions such as torch.randn(2, 3) vs torch.randn(2, 3,
|
||||
// 4). If we did not do this step and instead proceeded directly with
|
||||
// verifying tensor shapes, we would have malformed input into allgather()
|
||||
// and crash with an unhelpful error.
|
||||
std::vector<at::Tensor> meta_shape_tensors =
|
||||
c10d::getTensorShapes(shape_tensors);
|
||||
|
||||
verify_shapes(
|
||||
meta_shape_tensors, /* shape_tensors_to_report= */ shape_tensors, pg);
|
||||
|
||||
// If all meta shapes are 0 then we can skip the below verification since
|
||||
// it is not possible that there would be a difference. This happens only
|
||||
// when the tensor wraps a single scalar.
|
||||
bool skip = true;
|
||||
for (auto & t : meta_shape_tensors) {
|
||||
if (t.item().to<int64_t>() != 0) {
|
||||
skip = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!skip) {
|
||||
verify_shapes(
|
||||
shape_tensors, /* shape_tensors_to_report= */ shape_tensors, pg);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ProcessGroupWrapper::ProcessGroupWrapper(
|
||||
c10::intrusive_ptr<ProcessGroup> pg,
|
||||
c10::intrusive_ptr<ProcessGroupGloo> glooPg)
|
||||
: ProcessGroup(pg->getRank(), pg->getSize()), pg_(pg), glooPg_(glooPg) {
|
||||
// Set the sequence number for the underlying process group.
|
||||
pg_->setSequenceNumberForGroup();
|
||||
}
|
||||
|
||||
const std::string ProcessGroupWrapper::getBackendName() const {
|
||||
return pg_->getBackendName();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
const BroadcastOptions& opts) {
|
||||
runCollectiveChecks(OpType::BARRIER, data);
|
||||
return pg_->broadcast(data, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allreduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
const AllreduceOptions& opts) {
|
||||
runCollectiveChecks(OpType::ALLREDUCE, data);
|
||||
return pg_->allreduce(data, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allreduce_coalesced(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const AllreduceCoalescedOptions& opts) {
|
||||
// NOTE: We don't enforce shape checking for allreduce_coalesced because
|
||||
// the implementation itself does not enforce it we have tests that use
|
||||
// inconsistent shapes, see python implementation in distributed_c10d for
|
||||
// details.
|
||||
runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {});
|
||||
return pg_->allreduce_coalesced(tensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::reduce(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const ReduceOptions& opts) {
|
||||
runCollectiveChecks(OpType::REDUCE, tensors);
|
||||
return pg_->reduce(tensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allgather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts) {
|
||||
runCollectiveChecks(OpType::ALLGATHER, inputTensors);
|
||||
return pg_->allgather(outputTensors, inputTensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::_allgather_base(
|
||||
at::Tensor& outputBuffer,
|
||||
at::Tensor& inputBuffer,
|
||||
const AllgatherOptions& opts) {
|
||||
std::vector<at::Tensor> inputTensors({inputBuffer});
|
||||
runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors);
|
||||
return pg_->_allgather_base(outputBuffer, inputBuffer, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::allgather_coalesced(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensorLists,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts) {
|
||||
// NOTE: We don't enforce shape checking for allgather_coalesced because
|
||||
// the implementation itself does not enforce it we have tests that use
|
||||
// inconsistent shapes, see python implementation in distributed_c10d for
|
||||
// details.
|
||||
runCollectiveChecks(OpType::ALLGATHER_COALESCED, {});
|
||||
return pg_->allgather_coalesced(outputTensorLists, inputTensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::gather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const GatherOptions& opts) {
|
||||
runCollectiveChecks(OpType::GATHER, inputTensors);
|
||||
return pg_->gather(outputTensors, inputTensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::scatter(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ScatterOptions& opts) {
|
||||
runCollectiveChecks(OpType::SCATTER, outputTensors);
|
||||
return pg_->scatter(outputTensors, inputTensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::reduce_scatter(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ReduceScatterOptions& opts) {
|
||||
runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors);
|
||||
return pg_->reduce_scatter(outputTensors, inputTensors, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::alltoall_base(
|
||||
at::Tensor& outputTensor,
|
||||
at::Tensor& inputTensor,
|
||||
std::vector<int64_t>& outputSplitSizes,
|
||||
std::vector<int64_t>& inputSplitSizes,
|
||||
const AllToAllOptions& opts) {
|
||||
// alltoall supports uneven split, so don't enforce shape checking.
|
||||
runCollectiveChecks(OpType::ALLTOALL_BASE, {});
|
||||
return pg_->alltoall_base(
|
||||
outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::alltoall(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllToAllOptions& opts) {
|
||||
// alltoall supports uneven split, so don't enforce shape checking.
|
||||
runCollectiveChecks(OpType::ALLTOALL, {});
|
||||
return pg_->alltoall(outputTensors, inputTensors, opts);
|
||||
}
|
||||
|
||||
void ProcessGroupWrapper::monitoredBarrier(
|
||||
const BarrierOptions& opts,
|
||||
bool waitAllRanks) {
|
||||
return pg_->monitoredBarrier(opts, waitAllRanks);
|
||||
}
|
||||
|
||||
void ProcessGroupWrapper::setSequenceNumberForGroup() {
|
||||
// Set underlying pg's sequence number if it is not set.
|
||||
if (pg_->getSequenceNumberForGroup() == 0) {
|
||||
// Set the sequence number for the underlying process group.
|
||||
pg_->setSequenceNumberForGroup();
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() {
|
||||
return pg_->getSequenceNumberForGroup();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::send(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int dstRank,
|
||||
int tag) {
|
||||
return pg_->send(tensors, dstRank, tag);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::recv(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int srcRank,
|
||||
int tag) {
|
||||
return pg_->recv(tensors, srcRank, tag);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::recvAnysource(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int tag) {
|
||||
return pg_->recvAnysource(tensors, tag);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupWrapper::barrier(
|
||||
const BarrierOptions& opts) {
|
||||
runCollectiveChecks(OpType::BARRIER, {});
|
||||
return pg_->barrier(opts);
|
||||
}
|
||||
|
||||
void ProcessGroupWrapper::runCollectiveChecks(
|
||||
OpType op_type,
|
||||
const std::vector<at::Tensor>& tensors) const {
|
||||
// first perform a monitored barrier to ensure all ranks can synchronize.
|
||||
c10d::BarrierOptions options;
|
||||
// TODO: we should use wrapped pg_'s timeout here, but C++ ProcessGroup API
|
||||
// does not expose timeout.
|
||||
glooPg_->monitoredBarrier(options, /* waitAllRanks */ true);
|
||||
auto finger_print = CollectiveFingerPrint(op_type, tensors);
|
||||
// Will throw if an ill-formed collective is detected.
|
||||
finger_print.verify(glooPg_);
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
126
torch/lib/c10d/ProcessGroupWrapper.hpp
Normal file
126
torch/lib/c10d/ProcessGroupWrapper.hpp
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <c10d/ProcessGroupGloo.hpp>
|
||||
#include <c10d/Types.hpp>
|
||||
#include <c10d/Utils.hpp>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
class ProcessGroupWrapper : public ProcessGroup {
|
||||
public:
|
||||
explicit ProcessGroupWrapper(
|
||||
c10::intrusive_ptr<ProcessGroup> pg,
|
||||
c10::intrusive_ptr<ProcessGroupGloo> glooPg);
|
||||
|
||||
const std::string getBackendName() const override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
const BroadcastOptions& opts = BroadcastOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
const AllreduceOptions& opts = AllreduceOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const AllreduceCoalescedOptions& opts =
|
||||
AllreduceCoalescedOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> reduce(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const ReduceOptions& opts = ReduceOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> allgather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
|
||||
at::Tensor& outputBuffer,
|
||||
at::Tensor& inputBuffer,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||
|
||||
// This function is deprecated and will be moved out of ProcessGroup to comms:
|
||||
// * do not add dependencies on this function,
|
||||
// * do not implement it in your ProcessGroup, implement _allgather_base
|
||||
// instead.
|
||||
c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensorLists,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> gather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const GatherOptions& opts = GatherOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> scatter(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ScatterOptions& opts = ScatterOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
|
||||
at::Tensor& outputTensor,
|
||||
at::Tensor& inputTensor,
|
||||
std::vector<int64_t>& outputSplitSizes,
|
||||
std::vector<int64_t>& inputSplitSizes,
|
||||
const AllToAllOptions& opts = AllToAllOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllToAllOptions& opts = AllToAllOptions()) override;
|
||||
|
||||
void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false)
|
||||
override;
|
||||
|
||||
// Agrees on an initial sequence number for the whole group by having rank 0
|
||||
// create it and broadcast it to other ranks using the store. Only implemented
|
||||
// for GLOO and NCCL backends currently.
|
||||
// dont implement this
|
||||
void setSequenceNumberForGroup() override;
|
||||
|
||||
// Retrieves the current sequence number for the whole group, which should be
|
||||
// in sync. If the returned number is not consistent across the group, it
|
||||
// may indicate that there is some sort of collective desynchronization.
|
||||
uint64_t getSequenceNumberForGroup() override; // just call underlying
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> send(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int dstRank,
|
||||
int tag) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> recv(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int srcRank,
|
||||
int tag) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int tag) override;
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> barrier(
|
||||
const BarrierOptions& opts = BarrierOptions()) override;
|
||||
|
||||
private:
|
||||
// Underlying process group that actual application collectives will be
|
||||
// dispatched to
|
||||
c10::intrusive_ptr<ProcessGroup> pg_;
|
||||
// Gloo process group responsible for internal coordination such as monitored
|
||||
// barrier, sequence number checking, collective fingerprint collecting.
|
||||
c10::intrusive_ptr<ProcessGroupGloo> glooPg_;
|
||||
// Conducts several checks to ensure that the underlying collective is well
|
||||
// formed with the goal of notifying the user about incorrect collective use
|
||||
// in the application.
|
||||
void runCollectiveChecks(
|
||||
OpType op_type,
|
||||
const std::vector<at::Tensor>& tensors) const;
|
||||
};
|
||||
} // namespace c10d
|
||||
|
|
@ -71,6 +71,21 @@ namespace c10d {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> getTensorShapes(const std::vector<at::Tensor>& tensors) {
|
||||
std::vector<at::Tensor> shapeTensors;
|
||||
shapeTensors.reserve(tensors.size());
|
||||
for (const auto& tensor : tensors) {
|
||||
auto shapesVec = tensor.sizes().vec();
|
||||
int64_t shapes_size = shapesVec.size();
|
||||
// Need to clone here otherwise the shapesVec.data() memory is not copied
|
||||
// and can be released under the hood.
|
||||
at::Tensor shapesTensor = at::from_blob(
|
||||
shapesVec.data(), {shapes_size}, at::TensorOptions().dtype(at::kLong)).clone();
|
||||
shapeTensors.emplace_back(std::move(shapesTensor));
|
||||
}
|
||||
return shapeTensors;
|
||||
}
|
||||
|
||||
|
||||
namespace tcputil {
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ std::string parse_env(const char* env_var_name);
|
|||
|
||||
DistributedDebugLevel parseDistDebugLevel();
|
||||
|
||||
// Retrieve tensor shapes from a given tensor.
|
||||
std::vector<at::Tensor> getTensorShapes(const std::vector<at::Tensor>& tensors);
|
||||
|
||||
// Turns at::IntArrayRef into "(1, 2, 3, 4)".
|
||||
inline std::string toString(at::IntArrayRef l) {
|
||||
std::stringstream ss;
|
||||
|
|
|
|||
Loading…
Reference in a new issue