From 66babe4ed341d3b07dc5715afbdf2cdb11a92019 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 11 May 2021 20:16:42 -0700 Subject: [PATCH] MPI Kernels for Training --- .../core/framework/provider_bridge_ort.cc | 4 + .../providers/shared_library/provider_api.h | 19 ++++ .../shared_library/provider_interfaces.h | 17 ++++ .../core/framework/adasum/adasum_interface.h | 33 +++---- .../framework/communication/mpi/mpi_context.h | 98 ++++++++++--------- .../communication/mpi/mpi_utilities.h | 5 +- .../orttraining/core/graph/optimizer_config.h | 2 + .../orttraining/models/pipeline_poc/main.cc | 11 ++- .../training_ops/communication_common.h | 4 +- .../cuda/collective/adasum_kernels.cc | 20 ++-- .../cuda/collective/adasum_kernels.h | 24 ++--- .../training_ops/cuda/communication/recv.cc | 20 ++-- .../training_ops/cuda/communication/send.cc | 14 +-- 13 files changed, 162 insertions(+), 109 deletions(-) diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index c22ac30b6b..e828a385d5 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -631,16 +631,20 @@ struct ProviderHostImpl : ProviderHost { // OpKernelContext (wrapped) const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) override { return p->Input(index); } const Tensor& OpKernelContext__RequiredInput_Tensor(const OpKernelContext* p, int index) override { return p->RequiredInput(index); } + Tensor* OpKernelContext__Output_Tensor(OpKernelContext* p, int index) override { return p->Output(index); } Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) override { return p->Output(index, shape); } Tensor& OpKernelContext__RequiredOutput(OpKernelContext* p, int index, const TensorShape& shape) override { return p->RequiredOutput(index, shape); } int OpKernelContext__InputCount(const OpKernelContext* p) override { return p->InputCount(); } int OpKernelContext__OutputCount(const OpKernelContext* p) override { return p->OutputCount(); } Status OpKernelContext__GetTempSpaceAllocator(const OpKernelContext* p, AllocatorPtr* output) override { return p->GetTempSpaceAllocator(output); } bool OpKernelContext__GetUseDeterministicCompute(const OpKernelContext* p) override { return p->GetUseDeterministicCompute(); } + bool OpKernelContext__TryGetInferredOutputShape(const OpKernelContext* p, int index, TensorShape& shape) override { return p->TryGetInferredOutputShape(index, shape); } + bool OpKernelContext__TryGetInferredInputShape(const OpKernelContext* p, int index, TensorShape& shape) override { return p->TryGetInferredInputShape(index, shape); } // OpKernelInfo (wrapped) std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info) override { return onnxruntime::CopyOpKernelInfo(info); } void OpKernelInfo__operator_delete(OpKernelInfo* p) override { delete p; } + AllocatorPtr OpKernelInfo__GetAllocator(const OpKernelInfo* p, int device_id, OrtMemType mem_type) override { return p->GetAllocator(device_id, mem_type); } const IExecutionProvider* OpKernelInfo__GetExecutionProvider(const OpKernelInfo* p) override { return p->GetExecutionProvider(); } Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) override { return p->GetAttr(name, value); } Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) override { return p->GetAttr(name, value); } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 563a9cd793..e21882f5df 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -188,6 +188,8 @@ using MLDataType = const DataTypeImpl*; // be used with class MLValue using DeleteFunc = void (*)(void*); using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto; + +using NameMLValMap = std::unordered_map; } // namespace onnxruntime #include "core/platform/threadpool.h" @@ -204,6 +206,23 @@ using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto; namespace onnxruntime { +// From Tensor.h +class BufferDeleter { + public: + BufferDeleter() : alloc_(nullptr) {} + BufferDeleter(AllocatorPtr alloc) : alloc_(alloc) {} + + void operator()(void* p) const { + if (alloc_) + alloc_->Free(p); + } + + private: + AllocatorPtr alloc_; +}; + +using BufferUniquePtr = std::unique_ptr; + // The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own // Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example. void RunOnUnload(std::function function); diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index e33a2ae12e..0b94225d1a 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -542,16 +542,20 @@ struct ProviderHost { // OpKernelContext virtual const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) = 0; virtual const Tensor& OpKernelContext__RequiredInput_Tensor(const OpKernelContext* p, int index) = 0; + virtual Tensor* OpKernelContext__Output_Tensor(OpKernelContext* p, int index) = 0; virtual Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) = 0; virtual Tensor& OpKernelContext__RequiredOutput(OpKernelContext* p, int index, const TensorShape& shape) = 0; virtual int OpKernelContext__InputCount(const OpKernelContext* p) = 0; virtual int OpKernelContext__OutputCount(const OpKernelContext* p) = 0; virtual Status OpKernelContext__GetTempSpaceAllocator(const OpKernelContext* p, AllocatorPtr* output) = 0; virtual bool OpKernelContext__GetUseDeterministicCompute(const OpKernelContext* p) = 0; + virtual bool OpKernelContext__TryGetInferredOutputShape(const OpKernelContext* p, int index, TensorShape& shape) = 0; + virtual bool OpKernelContext__TryGetInferredInputShape(const OpKernelContext* p, int index, TensorShape& shape) = 0; // OpKernelInfo virtual std::unique_ptr CopyOpKernelInfo(const OpKernelInfo& info) = 0; virtual void OpKernelInfo__operator_delete(OpKernelInfo* p) = 0; + virtual AllocatorPtr OpKernelInfo__GetAllocator(const OpKernelInfo* p, int device_id, OrtMemType mem_type) = 0; virtual const IExecutionProvider* OpKernelInfo__GetExecutionProvider(const OpKernelInfo* p) = 0; virtual Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) = 0; virtual Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) = 0; @@ -1406,6 +1410,9 @@ struct OpKernelContext final { const T* Input(int index) const; int InputCount() const { return g_host->OpKernelContext__InputCount(this); } + template + T* Output(int index); + Tensor* Output(int index, const TensorShape& shape) { return g_host->OpKernelContext__Output(this, index, shape); } int OutputCount() const { return g_host->OpKernelContext__OutputCount(this); } @@ -1413,6 +1420,9 @@ struct OpKernelContext final { bool GetUseDeterministicCompute() const { return g_host->OpKernelContext__GetUseDeterministicCompute(this); } + bool TryGetInferredOutputShape(int index, TensorShape& shape) const { return g_host->OpKernelContext__TryGetInferredOutputShape(this, index, shape); } + bool TryGetInferredInputShape(int index, TensorShape& shape) const { return g_host->OpKernelContext__TryGetInferredInputShape(this, index, shape); } + PROVIDER_DISALLOW_ALL(OpKernelContext) }; @@ -1421,6 +1431,11 @@ inline const Tensor* OpKernelContext::Input(int index) const { return g_host->OpKernelContext__Input_Tensor(this, index); } +template <> +inline Tensor* OpKernelContext::Output(int index) { + return g_host->OpKernelContext__Output_Tensor(this, index); +} + template <> inline const Tensor& OpKernelContext::RequiredInput(int index) const { return g_host->OpKernelContext__RequiredInput_Tensor(this, index); @@ -1429,6 +1444,8 @@ inline const Tensor& OpKernelContext::RequiredInput(int index) const { struct OpKernelInfo final { static void operator delete(void* p) { g_host->OpKernelInfo__operator_delete(reinterpret_cast(p)); } + AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const { return g_host->OpKernelInfo__GetAllocator(this, device_id, mem_type); } + const IExecutionProvider* GetExecutionProvider() const noexcept { return g_host->OpKernelInfo__GetExecutionProvider(this); } template diff --git a/orttraining/orttraining/core/framework/adasum/adasum_interface.h b/orttraining/orttraining/core/framework/adasum/adasum_interface.h index 631167c0ed..0cd164b31f 100644 --- a/orttraining/orttraining/core/framework/adasum/adasum_interface.h +++ b/orttraining/orttraining/core/framework/adasum/adasum_interface.h @@ -7,11 +7,12 @@ #include #include +#ifndef SHARED_PROVIDER #include "core/framework/tensor.h" #include "core/framework/op_kernel.h" +#endif #include "orttraining/core/graph/optimizer_config.h" - #include "orttraining/core/framework/distributed_run_context.h" #ifdef ENABLE_CPU_FP16_TRAINING_OPS @@ -21,16 +22,14 @@ namespace onnxruntime { namespace training { -static inline bool IsPowerOfTwo(ulong x) -{ +static inline bool IsPowerOfTwo(unsigned x) { return (x != 0) && ((x & (x - 1)) == 0); } // Interface for Adasum algorithm template class AdasumInterface { -public: - + public: Status DispatchFusedAllreduce(void* grad_buffer, void* recv_buffer, std::vector& tensor_counts, int start_level, Communicator_type communicator, int tag, @@ -44,7 +43,7 @@ public: FusedAllreduce((float*)grad_buffer, (float*)recv_buffer, data_type, tensor_counts, start_level, communicator, tag, reduction_comms); - } else if(data_type == DataTypeImpl::GetType()) { + } else if (data_type == DataTypeImpl::GetType()) { FusedAllreduce((double*)grad_buffer, (double*)recv_buffer, data_type, tensor_counts, start_level, communicator, tag, reduction_comms); @@ -61,7 +60,7 @@ public: virtual const Communicator_type* GetReductionComms() = 0; -protected: + protected: // Communication primitives required for Adasum algorithm virtual void PointToPointSendRecv(void* input_data_buffer, int64_t input_buffer_bytes, @@ -113,7 +112,7 @@ protected: } } -private: + private: // Allocator for temporary buffer allocations AllocatorPtr allocator_ = nullptr; @@ -161,7 +160,7 @@ private: if (IsPowerOfTwo(size) == false) { ORT_THROW( - "Adasum doesn't currently support reduction among non-power-of-2 number of ranks."); + "Adasum doesn't currently support reduction among non-power-of-2 number of ranks."); } std::vector> nghrCountVec; @@ -213,10 +212,10 @@ private: tensor_counts[i] = 0; } else { nghrCountVec[nghrCountVec_index][i] = - nghrCount - nghrCountSoFar; // should not be negative + nghrCount - nghrCountSoFar; // should not be negative tensor_counts[i] = tensor_counts[i] - - (nghrCount - nghrCountSoFar); // should not be negative + (nghrCount - nghrCountSoFar); // should not be negative } } else { tensor_counts[i] = tensor_counts[i]; @@ -240,9 +239,9 @@ private: assert((myCount - myCountSoFar) >= 0); nghrCountVec[nghrCountVec_index][i] = tensor_counts[i] - - (myCount - myCountSoFar); // should not be negative + (myCount - myCountSoFar); // should not be negative tensor_counts[i] = - myCount - myCountSoFar; // should not be negative + myCount - myCountSoFar; // should not be negative } } else { nghrCountVec[nghrCountVec_index][i] = tensor_counts[i]; @@ -264,8 +263,8 @@ private: recv_buffer = &recv_buffer[nghrCount]; } FusedPairwiseReduceWithComm((uint8_t*)grad_buffer, (uint8_t*)recv_buffer, - data_type, tensor_counts, reduction_comms[comm_index], - (rank & level) == 0, normAndDots); + data_type, tensor_counts, reduction_comms[comm_index], + (rank & level) == 0, normAndDots); } for (level = (size >> 1); level > 0; level = (level >> 1)) { @@ -383,5 +382,5 @@ private: } }; -} // namespace training -} // namespace onnxruntime +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/framework/communication/mpi/mpi_context.h b/orttraining/orttraining/core/framework/communication/mpi/mpi_context.h index b1ac54940e..bff1fddcf9 100644 --- a/orttraining/orttraining/core/framework/communication/mpi/mpi_context.h +++ b/orttraining/orttraining/core/framework/communication/mpi/mpi_context.h @@ -2,9 +2,11 @@ // Licensed under the MIT License. #pragma once +#ifndef SHARED_PROVIDER #include "core/common/common.h" -#include "orttraining/core/framework/distributed_run_context.h" #include "core/common/logging/logging.h" +#endif +#include "orttraining/core/framework/distributed_run_context.h" #if defined(USE_MPI) #include @@ -14,76 +16,76 @@ namespace onnxruntime { namespace training { #if defined(USE_MPI) -#define MPI_CHECK(condition) \ - do { \ - int error = (condition); \ - ORT_ENFORCE( \ - error == MPI_SUCCESS, \ - "MPI Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - error); \ +#define MPI_CHECK(condition) \ + do { \ + int error = (condition); \ + ORT_ENFORCE( \ + error == MPI_SUCCESS, \ + "MPI Error at: ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ": ", \ + error); \ } while (0) #endif struct MPIGroup { #if defined(USE_MPI) - MPI_Group mpi_group {MPI_GROUP_EMPTY}; // MPI group - MPI_Comm communicator {MPI_COMM_NULL}; // MPI communicator of this group + MPI_Group mpi_group{MPI_GROUP_EMPTY}; // MPI group + MPI_Comm communicator{MPI_COMM_NULL}; // MPI communicator of this group #endif - bool is_group_initialized {false}; // Whether it's initialized + bool is_group_initialized{false}; // Whether it's initialized }; class MPIContext { // https://stackoverflow.com/questions/1008019/c-singleton-design-pattern - public: - static MPIContext& GetInstance(); + public: + static MPIContext& GetInstance(); - MPIContext(MPIContext const&) = delete; - void operator=(MPIContext const&) = delete; + MPIContext(MPIContext const&) = delete; + void operator=(MPIContext const&) = delete; - // within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi(). - ~MPIContext(); + // within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi(). + ~MPIContext(); - void AddMPIGroup(WorkerGroupType group_type, WorkerGroup& group); + void AddMPIGroup(WorkerGroupType group_type, WorkerGroup& group); - const std::vector& GetAllMPIGroups() const { return mpi_groups_; } - - const MPIGroup& GetMPIGroup(WorkerGroupType group_type) const { return mpi_groups_[group_type]; } + const std::vector& GetAllMPIGroups() const { return mpi_groups_; } - int GetWorldRank() const { return world_rank_; } - int GetLocalRank() const { return local_rank_; } - int GetWorldSize() const { return world_size_; } - int GetLocalSize() const { return local_size_; } - - const static int MPI_TIMEOUT_IN_SECONDS = 10; + const MPIGroup& GetMPIGroup(WorkerGroupType group_type) const { return mpi_groups_[group_type]; } + + int GetWorldRank() const { return world_rank_; } + int GetLocalRank() const { return local_rank_; } + int GetWorldSize() const { return world_size_; } + int GetLocalSize() const { return local_size_; } + + const static int MPI_TIMEOUT_IN_SECONDS = 10; #if defined(USE_MPI) - // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices - // in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction - // shutdown_mpi shall be called specifically in user code. - static void shutdown_mpi(bool perform_graceful_exit = true); + // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices + // in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction + // shutdown_mpi shall be called specifically in user code. + static void shutdown_mpi(bool perform_graceful_exit = true); #endif - private: - MPIContext(); + private: + MPIContext(); - // Groups containing mpi communicator for any worker group. - std::vector mpi_groups_; + // Groups containing mpi communicator for any worker group. + std::vector mpi_groups_; #if defined(USE_MPI) - // Global counter for MPI groups - int mpi_group_id_ = 0; - void setup_mpi(); - void ReleaseComms(); + // Global counter for MPI groups + int mpi_group_id_ = 0; + void setup_mpi(); + void ReleaseComms(); #endif - int world_rank_; - int local_rank_; - int world_size_; - int local_size_; + int world_rank_; + int local_rank_; + int world_size_; + int local_size_; - const logging::Logger& logger_ = logging::LoggingManager::DefaultLogger(); + const logging::Logger& logger_ = logging::LoggingManager::DefaultLogger(); }; } // namespace training diff --git a/orttraining/orttraining/core/framework/communication/mpi/mpi_utilities.h b/orttraining/orttraining/core/framework/communication/mpi/mpi_utilities.h index e9910f2c46..ee966427a4 100644 --- a/orttraining/orttraining/core/framework/communication/mpi/mpi_utilities.h +++ b/orttraining/orttraining/core/framework/communication/mpi/mpi_utilities.h @@ -3,15 +3,18 @@ #pragma once +#ifndef SHARED_PROVIDER #include "core/framework/tensor.h" #include "core/framework/op_kernel.h" +#endif + #ifdef USE_MPI #include #endif namespace onnxruntime { namespace training { #ifdef USE_MPI -MPI_Datatype GetMPIDataType (MLDataType data_type); +MPI_Datatype GetMPIDataType(MLDataType data_type); int GetMPIRank(MPI_Comm comm); diff --git a/orttraining/orttraining/core/graph/optimizer_config.h b/orttraining/orttraining/core/graph/optimizer_config.h index d6b31809c5..089b84d4ce 100644 --- a/orttraining/orttraining/core/graph/optimizer_config.h +++ b/orttraining/orttraining/core/graph/optimizer_config.h @@ -5,10 +5,12 @@ #include #include +#ifndef SHARED_PROVIDER #include "core/common/logging/logging.h" #include "core/framework/framework_common.h" #include "core/framework/ml_value.h" #include "core/graph/node_arg.h" +#endif namespace onnxruntime { namespace training { diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 6d44578671..0e4fe1716d 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -16,7 +16,10 @@ #include "orttraining/models/runner/training_util.h" #include "orttraining/models/runner/data_loader.h" -#include "core/providers/cuda/cuda_execution_provider.h" +#include "core/providers/cuda/cuda_provider_factory_creator.h" +namespace onnxruntime { +std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* provider_options); +} #include #include @@ -111,8 +114,10 @@ int main(int argc, char* argv[]) { InferenceSession session_object{so, *env}; Status st; - CUDAExecutionProviderInfo xp_info{static_cast(world_rank)}; - st = session_object.RegisterExecutionProvider(std::make_unique(xp_info)); + OrtCUDAProviderOptions xp_info{}; + xp_info.device_id = static_cast(world_rank); + auto cuda_factory = CreateExecutionProviderFactory_Cuda(&xp_info); + st = session_object.RegisterExecutionProvider(cuda_factory->CreateProvider()); ORT_ENFORCE(st == Status::OK(), "MPI rank ", world_rank, ": ", st.ErrorMessage()); std::string model_at_rank; diff --git a/orttraining/orttraining/training_ops/communication_common.h b/orttraining/orttraining/training_ops/communication_common.h index 957ca03de9..74825f9867 100644 --- a/orttraining/orttraining/training_ops/communication_common.h +++ b/orttraining/orttraining/training_ops/communication_common.h @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifndef SHARED_PROVIDER #include "core/common/common.h" +#endif #pragma once #include "orttraining/core/framework/communication/mpi/mpi_context.h" @@ -191,7 +193,7 @@ inline void ReceiveShapeInfo( info_shapes.buffer, info_shapes.size, MPI_CHAR, info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); } -#endif // USE_MPI +#endif // USE_MPI inline void ComputeTensorSizeAndBufferLength(OpKernelContext* context, std::vector& tensor_element_counts, diff --git a/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.cc b/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.cc index 053d5fee1a..99e9eccfda 100644 --- a/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #ifdef USE_MPI +#include "core/providers/shared_library/provider_api.h" #include "orttraining/training_ops/cuda/collective/adasum_kernels.h" #include "orttraining/training_ops/communication_common.h" #include "orttraining/core/framework/communication/mpi/mpi_context.h" @@ -9,7 +10,6 @@ namespace onnxruntime { namespace cuda { Status AdasumAllReduce::ComputeInternal(OpKernelContext* context) const { - int vhdd_start_level = 1; if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction) { vhdd_start_level = training::DistributedRunContext::GetInstance().GroupSize(training::WorkerGroupType::NodeLocalDataParallel); @@ -37,23 +37,23 @@ Status AdasumAllReduce::ComputeInternal(OpKernelContext* context) const { for (int i = 0; i < num_tensors; ++i) { const Tensor* x_tensor = context->Input(i); CUDA_CALL(cudaMemcpyAsync((uint8_t*)data_buffer_ptr.get() + tensor_offsets[i], x_tensor->DataRaw(), - tensor_sizes[i], cudaMemcpyDeviceToHost, Stream())); + tensor_sizes[i], cudaMemcpyDeviceToHost, Stream())); } auto recv_buffer = allocator->Alloc(total_recv_buffer_len); BufferUniquePtr recv_buffer_ptr(recv_buffer, BufferDeleter(allocator)); ORT_RETURN_IF_ERROR(adasum_reducer_->DispatchFusedAllreduce((void*)data_buffer, recv_buffer, tensor_element_counts, - vhdd_start_level, // start level - training::MPIContext::GetInstance().GetMPIGroup(training::WorkerGroupType::GlobalParallel).communicator, // communicator - 0, // tag - adasum_reducer_->GetReductionComms(), // reduction_comms - context->Input(0)->DataType())); + vhdd_start_level, // start level + training::MPIContext::GetInstance().GetMPIGroup(training::WorkerGroupType::GlobalParallel).communicator, // communicator + 0, // tag + adasum_reducer_->GetReductionComms(), // reduction_comms + context->Input(0)->DataType())); for (int i = 0; i < num_tensors; i++) { Tensor* y_tensor = context->Output(i, context->Input(i)->Shape()); CUDA_CALL(cudaMemcpyAsync(y_tensor->MutableDataRaw(), (uint8_t*)data_buffer + tensor_offsets[i], - tensor_sizes[i], cudaMemcpyHostToDevice, Stream())); + tensor_sizes[i], cudaMemcpyHostToDevice, Stream())); } return Status::OK(); } @@ -63,11 +63,11 @@ ONNX_OPERATOR_KERNEL_EX( kMSDomain, 1, kCudaExecutionProvider, - KernelDefBuilder() + (*KernelDefBuilder::Create()) .VariadicAlias(0, 0) // outputs and inputs are mapped one to one .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), AdasumAllReduce); } // namespace cuda } // namespace onnxruntime -#endif // USE_MPI +#endif // USE_MPI diff --git a/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.h b/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.h index 3138771e93..693dcafe13 100644 --- a/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.h +++ b/orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.h @@ -4,25 +4,25 @@ #pragma once #include "orttraining/training_ops/cuda/collective/nccl_common.h" - #include "orttraining/core/framework/adasum/adasum_interface.h" #include "orttraining/core/framework/adasum/adasum_mpi.h" + namespace onnxruntime { namespace cuda { class AdasumAllReduce final : public NcclKernel { public: explicit AdasumAllReduce(const OpKernelInfo& info) : NcclKernel(info) { - int64_t adasum_reduce_algo; - info.GetAttrOrDefault("reduce_algo", &adasum_reduce_algo, static_cast(0)); - adasum_reduce_algo_ = static_cast(adasum_reduce_algo); - if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction || - adasum_reduce_algo_ == training::AdasumReductionType::CpuReduction) { - adasum_reducer_ = std::make_unique(); - } - if(!adasum_reducer_->IsAdasumInitialized()) { - adasum_reducer_->InitializeVHDDReductionComms(); - } + int64_t adasum_reduce_algo; + info.GetAttrOrDefault("reduce_algo", &adasum_reduce_algo, static_cast(0)); + adasum_reduce_algo_ = static_cast(adasum_reduce_algo); + if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction || + adasum_reduce_algo_ == training::AdasumReductionType::CpuReduction) { + adasum_reducer_ = std::make_unique(); + } + if (!adasum_reducer_->IsAdasumInitialized()) { + adasum_reducer_->InitializeVHDDReductionComms(); + } } Status ComputeInternal(OpKernelContext* context) const override; @@ -33,4 +33,4 @@ class AdasumAllReduce final : public NcclKernel { }; } // namespace cuda } // namespace onnxruntime -#endif // USE_MPI +#endif // USE_MPI diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.cc b/orttraining/orttraining/training_ops/cuda/communication/recv.cc index fb9b383cf0..46d682066b 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.cc @@ -90,16 +90,16 @@ void Recv::ReceiveData( // Copy data out from buffer. #if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, - tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream())); + tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream())); #else CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, - tensor->SizeInBytes(), cudaMemcpyHostToDevice, Stream())); + tensor->SizeInBytes(), cudaMemcpyHostToDevice, Stream())); #endif #ifndef NDEBUG - // In addition to the first output, other tensors are allocated on GPU. - // We check if the allocated memory is on the current CUDA device. - CheckIfMemoryOnCurrentGpuDevice(tensor->DataRaw()); + // In addition to the first output, other tensors are allocated on GPU. + // We check if the allocated memory is on the current CUDA device. + CheckIfMemoryOnCurrentGpuDevice(tensor->DataRaw()); #endif tensor_offset_in_bytes += tensor->SizeInBytes(); } @@ -121,10 +121,10 @@ ONNX_OPERATOR_KERNEL_EX( kMSDomain, 1, kCudaExecutionProvider, - KernelDefBuilder() - .InputMemoryType(0) /* CPU variable */ - .InputMemoryType(1) /* CPU variable */ - .OutputMemoryType(0) /* CPU variable */ + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) /* CPU variable */ + .InputMemoryType(OrtMemTypeCPUInput, 1) /* CPU variable */ + .OutputMemoryType(OrtMemTypeCPUOutput, 0) /* CPU variable */ .TypeConstraint("TBool", DataTypeImpl::GetTensorType()) .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), @@ -274,4 +274,4 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const { } // namespace cuda } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.cc b/orttraining/orttraining/training_ops/cuda/communication/send.cc index 6a5bc71fd3..7a3a21a63c 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/send.cc @@ -22,10 +22,10 @@ ONNX_OPERATOR_KERNEL_EX( kMSDomain, 1, kCudaExecutionProvider, - KernelDefBuilder() - .InputMemoryType(0) /* CPU variable */ - .InputMemoryType(1) /* CPU variable */ - .OutputMemoryType(0) /* CPU variable */ + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) /* CPU variable */ + .InputMemoryType(OrtMemTypeCPUInput, 1) /* CPU variable */ + .OutputMemoryType(OrtMemTypeCPUOutput, 0) /* CPU variable */ .TypeConstraint("TBool", DataTypeImpl::GetTensorType()) .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), @@ -67,10 +67,10 @@ void Send::SendData( #if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) CUDA_CALL(cudaMemcpyAsync(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(), - tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice, Stream())); + tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice, Stream())); #else CUDA_CALL(cudaMemcpyAsync(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(), - tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost, Stream())); + tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost, Stream())); #endif } @@ -226,4 +226,4 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const { } // namespace cuda } // namespace onnxruntime -#endif \ No newline at end of file +#endif