[pytorch/ncclx] Remove Alltoallv specialization for PTD all_to_all (#145045)

Summary:
PTD all_to_all uses a list of tensors, while ncclAllToAllv (provided
by NCCLX and RCCL) assumes that a single contiguous buffer is used.
These are fundamentally mismatched.  The list of tensors might not be
contiguous or even ordered (buffer addresses might not be in
increasing order).

This patch removes the ncclAllToAllv specialization for PTD
all_to_all, and instead let's it directly call ncclSend/ncclRecv.

Co-authored by @pavanbalaji
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145045
Approved by: https://github.com/pavanbalaji, https://github.com/d4l3k, https://github.com/fduwjj, https://github.com/ezyang
This commit is contained in:
Will Constable 2025-01-16 20:03:34 -08:00 committed by PyTorch MergeBot
parent 07669ed960
commit 2859b11bdb

View file

@ -957,48 +957,6 @@ void all2all(
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
#ifdef NCCL_ALLTOALLV_SUPPORTED
// NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv
// operations issued as a part of the collective (e.g. alltoallv) vs those
// inside traditional p2p operations.
TORCH_INTERNAL_ASSERT(
outputTensors.size() == inputTensors.size(),
"number of input tensors is not equal to number of output tensors");
std::vector<size_t> sendCounts(inputTensors.size());
std::vector<size_t> sendDisps(inputTensors.size());
std::vector<size_t> recvCounts(outputTensors.size());
std::vector<size_t> recvDisps(outputTensors.size());
uintptr_t sendBase = reinterpret_cast<uintptr_t>(inputTensors[0].data_ptr());
uintptr_t recvBase = reinterpret_cast<uintptr_t>(outputTensors[0].data_ptr());
size_t dtypeSize = inputTensors.front().element_size();
for (const int r : c10::irange(outputTensors.size())) {
sendCounts[r] = inputTensors[r].numel();
auto sendOffset =
reinterpret_cast<uintptr_t>(inputTensors[r].data_ptr()) - sendBase;
TORCH_INTERNAL_ASSERT(
sendOffset % dtypeSize == 0,
"sendOffset is not divisible by dtypeSize");
sendDisps[r] = sendOffset / dtypeSize;
recvCounts[r] = outputTensors[r].numel();
auto recvOffset =
reinterpret_cast<uintptr_t>(outputTensors[r].data_ptr()) - recvBase;
TORCH_INTERNAL_ASSERT(
recvOffset % dtypeSize == 0,
"recvOffset is not divisible by dtypeSize");
recvDisps[r] = recvOffset / dtypeSize;
}
NCCL_CHECK(ncclAllToAllv(
inputTensors[0].data_ptr(),
sendCounts.data(),
sendDisps.data(),
outputTensors[0].data_ptr(),
recvCounts.data(),
recvDisps.data(),
to_nccl_data_type(inputTensors.front()),
comm,
stream.stream()));
#else
NCCL_CHECK(ncclGroupStart());
for (const int r : c10::irange(static_cast<int>(outputTensors.size()))) {
at::Tensor& input = inputTensors[r];
@ -1028,7 +986,6 @@ void all2all(
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#endif
#else
TORCH_CHECK(false, "all2all is only supported for NCCL lib version >= 2.7.0");
#endif