diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index d37028901e4..7053ad561e8 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -957,48 +957,6 @@ void all2all( using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); -#ifdef NCCL_ALLTOALLV_SUPPORTED - // NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv - // operations issued as a part of the collective (e.g. alltoallv) vs those - // inside traditional p2p operations. - TORCH_INTERNAL_ASSERT( - outputTensors.size() == inputTensors.size(), - "number of input tensors is not equal to number of output tensors"); - std::vector sendCounts(inputTensors.size()); - std::vector sendDisps(inputTensors.size()); - std::vector recvCounts(outputTensors.size()); - std::vector recvDisps(outputTensors.size()); - uintptr_t sendBase = reinterpret_cast(inputTensors[0].data_ptr()); - uintptr_t recvBase = reinterpret_cast(outputTensors[0].data_ptr()); - size_t dtypeSize = inputTensors.front().element_size(); - - for (const int r : c10::irange(outputTensors.size())) { - sendCounts[r] = inputTensors[r].numel(); - auto sendOffset = - reinterpret_cast(inputTensors[r].data_ptr()) - sendBase; - TORCH_INTERNAL_ASSERT( - sendOffset % dtypeSize == 0, - "sendOffset is not divisible by dtypeSize"); - sendDisps[r] = sendOffset / dtypeSize; - recvCounts[r] = outputTensors[r].numel(); - auto recvOffset = - reinterpret_cast(outputTensors[r].data_ptr()) - recvBase; - TORCH_INTERNAL_ASSERT( - recvOffset % dtypeSize == 0, - "recvOffset is not divisible by dtypeSize"); - recvDisps[r] = recvOffset / dtypeSize; - } - NCCL_CHECK(ncclAllToAllv( - inputTensors[0].data_ptr(), - sendCounts.data(), - sendDisps.data(), - outputTensors[0].data_ptr(), - recvCounts.data(), - recvDisps.data(), - to_nccl_data_type(inputTensors.front()), - comm, - stream.stream())); -#else NCCL_CHECK(ncclGroupStart()); for (const int r : c10::irange(static_cast(outputTensors.size()))) { at::Tensor& input = inputTensors[r]; @@ -1028,7 +986,6 @@ void all2all( #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif -#endif #else TORCH_CHECK(false, "all2all is only supported for NCCL lib version >= 2.7.0"); #endif