diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 9cbb7c5928..fec53eb749 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -94,6 +94,10 @@ if (onnxruntime_ENABLE_TRAINING) if (onnxruntime_USE_HOROVOD) target_include_directories(onnxruntime_graph PRIVATE ${HOROVOD_INCLUDE_DIRS}) endif() + + if (onnxruntime_USE_NCCL) + target_include_directories(onnxruntime_graph PRIVATE ${NCCL_INCLUDE_DIRS}) + endif() endif() set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 87d5a54be8..193f7bf1f3 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -191,11 +191,15 @@ if (onnxruntime_USE_CUDA) file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc" ) file(GLOB_RECURSE onnxruntime_cuda_training_ops_cu_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cu" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cuh" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cu" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cuh" ) if (NOT onnxruntime_USE_HOROVOD) @@ -255,6 +259,10 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_USE_HOROVOD) target_include_directories(onnxruntime_providers_cuda PRIVATE ${HOROVOD_INCLUDE_DIRS}) endif() + + if (onnxruntime_USE_NCCL) + target_include_directories(onnxruntime_providers_cuda PRIVATE ${NCCL_INCLUDE_DIRS}) + endif() endif() if (WIN32) diff --git a/orttraining/orttraining/models/runner/pipeline.h b/orttraining/orttraining/models/runner/pipeline.h index 18eb9ed5da..fcf6ba61c8 100644 --- a/orttraining/orttraining/models/runner/pipeline.h +++ b/orttraining/orttraining/models/runner/pipeline.h @@ -39,6 +39,7 @@ struct PipelineTask { bool IsForward() const { return pass == Pass::Forward; } bool IsBackward() const { return pass == Pass::Backward; } bool IsCompute() const { return type == Type::Compute; } + bool IsCommute() const { return type == Type::Send || type == Type::Recv; } bool IsSendTo(const int dst_rank) const { if (type != Type::Send) { return false; @@ -86,30 +87,33 @@ class PipelineSlot { bool IsEmpty() const { return tasks_.empty(); }; size_t NumActions() const { return tasks_.size(); } + bool HasCompute() const { - for (auto& task : tasks_) { - if (task.IsCompute()) - return true; - } - return false; + return std::any_of( + tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) { + return task.IsCompute(); + }); + } + + bool HasCommute() const { + return std::any_of( + tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) { + return task.IsCommute(); + }); } bool HasRendTo(const int stage) const { - for (auto& task : tasks_) { - if (task.IsSendTo(stage)) { - return true; - } - } - return false; + return std::any_of( + tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) { + return task.IsSendTo(stage); + }); } bool HasRecvFrom(const int stage) const { - for (auto& task : tasks_) { - if (task.IsRecvFrom(stage)) { - return true; - } - } - return false; + return std::any_of( + tasks_.begin(), tasks_.end(), [&](const PipelineTask& task) { + return task.IsRecvFrom(stage); + }); } PipelineTask& operator[](int index); @@ -125,6 +129,8 @@ class PipelineSlot { void SetRecordedEvent(const std::vector event); std::vector GetRecordedEvent() const; + std::vector GetTasks() { return tasks_; } + private: // Actions which can be executed in parallel in this time slot. std::vector tasks_; @@ -141,6 +147,18 @@ class PipelineScheduler { size_t GetScheduleSize() const { return compute_commute_table_.size(); } // Number of stages. size_t GetStageSize() const { return num_stages_; } + std::vector GetSchedule(const int stage_id) const { + std::vector commute_slots; + for (int t = 0; static_cast(t) < GetScheduleSize(); ++t) { + auto& slot = compute_commute_table_[t][stage_id]; + if (!slot.HasCommute()) { + continue; + } + commute_slots.push_back(slot); + } + return commute_slots; + } + // APIs to get events for the following pattern. // Wait -> Recv -> Record -> Wait -> Compute -> Record -> Wait -> Send -> Record. // If no event exists, -1 may be returned. diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index c623e5081c..084d74858b 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -20,6 +20,7 @@ #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/graph/optimizer_graph_builder.h" #include "orttraining/models/runner/training_util.h" +#include "orttraining/training_ops/cuda/communication/nccl_service.h" #include "single_include/nlohmann/json.hpp" #include "test/perftest/utils.h" @@ -705,6 +706,30 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad auto end_to_end_start = std::chrono::high_resolution_clock::now(); bool end_to_end_measurement_started = false; +#ifdef USE_NCCL + // NCCL-P2P + auto& nccl_service = cuda::NcclService::GetInstance(); + + nccl_service.PlanStart(); + for (auto& slot : pipeline_schedule_.GetSchedule(pipeline_context_.pipeline_stage_id)) { + if (!slot.HasCommute()) { + continue; + } + nccl_service.PlanNewGroupStart(); + for (auto& task : slot.GetTasks()) { + if (task.type == pipeline::PipelineTask::Type::Send) { + nccl_service.PlanSend(task.peer_rank); + } else if (task.type == pipeline::PipelineTask::Type::Recv) { + nccl_service.PlanRecv(task.peer_rank); + } + } + nccl_service.PlanNewGroupEnd(); + } + nccl_service.PlanEnd(); + + nccl_service.Launch(); +#endif + auto all_steps_time_start = std::chrono::high_resolution_clock::now(); while (step_ < params_.num_train_steps) { for (size_t shard_it = 0; shard_it < num_shards_to_visit; ++shard_it) { @@ -754,6 +779,9 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad fetch_names, fetches)); RunWithUpdate(feed_names, fetch_names, feeds, fetches); +#ifdef USE_NCCL + nccl_service.Reset(); +#endif } else { ORT_RETURN_IF_ERROR(PrepareFeedNamesAndFeeds(GradientAccumulateStep, training_data_loader, @@ -876,6 +904,9 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad << "Average Step Time: " << all_steps_duration_seconds.count() / (step_ - step_start) << " Second\n" << "Average Step Throughput: " << params_.batch_size * (step_ - step_start) / (all_steps_duration_seconds.count()) << " Examples / Second\n"; +#ifdef USE_NCCL + nccl_service.Terminate(); +#endif return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc new file mode 100644 index 0000000000..4a881c006d --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc @@ -0,0 +1,346 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_NCCL + +#include "core/common/common.h" +#include "core/profile/context.h" +#include "core/providers/cuda/cuda_common.h" +#include "orttraining/core/framework/mpi_context.h" +#include "orttraining/training_ops/cuda/communication/nccl_service.h" +#include +#include + +namespace onnxruntime { +namespace cuda { + +bool NcclTask::Compare(const NcclTask& other) const { + if (type != other.type) { + return false; + } + if (peers.size() != other.peers.size()) { + return false; + } + for (size_t i = 0; i < peers.size(); ++i) { + if (peers[i] != other.peers[i]) { + return false; + } + } + + return true; +} + +void NcclTask::ResetTask() { + ptr = nullptr; + size = 0; + is_enqueued = false; + is_finished = false; +} + +void NcclTaskGroup::PlanTask( + const NcclTask::Type type, + const std::vector peers) { + batch.push_back({type, peers, nullptr, 0, false, false, ""}); +}; + +const NcclTask* NcclTaskGroup::EqueueTask( + const NcclTask::Type type, + const std::vector peers, + void* ptr, + const size_t size, + const std::string info) { + NcclTask scheduled_task; + scheduled_task.type = type; + scheduled_task.peers = peers; + + for (auto& task : batch) { + if (!task.Compare(scheduled_task)) { + continue; + } + + // We cannot enqueue the same task. + ORT_ENFORCE(!task.is_finished); + // We cannot enqueue the same task. + ORT_ENFORCE(!task.is_enqueued); + + task.ptr = ptr; + task.size = size; + task.is_enqueued = true; + task.info = info; + return &task; + } + + return nullptr; +}; + +bool NcclTaskGroup::IsAllTasksEqueued() const { + return std::all_of( + batch.begin(), batch.end(), [&](const NcclTask& task) { + return task.is_enqueued; + }); +}; + +bool NcclTaskGroup::IsAllTasksFinished() const { + return std::all_of( + batch.begin(), batch.end(), [&](const NcclTask& task) { + return task.is_finished; + }); +}; + +void NcclTaskGroup::ResetAllTasks() { + for (auto& task : batch) { + task.ptr = nullptr; + task.size = 0; + task.is_enqueued = false; + task.is_finished = false; + } +}; + +void NcclService::PlanStart() { + ORT_ENFORCE(!is_planned_, "Communication plan cannot be changed after calling PlanEnd."); +}; + +void NcclService::PlanEnd() { + is_planned_ = true; +}; + +void NcclService::PlanNewGroupStart() { + group_status_.push_back(true); + schedule_.push_back(NcclTaskGroup()); +}; + +void NcclService::PlanNewGroupEnd() { + group_status_.back() = false; +}; + +void NcclService::PlanSend(const int dst) { + ORT_ENFORCE(group_status_.back(), "Last communication group can not be changed after call PlanEndNewGroup."); + + schedule_.back().PlanTask(NcclTask::Type::SEND, {dst}); +}; + +void NcclService::PlanRecv(const int src) { + ORT_ENFORCE(group_status_.back(), "Last communication group can not be changed after call PlanEndNewGroup."); + schedule_.back().PlanTask(NcclTask::Type::RECV, {src}); +}; + +void NcclService::WaitForLaunch() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return is_running_; }); +} + +std::ostream& operator<<(std::ostream& stream, const NcclTaskGroup& task_group) { + for (int i = 0; static_cast(i) < task_group.batch.size(); ++i) { + std::string line = " "; + auto& task = task_group.batch[i]; + if (task.type == NcclTask::Type::SEND) { + line += "Send, "; + } else if (task.type == NcclTask::Type::RECV) { + line += "Recv, "; + } + + for (int j = 0; static_cast(j) < task.peers.size(); ++j) { + line += std::to_string(task.peers[j]); + if (static_cast(j) != task.peers.size() - 1) { + line += ", "; + } else { + line += "\n"; + } + } + stream << line; + } + return stream; +} + +std::ostream& operator<<(std::ostream& stream, const NcclService& service) { + for (int i = 0; static_cast(i) < service.schedule_.size(); ++i) { + stream << "NCCL operations at time " << i << std::endl; + stream << service.schedule_[i]; + } + return stream; +} + +int NcclService::FindNextCommunicationTime() const { + for (int i = 0; static_cast(i) < schedule_.size(); ++i) { + if (schedule_[i].IsAllTasksEqueued() && !schedule_[i].IsAllTasksFinished()) { + return i; + } + } + return -1; +}; + +void NcclService::SubmitSendAndWait(void* ptr, size_t size, int peer) { + // Wait until NCCL service is launched. + WaitForLaunch(); + auto& profile_context = profile::Context::GetInstance(); + const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id()); + + // Pointer to enqueued task. + const NcclTask* task; + + // Submit task. + { + std::lock_guard guard(mutex_); + auto& profile_context = profile::Context::GetInstance(); + const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id()); + task = schedule_[time_].EqueueTask(NcclTask::Type::SEND, std::vector{peer}, ptr, size, tag); + } + + // Wait for task to be finished. + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return task->is_finished; }); + } +}; + +void NcclService::SubmitRecvAndWait(void* ptr, size_t size, int peer) { + // Wait until NCCL service is launched. + WaitForLaunch(); + + // Pointer to euqueued task. + const NcclTask* task; + { + std::lock_guard guard(mutex_); + auto& profile_context = profile::Context::GetInstance(); + const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id()); + task = schedule_[time_].EqueueTask(NcclTask::Type::RECV, std::vector{peer}, ptr, size, tag); + } + + // Wait for task to be finished. + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return task->is_finished; }); + } +}; + +void NcclService::Initialize() { + // Here we assume GPU i is assigned to local process i. + // TODO: Create a general class to describe for computation topology and unify all similar uses. + // Hardware a process can own: + // GPUs + // CPUs + // Other devices + int mpi_rank; + int mpi_size; + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank)); + MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_size)); + + // Set device this NCCL communicator runs on. + CUDA_CALL(cudaSetDevice(mpi_rank)); + + // Get NCCL unique ID at rank 0 and broadcast it to all others. + ncclUniqueId id; + if (mpi_rank == 0) NCCL_CALL(ncclGetUniqueId(&id)); + MPI_CHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + NCCL_CALL(ncclCommInitRank(&comm_, mpi_size, id, mpi_rank)); +} + +void NcclService::Launch() { + worker_ = std::thread([this]() { + { + std::lock_guard guard(mutex_); + ORT_ENFORCE(is_planned_, "NCCL service must know its communication plan before launching."); + // The NCCL service object can only be launched once because it's a + // singlton class. + ORT_ENFORCE(!is_running_, "NCCL service cannot be repeatedly launched."); + + // Set this flag so that others will not call this again. + is_running_ = true; + cv_.notify_all(); + } + + Initialize(); + + while (is_running_) { + // Enter critical region. + // The state of this class cannot be modified by other threads. + { + std::lock_guard guard(mutex_); + // All tasks must be ready with a valid time. + if (time_ > schedule_.size() - 1 || + !schedule_[time_].IsAllTasksEqueued() || + schedule_[time_].IsAllTasksFinished()) { + continue; + } + + // Start NCCL parallel communication. + NCCL_CALL(ncclGroupStart()); + for (auto& task : schedule_[time_].batch) { + ORT_ENFORCE(task.is_enqueued, "Unscheduled task cannot be run. Use SubmitSendAndWait or SubmitRecvAndWait to schedule tasks."); + switch (task.type) { + case NcclTask::Type::SEND: + ORT_ENFORCE(task.peers.size() == 1, "Send can only send data to one rank."); + NCCL_CALL(ncclSend(task.ptr, task.size, ncclChar, task.peers.front(), comm_, nullptr)); + break; + case NcclTask::Type::RECV: + ORT_ENFORCE(task.peers.size() == 1, "Recv can only send data to one rank."); + NCCL_CALL(ncclRecv(task.ptr, task.size, ncclChar, task.peers.front(), comm_, nullptr)); + break; + default: + ORT_NOT_IMPLEMENTED("NCCL service currently only support ncclSend and ncclRecv."); + } + task.is_finished = true; + } + NCCL_CALL(ncclGroupEnd()); + + // This round of communication is done. + // We can start waiting for the tasks to be fully scheduled. + ++time_; + ++total_time_; + } + cv_.notify_all(); + } + }); +} + +void NcclService::Reset() { + WaitForLaunch(); + { + std::unique_lock lock(mutex_); + + // We can only reset after all planned tasks are done, + // so wait for unfinished tasks here. + cv_.wait(lock, [this] { + bool is_all_tasks_finished = true; + for (auto& task_group : schedule_) { + if (task_group.IsAllTasksFinished()) { + continue; + } + is_all_tasks_finished = false; + }; + return is_all_tasks_finished; + }); + } + + { + std::lock_guard guard(mutex_); + time_ = 0; + + // All scheduled communication tasks are done for finishing + // gradient accumulation steps + one model update step. + // To start next round of gradient accumulation and model update, + // we need to reset the "done" status of all tasks in the schedule. + for (auto& task_group : schedule_) { + task_group.ResetAllTasks(); + } + + cv_.notify_all(); + } +} + +void NcclService::Terminate() { + WaitForLaunch(); + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return total_time_ > 0 && time_ == 0; }); + } + + is_running_ = false; + worker_.join(); + NCCL_CALL(ncclCommDestroy(comm_)); +} + +} // namespace cuda +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h new file mode 100644 index 0000000000..915197bf4f --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.h @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_NCCL + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace onnxruntime { +namespace cuda { + +struct NcclTask final { + // Attributes for communication operator. + enum class Type { SEND, + RECV, + ALLREDUCE }; + // Operator to perform. + Type type; + // For Send, this field is destination device's ID. + // For Recv, this field is source device's ID. + std::vector peers; + + // Attributes for memory location. + // GPU memory pointer. + void* ptr; + // Number of bytes to send/recv. + size_t size; + + // Scheduler flag. + bool is_enqueued; + bool is_finished; + + // Debug information + std::string info; + + // Return true if the two operations are the same. + bool Compare(const NcclTask& other) const; + + // Clear runtime information. + void ResetTask(); +}; + +// A collection of independent communication operations. +struct NcclTaskGroup final { + // Schedule a communication operation in this group. + // We don't know the pointer to the actual data and other runtime information yet; + // runtime information is filled by calling EqunueTask(...). + void PlanTask(const NcclTask::Type type, const std::vector peers); + // Fill in task's details. + const NcclTask* EqueueTask( + const NcclTask::Type type, + const std::vector peers, + void* ptr, + const size_t size, + const std::string info); + bool IsAllTasksEqueued() const; + bool IsAllTasksFinished() const; + void ResetAllTasks(); + friend std::ostream& operator<<(std::ostream& stream, const NcclTaskGroup& task_group); + std::vector batch; +}; + +// The use of this class has two stages. First, the user needs to plan the communication operators. +// Second, when running a model, the user should submit tasks following the communication plan. +// Function names begin with "Plan" are used for creating communication plan. Function names begin +// with "Submit" asks this class to run the submitted task. Communication usually does not happen +// immediately after submitting a task. The actual communication time is decided by this class based on +// the communication plan. +// +// Below is an example of planning tasks. Notice that the communication operations in the same group are +// called in random order, so those operations cannot have mutual dependency. +// +// auto& nccl_service = cuda::NcclService::GetInstance(); +// +// nccl_service.PlanStart(); // Signal the begin of communication planning. +// +// nccl_service.PlanStartNewGroup(); // Create new time slot. +// nccl_service.PlanSend(0); +// nccl_service.PlanRecv(1); +// nccl_service.PlanEndNewGroup(); // Mark the end of the first time slot. +// +// nccl_service.PlanStartNewGroup(); // Create the second time slot. +// nccl_service.PlanSend(1); +// nccl_service.PlanRecv(0); +// nccl_service.PlanEndNewGroup(); // Mark the end of the second time slot. +// +// nccl_service.EndPlan(); // Signal the end of communication planning. +class NcclService final { + public: + // Get the singleton of this class. + static NcclService& GetInstance() { + static NcclService instance_; + return instance_; + }; + + // Planning APIs. They are not thread-safe. + + // Mark the start of entire plan. + void PlanStart(); + // Mark the end of entire plan. + void PlanEnd(); + // Mark the begin of a new communication group. It uses the latest time slot. + // Operations in a group can happen in random order. + void PlanNewGroupStart(); + // Mark the end of the current communication group. + void PlanNewGroupEnd(); + // Add Send to the current communication group. + void PlanSend(const int dst); + // Add Recv to the current communication group. + void PlanRecv(const int src); + + // Runtime APIs. They are thread-safe. + + // Launch NCCL service. It's an infinite loop which repeatedly calls corresponding NCCL + // when planned operators (e.g., Send and Recv) arrive. + void Launch(); + // Submit a Send request with needed information such as tensor's address and number bytes to send. + void SubmitSendAndWait(void* buffer, size_t count, int peer); + // Submit a Recv request with needed information such as tensor's address and number bytes to recv. + void SubmitRecvAndWait(void* buffer, size_t count, int peer); + // Reset communication plan's status so that we can reuse the same communication plan for multiple + // model update steps. + void Reset(); + // Terminate NCCL service. + void Terminate(); + + // Print debug string. + friend std::ostream& operator<<(std::ostream& stream, const NcclService& service); + + private: + NcclService() = default; + ~NcclService() = default; + NcclService(const NcclService&) = delete; + NcclService& operator=(const NcclService&) = delete; + // Initialization for running NCCL service. + void Initialize(); + // Most member functions should start with a call to this function because + // they are valid only after NCCL service is launched. + void WaitForLaunch(); + // Search the next unfinished communication group to work on. + int FindNextCommunicationTime() const; + + // Mutex to gurantee thread-safe access to this class. + std::mutex mutex_; + // Conditional variable used to wait for the mutex. + std::condition_variable cv_; + + // Stream for running NCCL. + ncclComm_t comm_; + + // Indicates if NCCL service launched. + bool is_running_; + // Indicates if NCCL service has a plan, which must be true when calling Launch(...). + bool is_planned_; + // Pipeline stage. + size_t rank_; + + size_t time_; + size_t total_time_; + + // group_status_[t] indicates if the t-th group's plan is done. Once group_status_[t] is + // set to false, we can add communication operations to that group. + std::vector group_status_; + // schedule_[t] communication group at time t. Communication group at time t-1 must be + // finished before working on the group at time t. In other words, communication groups + // are stored in their actual time order. + std::vector schedule_; + // Thread to asynchronously run Launc(...). + std::thread worker_; +}; + +} // namespace cuda +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.cc b/orttraining/orttraining/training_ops/cuda/communication/recv.cc index 9b3b5685c6..540eac031d 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.cc @@ -5,6 +5,7 @@ #include "orttraining/training_ops/cuda/communication/recv.h" #include "orttraining/training_ops/cuda/communication/common.h" +#include "orttraining/training_ops/cuda/communication/nccl_service.h" #include "core/profile/profile.h" #include "core/profile/context.h" #include "core/providers/cuda/cuda_common.h" @@ -72,15 +73,28 @@ void Recv::ReceiveData( // count waiting time before setting up the actual communication. recvRange.Begin(); #endif + +#ifdef USE_NCCL + buffer = GetScratchBuffer(aggregated_aligned_tensor_bytes); +#else buffer = AllocateBufferOnCPUPinned(static_cast(aggregated_aligned_tensor_bytes)); +#endif + CommInfo_t info_data{buffer.get(), static_cast(aggregated_aligned_tensor_bytes), src, static_cast(tag_)}; + // The following NCCL call is equivalent to the following MPI call. User can + // uncomment the MPI call to debug. +#ifdef USE_NCCL + auto& nccl_service = cuda::NcclService::GetInstance(); + nccl_service.SubmitRecvAndWait(info_data.buffer, info_data.size, info_data.rank); +#else MPI_CHECK(MPI_Recv( - info_data.buffer, info_data.size, MPI_CHAR, - info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + info_data.buffer, info_data.size, MPI_CHAR, + info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); +#endif #ifdef ENABLE_NVTX_PROFILE // End of actual communication. @@ -103,13 +117,20 @@ void Recv::ReceiveData( // Find the next aligned offset in the tensor buffer to meet alignment requirement tensor_offset_in_bytes = GetAggregatedAlignedAddress(tensor_offset_in_bytes); - // Keep the sync copy in the previous design - // TODO they can be moved to async call after global stream becoming accessible + // Copy data out from buffer. +#ifdef USE_NCCL CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, tensor->SizeInBytes(), cudaMemcpyHostToDevice)); +#else + CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, + tensor->SizeInBytes(), cudaMemcpyDeviceToDevice)); +#endif tensor_offset_in_bytes += tensor->SizeInBytes(); } + +#ifndef USE_NCCL AddDeferredReleaseCPUPtr(buffer.release()); +#endif #ifdef ENABLE_NVTX_PROFILE // End of host-to-device copy. @@ -266,4 +287,4 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const { } // namespace cuda } // namespace onnxruntime -#endif +#endif \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.cc b/orttraining/orttraining/training_ops/cuda/communication/send.cc index ea90017590..25bc01460e 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/send.cc @@ -5,6 +5,7 @@ #include "orttraining/training_ops/cuda/communication/send.h" #include "orttraining/training_ops/cuda/communication/common.h" +#include "orttraining/training_ops/cuda/communication/nccl_service.h" #include "core/profile/profile.h" #include "core/profile/context.h" #include "core/providers/cuda/cuda_common.h" @@ -99,13 +100,22 @@ void Send::SendData( memcpyRange.Begin(); #endif +#ifdef USE_NCCL + IAllocatorUniquePtr buffer = GetScratchBuffer(aggregated_aligned_tensor_bytes); +#else IAllocatorUniquePtr buffer = AllocateBufferOnCPUPinned( aggregated_aligned_tensor_bytes); +#endif for (int i = 0; i < num_tensors; ++i) { const Tensor* tensor = ctx->Input(i + 2); +#ifdef USE_NCCL + CUDA_CALL(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(), + tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice)); +#else CUDA_CALL(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(), tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost)); +#endif } #ifdef ENABLE_NVTX_PROFILE @@ -127,9 +137,14 @@ void Send::SendData( dst, static_cast(tag_)}; +#ifdef USE_NCCL + auto& nccl_service = cuda::NcclService::GetInstance(); + nccl_service.SubmitSendAndWait(info_data.buffer, info_data.size, info_data.rank); +#else MPI_CHECK(MPI_Send( info_data.buffer, info_data.size, MPI_CHAR, info_data.rank, info_data.tag, MPI_COMM_WORLD)); +#endif #ifdef ENABLE_NVTX_PROFILE // End of major communication tasks. @@ -240,4 +255,4 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const { } // namespace cuda } // namespace onnxruntime -#endif +#endif \ No newline at end of file