mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[pytorch/ncclx] Remove Alltoallv specialization for PTD all_to_all (#145045)
Summary: PTD all_to_all uses a list of tensors, while ncclAllToAllv (provided by NCCLX and RCCL) assumes that a single contiguous buffer is used. These are fundamentally mismatched. The list of tensors might not be contiguous or even ordered (buffer addresses might not be in increasing order). This patch removes the ncclAllToAllv specialization for PTD all_to_all, and instead let's it directly call ncclSend/ncclRecv. Co-authored by @pavanbalaji Pull Request resolved: https://github.com/pytorch/pytorch/pull/145045 Approved by: https://github.com/pavanbalaji, https://github.com/d4l3k, https://github.com/fduwjj, https://github.com/ezyang
This commit is contained in:
parent
07669ed960
commit
2859b11bdb
1 changed files with 0 additions and 43 deletions
|
|
@ -957,48 +957,6 @@ void all2all(
|
|||
using namespace torch::cuda::nccl::detail;
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
|
||||
#ifdef NCCL_ALLTOALLV_SUPPORTED
|
||||
// NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv
|
||||
// operations issued as a part of the collective (e.g. alltoallv) vs those
|
||||
// inside traditional p2p operations.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
outputTensors.size() == inputTensors.size(),
|
||||
"number of input tensors is not equal to number of output tensors");
|
||||
std::vector<size_t> sendCounts(inputTensors.size());
|
||||
std::vector<size_t> sendDisps(inputTensors.size());
|
||||
std::vector<size_t> recvCounts(outputTensors.size());
|
||||
std::vector<size_t> recvDisps(outputTensors.size());
|
||||
uintptr_t sendBase = reinterpret_cast<uintptr_t>(inputTensors[0].data_ptr());
|
||||
uintptr_t recvBase = reinterpret_cast<uintptr_t>(outputTensors[0].data_ptr());
|
||||
size_t dtypeSize = inputTensors.front().element_size();
|
||||
|
||||
for (const int r : c10::irange(outputTensors.size())) {
|
||||
sendCounts[r] = inputTensors[r].numel();
|
||||
auto sendOffset =
|
||||
reinterpret_cast<uintptr_t>(inputTensors[r].data_ptr()) - sendBase;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
sendOffset % dtypeSize == 0,
|
||||
"sendOffset is not divisible by dtypeSize");
|
||||
sendDisps[r] = sendOffset / dtypeSize;
|
||||
recvCounts[r] = outputTensors[r].numel();
|
||||
auto recvOffset =
|
||||
reinterpret_cast<uintptr_t>(outputTensors[r].data_ptr()) - recvBase;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
recvOffset % dtypeSize == 0,
|
||||
"recvOffset is not divisible by dtypeSize");
|
||||
recvDisps[r] = recvOffset / dtypeSize;
|
||||
}
|
||||
NCCL_CHECK(ncclAllToAllv(
|
||||
inputTensors[0].data_ptr(),
|
||||
sendCounts.data(),
|
||||
sendDisps.data(),
|
||||
outputTensors[0].data_ptr(),
|
||||
recvCounts.data(),
|
||||
recvDisps.data(),
|
||||
to_nccl_data_type(inputTensors.front()),
|
||||
comm,
|
||||
stream.stream()));
|
||||
#else
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (const int r : c10::irange(static_cast<int>(outputTensors.size()))) {
|
||||
at::Tensor& input = inputTensors[r];
|
||||
|
|
@ -1028,7 +986,6 @@ void all2all(
|
|||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
TORCH_CHECK(false, "all2all is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue