mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
- Pass `group_name` to `CUDASymmetricMemory::alloc()` instead of `CUDASymmetricMemory::rendezvous()`. We can only move the argument to rendezvous() once all the underlying operators do the same. - Added `float` to the allowlist for intra-node all-reduces. - Added a warning when `IntraNodeComm::rendezvous()` is performed with overlapping devices among participants. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141200 Approved by: https://github.com/weifengpy, https://github.com/kwen2501
124 lines
3.8 KiB
Text
124 lines
3.8 KiB
Text
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
|
|
|
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
|
|
|
namespace c10d {
|
|
namespace intra_node_comm {
|
|
|
|
static constexpr size_t kOneShotThreshBytes = 256 * 1024;
|
|
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
|
|
|
|
static void checkInput(const at::Tensor& input, int deviceIdx) {
|
|
TORCH_CHECK(
|
|
input.dtype() == at::kBFloat16 || input.dtype() == at::kFloat,
|
|
"oneShotAllReduce only supports float and bf16 for now");
|
|
TORCH_CHECK(input.is_non_overlapping_and_dense());
|
|
TORCH_CHECK(input.device().is_cuda());
|
|
TORCH_CHECK(
|
|
input.get_device() == deviceIdx,
|
|
"IntraNodeComm: expect input to be on device ",
|
|
deviceIdx,
|
|
", got device ",
|
|
input.get_device());
|
|
}
|
|
|
|
bool isIntraNodeCommSupported() {
|
|
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
|
return false;
|
|
#else
|
|
return true;
|
|
#endif
|
|
}
|
|
|
|
at::Tensor IntraNodeComm::oneShotAllReduce(
|
|
const at::Tensor& input,
|
|
at::cuda::CUDAStream& stream) {
|
|
checkInput(input, deviceIdx_);
|
|
|
|
auto op = c10::Dispatcher::singleton()
|
|
.findSchemaOrThrow("symm_mem::one_shot_all_reduce_out", "")
|
|
.typed<at::Tensor(
|
|
const at::Tensor&, std::string, std::string, at::Tensor)>();
|
|
|
|
auto symmMemTensor = at::from_blob(
|
|
symmetricMemoryPtr_,
|
|
input.sizes(),
|
|
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
|
|
|
symmMemTensor.copy_(input);
|
|
op.call(symmMemTensor, "sum", "", input);
|
|
return input;
|
|
}
|
|
|
|
at::Tensor IntraNodeComm::twoShotAllReduce(
|
|
const at::Tensor& input,
|
|
at::cuda::CUDAStream& stream) {
|
|
checkInput(input, deviceIdx_);
|
|
|
|
auto op = c10::Dispatcher::singleton()
|
|
.findSchemaOrThrow("symm_mem::two_shot_all_reduce_", "")
|
|
.typed<at::Tensor(at::Tensor, std::string, std::string)>();
|
|
|
|
auto symmMemTensor = at::from_blob(
|
|
symmetricMemoryPtr_,
|
|
input.sizes(),
|
|
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
|
|
|
symmMemTensor.copy_(input);
|
|
op.call(symmMemTensor, "sum", "");
|
|
input.copy_(symmMemTensor);
|
|
return input;
|
|
}
|
|
|
|
AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
|
|
// Only support float and bf16 for now
|
|
if (input.dtype() != at::kBFloat16 && input.dtype() != at::kFloat) {
|
|
return AllReduceAlgo::NONE;
|
|
}
|
|
const auto inputSize =
|
|
static_cast<size_t>(input.numel() * input.element_size());
|
|
const size_t ptrAlignment = get_alignment(
|
|
static_cast<size_t>(input.storage_offset() * input.element_size()));
|
|
const size_t sizeAlignment = get_alignment(inputSize);
|
|
const size_t alignment = std::min(ptrAlignment, sizeAlignment);
|
|
|
|
if (topology_ == Topology::FULLY_CONNECTED) {
|
|
// Both symm_mem::one_shot_all_reduce and symm_mem::two_shot_all_reduce_
|
|
// currently requires the input to be at least 4-bytes aligned.
|
|
if (alignment >= 4 && inputSize <= kOneShotThreshBytes &&
|
|
inputSize <= bufferSize_) {
|
|
return AllReduceAlgo::ONE_SHOT;
|
|
}
|
|
if (alignment >= 4 && inputSize <= kTwoShotThreshBytes &&
|
|
inputSize <= bufferSize_) {
|
|
return AllReduceAlgo::TWO_SHOT;
|
|
}
|
|
}
|
|
return AllReduceAlgo::NONE;
|
|
}
|
|
|
|
static int64_t usageCounter = 0;
|
|
|
|
at::Tensor IntraNodeComm::allReduce(
|
|
const at::Tensor& input,
|
|
AllReduceAlgo algo) {
|
|
// Report usage for testing purposes.
|
|
// We don't care about overflowing.
|
|
++usageCounter;
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
switch (algo) {
|
|
case AllReduceAlgo::ONE_SHOT:
|
|
return oneShotAllReduce(input, stream);
|
|
case AllReduceAlgo::TWO_SHOT:
|
|
return twoShotAllReduce(input, stream);
|
|
default:
|
|
C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
|
|
}
|
|
}
|
|
|
|
int64_t getIntraNodeCommUsageCounter() {
|
|
return usageCounter;
|
|
}
|
|
|
|
} // namespace intra_node_comm
|
|
} // namespace c10d
|