Replace MPI Send and Recv with NCCL Send and Recv (#5054)

* Prototype NCCL P2P

* Clean code

* Fix NCCL path and some minor bugs

* Add path

* Fix path

* Try fix path

* Add missed files

* Address some comments

* Clean code

* Rename files

* Add MPI path back and fix a path

* Put MPI path under USE_NCCL flag

* not to build Send and Recv when MPI is not installed
This commit is contained in:
Wei-Sheng Chin 2020-09-09 09:39:56 -07:00 committed by GitHub
parent dbf4e7019d
commit 4ccca20def
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 651 additions and 23 deletions

View file

@ -94,6 +94,10 @@ if (onnxruntime_ENABLE_TRAINING)
if (onnxruntime_USE_HOROVOD)
target_include_directories(onnxruntime_graph PRIVATE ${HOROVOD_INCLUDE_DIRS})
endif()
if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_graph PRIVATE ${NCCL_INCLUDE_DIRS})
endif()
endif()
set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime")

View file

@ -191,11 +191,15 @@ if (onnxruntime_USE_CUDA)
file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc"
)
file(GLOB_RECURSE onnxruntime_cuda_training_ops_cu_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cu"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cuh"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cu"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cuh"
)
if (NOT onnxruntime_USE_HOROVOD)
@ -255,6 +259,10 @@ if (onnxruntime_USE_CUDA)
if (onnxruntime_USE_HOROVOD)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${HOROVOD_INCLUDE_DIRS})
endif()
if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${NCCL_INCLUDE_DIRS})
endif()
endif()
if (WIN32)

View file

@ -39,6 +39,7 @@ struct PipelineTask {
bool IsForward() const { return pass == Pass::Forward; }
bool IsBackward() const { return pass == Pass::Backward; }
bool IsCompute() const { return type == Type::Compute; }
bool IsCommute() const { return type == Type::Send || type == Type::Recv; }
bool IsSendTo(const int dst_rank) const {
if (type != Type::Send) {
return false;
@ -86,30 +87,33 @@ class PipelineSlot {
bool IsEmpty() const { return tasks_.empty(); };
size_t NumActions() const { return tasks_.size(); }
bool HasCompute() const {
for (auto& task : tasks_) {
if (task.IsCompute())
return true;
}
return false;
return std::any_of(
tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) {
return task.IsCompute();
});
}
bool HasCommute() const {
return std::any_of(
tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) {
return task.IsCommute();
});
}
bool HasRendTo(const int stage) const {
for (auto& task : tasks_) {
if (task.IsSendTo(stage)) {
return true;
}
}
return false;
return std::any_of(
tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) {
return task.IsSendTo(stage);
});
}
bool HasRecvFrom(const int stage) const {
for (auto& task : tasks_) {
if (task.IsRecvFrom(stage)) {
return true;
}
}
return false;
return std::any_of(
tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) {
return task.IsRecvFrom(stage);
});
}
PipelineTask& operator[](int index);
@ -125,6 +129,8 @@ class PipelineSlot {
void SetRecordedEvent(const std::vector<int> event);
std::vector<int> GetRecordedEvent() const;
std::vector<PipelineTask> GetTasks() { return tasks_; }
private:
// Actions which can be executed in parallel in this time slot.
std::vector<PipelineTask> tasks_;
@ -141,6 +147,18 @@ class PipelineScheduler {
size_t GetScheduleSize() const { return compute_commute_table_.size(); }
// Number of stages.
size_t GetStageSize() const { return num_stages_; }
std::vector<PipelineSlot> GetSchedule(const int stage_id) const {
std::vector<PipelineSlot> commute_slots;
for (int t = 0; static_cast<size_t>(t) < GetScheduleSize(); ++t) {
auto& slot = compute_commute_table_[t][stage_id];
if (!slot.HasCommute()) {
continue;
}
commute_slots.push_back(slot);
}
return commute_slots;
}
// APIs to get events for the following pattern.
// Wait -> Recv -> Record -> Wait -> Compute -> Record -> Wait -> Send -> Record.
// If no event exists, -1 may be returned.

View file

@ -20,6 +20,7 @@
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/graph/optimizer_graph_builder.h"
#include "orttraining/models/runner/training_util.h"
#include "orttraining/training_ops/cuda/communication/nccl_service.h"
#include "single_include/nlohmann/json.hpp"
#include "test/perftest/utils.h"
@ -705,6 +706,30 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
auto end_to_end_start = std::chrono::high_resolution_clock::now();
bool end_to_end_measurement_started = false;
#ifdef USE_NCCL
// NCCL-P2P
auto& nccl_service = cuda::NcclService::GetInstance();
nccl_service.PlanStart();
for (auto& slot : pipeline_schedule_.GetSchedule(pipeline_context_.pipeline_stage_id)) {
if (!slot.HasCommute()) {
continue;
}
nccl_service.PlanNewGroupStart();
for (auto& task : slot.GetTasks()) {
if (task.type == pipeline::PipelineTask::Type::Send) {
nccl_service.PlanSend(task.peer_rank);
} else if (task.type == pipeline::PipelineTask::Type::Recv) {
nccl_service.PlanRecv(task.peer_rank);
}
}
nccl_service.PlanNewGroupEnd();
}
nccl_service.PlanEnd();
nccl_service.Launch();
#endif
auto all_steps_time_start = std::chrono::high_resolution_clock::now();
while (step_ < params_.num_train_steps) {
for (size_t shard_it = 0; shard_it < num_shards_to_visit; ++shard_it) {
@ -754,6 +779,9 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
fetch_names,
fetches));
RunWithUpdate(feed_names, fetch_names, feeds, fetches);
#ifdef USE_NCCL
nccl_service.Reset();
#endif
} else {
ORT_RETURN_IF_ERROR(PrepareFeedNamesAndFeeds(GradientAccumulateStep,
training_data_loader,
@ -876,6 +904,9 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
<< "Average Step Time: " << all_steps_duration_seconds.count() / (step_ - step_start) << " Second\n"
<< "Average Step Throughput: " << params_.batch_size * (step_ - step_start) / (all_steps_duration_seconds.count()) << " Examples / Second\n";
#ifdef USE_NCCL
nccl_service.Terminate();
#endif
return Status::OK();
}

View file

@ -0,0 +1,346 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef USE_NCCL
#include "core/common/common.h"
#include "core/profile/context.h"
#include "core/providers/cuda/cuda_common.h"
#include "orttraining/core/framework/mpi_context.h"
#include "orttraining/training_ops/cuda/communication/nccl_service.h"
#include <iostream>
#include <nccl.h>
namespace onnxruntime {
namespace cuda {
bool NcclTask::Compare(const NcclTask& other) const {
if (type != other.type) {
return false;
}
if (peers.size() != other.peers.size()) {
return false;
}
for (size_t i = 0; i < peers.size(); ++i) {
if (peers[i] != other.peers[i]) {
return false;
}
}
return true;
}
void NcclTask::ResetTask() {
ptr = nullptr;
size = 0;
is_enqueued = false;
is_finished = false;
}
void NcclTaskGroup::PlanTask(
const NcclTask::Type type,
const std::vector<int> peers) {
batch.push_back({type, peers, nullptr, 0, false, false, ""});
};
const NcclTask* NcclTaskGroup::EqueueTask(
const NcclTask::Type type,
const std::vector<int> peers,
void* ptr,
const size_t size,
const std::string info) {
NcclTask scheduled_task;
scheduled_task.type = type;
scheduled_task.peers = peers;
for (auto& task : batch) {
if (!task.Compare(scheduled_task)) {
continue;
}
// We cannot enqueue the same task.
ORT_ENFORCE(!task.is_finished);
// We cannot enqueue the same task.
ORT_ENFORCE(!task.is_enqueued);
task.ptr = ptr;
task.size = size;
task.is_enqueued = true;
task.info = info;
return &task;
}
return nullptr;
};
bool NcclTaskGroup::IsAllTasksEqueued() const {
return std::all_of(
batch.begin(), batch.end(), [&](const NcclTask& task) {
return task.is_enqueued;
});
};
bool NcclTaskGroup::IsAllTasksFinished() const {
return std::all_of(
batch.begin(), batch.end(), [&](const NcclTask& task) {
return task.is_finished;
});
};
void NcclTaskGroup::ResetAllTasks() {
for (auto& task : batch) {
task.ptr = nullptr;
task.size = 0;
task.is_enqueued = false;
task.is_finished = false;
}
};
void NcclService::PlanStart() {
ORT_ENFORCE(!is_planned_, "Communication plan cannot be changed after calling PlanEnd.");
};
void NcclService::PlanEnd() {
is_planned_ = true;
};
void NcclService::PlanNewGroupStart() {
group_status_.push_back(true);
schedule_.push_back(NcclTaskGroup());
};
void NcclService::PlanNewGroupEnd() {
group_status_.back() = false;
};
void NcclService::PlanSend(const int dst) {
ORT_ENFORCE(group_status_.back(), "Last communication group can not be changed after call PlanEndNewGroup.");
schedule_.back().PlanTask(NcclTask::Type::SEND, {dst});
};
void NcclService::PlanRecv(const int src) {
ORT_ENFORCE(group_status_.back(), "Last communication group can not be changed after call PlanEndNewGroup.");
schedule_.back().PlanTask(NcclTask::Type::RECV, {src});
};
void NcclService::WaitForLaunch() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return is_running_; });
}
std::ostream& operator<<(std::ostream& stream, const NcclTaskGroup& task_group) {
for (int i = 0; static_cast<size_t>(i) < task_group.batch.size(); ++i) {
std::string line = " ";
auto& task = task_group.batch[i];
if (task.type == NcclTask::Type::SEND) {
line += "Send, ";
} else if (task.type == NcclTask::Type::RECV) {
line += "Recv, ";
}
for (int j = 0; static_cast<size_t>(j) < task.peers.size(); ++j) {
line += std::to_string(task.peers[j]);
if (static_cast<size_t>(j) != task.peers.size() - 1) {
line += ", ";
} else {
line += "\n";
}
}
stream << line;
}
return stream;
}
std::ostream& operator<<(std::ostream& stream, const NcclService& service) {
for (int i = 0; static_cast<size_t>(i) < service.schedule_.size(); ++i) {
stream << "NCCL operations at time " << i << std::endl;
stream << service.schedule_[i];
}
return stream;
}
int NcclService::FindNextCommunicationTime() const {
for (int i = 0; static_cast<size_t>(i) < schedule_.size(); ++i) {
if (schedule_[i].IsAllTasksEqueued() && !schedule_[i].IsAllTasksFinished()) {
return i;
}
}
return -1;
};
void NcclService::SubmitSendAndWait(void* ptr, size_t size, int peer) {
// Wait until NCCL service is launched.
WaitForLaunch();
auto& profile_context = profile::Context::GetInstance();
const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id());
// Pointer to enqueued task.
const NcclTask* task;
// Submit task.
{
std::lock_guard<std::mutex> guard(mutex_);
auto& profile_context = profile::Context::GetInstance();
const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id());
task = schedule_[time_].EqueueTask(NcclTask::Type::SEND, std::vector<int>{peer}, ptr, size, tag);
}
// Wait for task to be finished.
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return task->is_finished; });
}
};
void NcclService::SubmitRecvAndWait(void* ptr, size_t size, int peer) {
// Wait until NCCL service is launched.
WaitForLaunch();
// Pointer to euqueued task.
const NcclTask* task;
{
std::lock_guard<std::mutex> guard(mutex_);
auto& profile_context = profile::Context::GetInstance();
const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id());
task = schedule_[time_].EqueueTask(NcclTask::Type::RECV, std::vector<int>{peer}, ptr, size, tag);
}
// Wait for task to be finished.
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return task->is_finished; });
}
};
void NcclService::Initialize() {
// Here we assume GPU i is assigned to local process i.
// TODO: Create a general class to describe for computation topology and unify all similar uses.
// Hardware a process can own:
// GPUs
// CPUs
// Other devices
int mpi_rank;
int mpi_size;
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_size));
// Set device this NCCL communicator runs on.
CUDA_CALL(cudaSetDevice(mpi_rank));
// Get NCCL unique ID at rank 0 and broadcast it to all others.
ncclUniqueId id;
if (mpi_rank == 0) NCCL_CALL(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
NCCL_CALL(ncclCommInitRank(&comm_, mpi_size, id, mpi_rank));
}
void NcclService::Launch() {
worker_ = std::thread([this]() {
{
std::lock_guard<std::mutex> guard(mutex_);
ORT_ENFORCE(is_planned_, "NCCL service must know its communication plan before launching.");
// The NCCL service object can only be launched once because it's a
// singlton class.
ORT_ENFORCE(!is_running_, "NCCL service cannot be repeatedly launched.");
// Set this flag so that others will not call this again.
is_running_ = true;
cv_.notify_all();
}
Initialize();
while (is_running_) {
// Enter critical region.
// The state of this class cannot be modified by other threads.
{
std::lock_guard<std::mutex> guard(mutex_);
// All tasks must be ready with a valid time.
if (time_ > schedule_.size() - 1 ||
!schedule_[time_].IsAllTasksEqueued() ||
schedule_[time_].IsAllTasksFinished()) {
continue;
}
// Start NCCL parallel communication.
NCCL_CALL(ncclGroupStart());
for (auto& task : schedule_[time_].batch) {
ORT_ENFORCE(task.is_enqueued, "Unscheduled task cannot be run. Use SubmitSendAndWait or SubmitRecvAndWait to schedule tasks.");
switch (task.type) {
case NcclTask::Type::SEND:
ORT_ENFORCE(task.peers.size() == 1, "Send can only send data to one rank.");
NCCL_CALL(ncclSend(task.ptr, task.size, ncclChar, task.peers.front(), comm_, nullptr));
break;
case NcclTask::Type::RECV:
ORT_ENFORCE(task.peers.size() == 1, "Recv can only send data to one rank.");
NCCL_CALL(ncclRecv(task.ptr, task.size, ncclChar, task.peers.front(), comm_, nullptr));
break;
default:
ORT_NOT_IMPLEMENTED("NCCL service currently only support ncclSend and ncclRecv.");
}
task.is_finished = true;
}
NCCL_CALL(ncclGroupEnd());
// This round of communication is done.
// We can start waiting for the tasks to be fully scheduled.
++time_;
++total_time_;
}
cv_.notify_all();
}
});
}
void NcclService::Reset() {
WaitForLaunch();
{
std::unique_lock<std::mutex> lock(mutex_);
// We can only reset after all planned tasks are done,
// so wait for unfinished tasks here.
cv_.wait(lock, [this] {
bool is_all_tasks_finished = true;
for (auto& task_group : schedule_) {
if (task_group.IsAllTasksFinished()) {
continue;
}
is_all_tasks_finished = false;
};
return is_all_tasks_finished;
});
}
{
std::lock_guard<std::mutex> guard(mutex_);
time_ = 0;
// All scheduled communication tasks are done for finishing
// gradient accumulation steps + one model update step.
// To start next round of gradient accumulation and model update,
// we need to reset the "done" status of all tasks in the schedule.
for (auto& task_group : schedule_) {
task_group.ResetAllTasks();
}
cv_.notify_all();
}
}
void NcclService::Terminate() {
WaitForLaunch();
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return total_time_ > 0 && time_ == 0; });
}
is_running_ = false;
worker_.join();
NCCL_CALL(ncclCommDestroy(comm_));
}
} // namespace cuda
} // namespace onnxruntime
#endif

View file

@ -0,0 +1,185 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef USE_NCCL
#pragma once
#include <condition_variable>
#include <list>
#include <mutex>
#include <map>
#include <vector>
#include <iostream>
#include <string>
#include <thread>
#include <nccl.h>
#include <mpi.h>
namespace onnxruntime {
namespace cuda {
struct NcclTask final {
// Attributes for communication operator.
enum class Type { SEND,
RECV,
ALLREDUCE };
// Operator to perform.
Type type;
// For Send, this field is destination device's ID.
// For Recv, this field is source device's ID.
std::vector<int> peers;
// Attributes for memory location.
// GPU memory pointer.
void* ptr;
// Number of bytes to send/recv.
size_t size;
// Scheduler flag.
bool is_enqueued;
bool is_finished;
// Debug information
std::string info;
// Return true if the two operations are the same.
bool Compare(const NcclTask& other) const;
// Clear runtime information.
void ResetTask();
};
// A collection of independent communication operations.
struct NcclTaskGroup final {
// Schedule a communication operation in this group.
// We don't know the pointer to the actual data and other runtime information yet;
// runtime information is filled by calling EqunueTask(...).
void PlanTask(const NcclTask::Type type, const std::vector<int> peers);
// Fill in task's details.
const NcclTask* EqueueTask(
const NcclTask::Type type,
const std::vector<int> peers,
void* ptr,
const size_t size,
const std::string info);
bool IsAllTasksEqueued() const;
bool IsAllTasksFinished() const;
void ResetAllTasks();
friend std::ostream& operator<<(std::ostream& stream, const NcclTaskGroup& task_group);
std::vector<NcclTask> batch;
};
// The use of this class has two stages. First, the user needs to plan the communication operators.
// Second, when running a model, the user should submit tasks following the communication plan.
// Function names begin with "Plan" are used for creating communication plan. Function names begin
// with "Submit" asks this class to run the submitted task. Communication usually does not happen
// immediately after submitting a task. The actual communication time is decided by this class based on
// the communication plan.
//
// Below is an example of planning tasks. Notice that the communication operations in the same group are
// called in random order, so those operations cannot have mutual dependency.
//
// auto& nccl_service = cuda::NcclService::GetInstance();
//
// nccl_service.PlanStart(); // Signal the begin of communication planning.
//
// nccl_service.PlanStartNewGroup(); // Create new time slot.
// nccl_service.PlanSend(0);
// nccl_service.PlanRecv(1);
// nccl_service.PlanEndNewGroup(); // Mark the end of the first time slot.
//
// nccl_service.PlanStartNewGroup(); // Create the second time slot.
// nccl_service.PlanSend(1);
// nccl_service.PlanRecv(0);
// nccl_service.PlanEndNewGroup(); // Mark the end of the second time slot.
//
// nccl_service.EndPlan(); // Signal the end of communication planning.
class NcclService final {
public:
// Get the singleton of this class.
static NcclService& GetInstance() {
static NcclService instance_;
return instance_;
};
// Planning APIs. They are not thread-safe.
// Mark the start of entire plan.
void PlanStart();
// Mark the end of entire plan.
void PlanEnd();
// Mark the begin of a new communication group. It uses the latest time slot.
// Operations in a group can happen in random order.
void PlanNewGroupStart();
// Mark the end of the current communication group.
void PlanNewGroupEnd();
// Add Send to the current communication group.
void PlanSend(const int dst);
// Add Recv to the current communication group.
void PlanRecv(const int src);
// Runtime APIs. They are thread-safe.
// Launch NCCL service. It's an infinite loop which repeatedly calls corresponding NCCL
// when planned operators (e.g., Send and Recv) arrive.
void Launch();
// Submit a Send request with needed information such as tensor's address and number bytes to send.
void SubmitSendAndWait(void* buffer, size_t count, int peer);
// Submit a Recv request with needed information such as tensor's address and number bytes to recv.
void SubmitRecvAndWait(void* buffer, size_t count, int peer);
// Reset communication plan's status so that we can reuse the same communication plan for multiple
// model update steps.
void Reset();
// Terminate NCCL service.
void Terminate();
// Print debug string.
friend std::ostream& operator<<(std::ostream& stream, const NcclService& service);
private:
NcclService() = default;
~NcclService() = default;
NcclService(const NcclService&) = delete;
NcclService& operator=(const NcclService&) = delete;
// Initialization for running NCCL service.
void Initialize();
// Most member functions should start with a call to this function because
// they are valid only after NCCL service is launched.
void WaitForLaunch();
// Search the next unfinished communication group to work on.
int FindNextCommunicationTime() const;
// Mutex to gurantee thread-safe access to this class.
std::mutex mutex_;
// Conditional variable used to wait for the mutex.
std::condition_variable cv_;
// Stream for running NCCL.
ncclComm_t comm_;
// Indicates if NCCL service launched.
bool is_running_;
// Indicates if NCCL service has a plan, which must be true when calling Launch(...).
bool is_planned_;
// Pipeline stage.
size_t rank_;
size_t time_;
size_t total_time_;
// group_status_[t] indicates if the t-th group's plan is done. Once group_status_[t] is
// set to false, we can add communication operations to that group.
std::vector<bool> group_status_;
// schedule_[t] communication group at time t. Communication group at time t-1 must be
// finished before working on the group at time t. In other words, communication groups
// are stored in their actual time order.
std::vector<NcclTaskGroup> schedule_;
// Thread to asynchronously run Launc(...).
std::thread worker_;
};
} // namespace cuda
} // namespace onnxruntime
#endif

View file

@ -5,6 +5,7 @@
#include "orttraining/training_ops/cuda/communication/recv.h"
#include "orttraining/training_ops/cuda/communication/common.h"
#include "orttraining/training_ops/cuda/communication/nccl_service.h"
#include "core/profile/profile.h"
#include "core/profile/context.h"
#include "core/providers/cuda/cuda_common.h"
@ -72,15 +73,28 @@ void Recv::ReceiveData(
// count waiting time before setting up the actual communication.
recvRange.Begin();
#endif
#ifdef USE_NCCL
buffer = GetScratchBuffer<char>(aggregated_aligned_tensor_bytes);
#else
buffer = AllocateBufferOnCPUPinned<char>(static_cast<size_t>(aggregated_aligned_tensor_bytes));
#endif
CommInfo_t info_data{buffer.get(),
static_cast<int>(aggregated_aligned_tensor_bytes),
src,
static_cast<int>(tag_)};
// The following NCCL call is equivalent to the following MPI call. User can
// uncomment the MPI call to debug.
#ifdef USE_NCCL
auto& nccl_service = cuda::NcclService::GetInstance();
nccl_service.SubmitRecvAndWait(info_data.buffer, info_data.size, info_data.rank);
#else
MPI_CHECK(MPI_Recv(
info_data.buffer, info_data.size, MPI_CHAR,
info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE));
info_data.buffer, info_data.size, MPI_CHAR,
info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE));
#endif
#ifdef ENABLE_NVTX_PROFILE
// End of actual communication.
@ -103,13 +117,20 @@ void Recv::ReceiveData(
// Find the next aligned offset in the tensor buffer to meet alignment requirement
tensor_offset_in_bytes = GetAggregatedAlignedAddress(tensor_offset_in_bytes);
// Keep the sync copy in the previous design
// TODO they can be moved to async call after global stream becoming accessible
// Copy data out from buffer.
#ifdef USE_NCCL
CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes,
tensor->SizeInBytes(), cudaMemcpyHostToDevice));
#else
CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes,
tensor->SizeInBytes(), cudaMemcpyDeviceToDevice));
#endif
tensor_offset_in_bytes += tensor->SizeInBytes();
}
#ifndef USE_NCCL
AddDeferredReleaseCPUPtr(buffer.release());
#endif
#ifdef ENABLE_NVTX_PROFILE
// End of host-to-device copy.
@ -266,4 +287,4 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const {
} // namespace cuda
} // namespace onnxruntime
#endif
#endif

View file

@ -5,6 +5,7 @@
#include "orttraining/training_ops/cuda/communication/send.h"
#include "orttraining/training_ops/cuda/communication/common.h"
#include "orttraining/training_ops/cuda/communication/nccl_service.h"
#include "core/profile/profile.h"
#include "core/profile/context.h"
#include "core/providers/cuda/cuda_common.h"
@ -99,13 +100,22 @@ void Send::SendData(
memcpyRange.Begin();
#endif
#ifdef USE_NCCL
IAllocatorUniquePtr<char> buffer = GetScratchBuffer<char>(aggregated_aligned_tensor_bytes);
#else
IAllocatorUniquePtr<char> buffer = AllocateBufferOnCPUPinned<char>(
aggregated_aligned_tensor_bytes);
#endif
for (int i = 0; i < num_tensors; ++i) {
const Tensor* tensor = ctx->Input<Tensor>(i + 2);
#ifdef USE_NCCL
CUDA_CALL(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(),
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice));
#else
CUDA_CALL(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(),
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost));
#endif
}
#ifdef ENABLE_NVTX_PROFILE
@ -127,9 +137,14 @@ void Send::SendData(
dst,
static_cast<int>(tag_)};
#ifdef USE_NCCL
auto& nccl_service = cuda::NcclService::GetInstance();
nccl_service.SubmitSendAndWait(info_data.buffer, info_data.size, info_data.rank);
#else
MPI_CHECK(MPI_Send(
info_data.buffer, info_data.size, MPI_CHAR,
info_data.rank, info_data.tag, MPI_COMM_WORLD));
#endif
#ifdef ENABLE_NVTX_PROFILE
// End of major communication tasks.
@ -240,4 +255,4 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const {
} // namespace cuda
} // namespace onnxruntime
#endif
#endif