mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
MPI Kernels for Training
This commit is contained in:
parent
335562def6
commit
66babe4ed3
13 changed files with 162 additions and 109 deletions
|
|
@ -631,16 +631,20 @@ struct ProviderHostImpl : ProviderHost {
|
|||
// OpKernelContext (wrapped)
|
||||
const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) override { return p->Input<Tensor>(index); }
|
||||
const Tensor& OpKernelContext__RequiredInput_Tensor(const OpKernelContext* p, int index) override { return p->RequiredInput<Tensor>(index); }
|
||||
Tensor* OpKernelContext__Output_Tensor(OpKernelContext* p, int index) override { return p->Output<Tensor>(index); }
|
||||
Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) override { return p->Output(index, shape); }
|
||||
Tensor& OpKernelContext__RequiredOutput(OpKernelContext* p, int index, const TensorShape& shape) override { return p->RequiredOutput(index, shape); }
|
||||
int OpKernelContext__InputCount(const OpKernelContext* p) override { return p->InputCount(); }
|
||||
int OpKernelContext__OutputCount(const OpKernelContext* p) override { return p->OutputCount(); }
|
||||
Status OpKernelContext__GetTempSpaceAllocator(const OpKernelContext* p, AllocatorPtr* output) override { return p->GetTempSpaceAllocator(output); }
|
||||
bool OpKernelContext__GetUseDeterministicCompute(const OpKernelContext* p) override { return p->GetUseDeterministicCompute(); }
|
||||
bool OpKernelContext__TryGetInferredOutputShape(const OpKernelContext* p, int index, TensorShape& shape) override { return p->TryGetInferredOutputShape(index, shape); }
|
||||
bool OpKernelContext__TryGetInferredInputShape(const OpKernelContext* p, int index, TensorShape& shape) override { return p->TryGetInferredInputShape(index, shape); }
|
||||
|
||||
// OpKernelInfo (wrapped)
|
||||
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) override { return onnxruntime::CopyOpKernelInfo(info); }
|
||||
void OpKernelInfo__operator_delete(OpKernelInfo* p) override { delete p; }
|
||||
AllocatorPtr OpKernelInfo__GetAllocator(const OpKernelInfo* p, int device_id, OrtMemType mem_type) override { return p->GetAllocator(device_id, mem_type); }
|
||||
const IExecutionProvider* OpKernelInfo__GetExecutionProvider(const OpKernelInfo* p) override { return p->GetExecutionProvider(); }
|
||||
Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) override { return p->GetAttr(name, value); }
|
||||
Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) override { return p->GetAttr(name, value); }
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ using MLDataType = const DataTypeImpl*;
|
|||
// be used with class MLValue
|
||||
using DeleteFunc = void (*)(void*);
|
||||
using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto;
|
||||
|
||||
using NameMLValMap = std::unordered_map<std::string, OrtValue>;
|
||||
} // namespace onnxruntime
|
||||
|
||||
#include "core/platform/threadpool.h"
|
||||
|
|
@ -204,6 +206,23 @@ using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
// From Tensor.h
|
||||
class BufferDeleter {
|
||||
public:
|
||||
BufferDeleter() : alloc_(nullptr) {}
|
||||
BufferDeleter(AllocatorPtr alloc) : alloc_(alloc) {}
|
||||
|
||||
void operator()(void* p) const {
|
||||
if (alloc_)
|
||||
alloc_->Free(p);
|
||||
}
|
||||
|
||||
private:
|
||||
AllocatorPtr alloc_;
|
||||
};
|
||||
|
||||
using BufferUniquePtr = std::unique_ptr<void, BufferDeleter>;
|
||||
|
||||
// The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own
|
||||
// Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example.
|
||||
void RunOnUnload(std::function<void()> function);
|
||||
|
|
|
|||
|
|
@ -542,16 +542,20 @@ struct ProviderHost {
|
|||
// OpKernelContext
|
||||
virtual const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) = 0;
|
||||
virtual const Tensor& OpKernelContext__RequiredInput_Tensor(const OpKernelContext* p, int index) = 0;
|
||||
virtual Tensor* OpKernelContext__Output_Tensor(OpKernelContext* p, int index) = 0;
|
||||
virtual Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) = 0;
|
||||
virtual Tensor& OpKernelContext__RequiredOutput(OpKernelContext* p, int index, const TensorShape& shape) = 0;
|
||||
virtual int OpKernelContext__InputCount(const OpKernelContext* p) = 0;
|
||||
virtual int OpKernelContext__OutputCount(const OpKernelContext* p) = 0;
|
||||
virtual Status OpKernelContext__GetTempSpaceAllocator(const OpKernelContext* p, AllocatorPtr* output) = 0;
|
||||
virtual bool OpKernelContext__GetUseDeterministicCompute(const OpKernelContext* p) = 0;
|
||||
virtual bool OpKernelContext__TryGetInferredOutputShape(const OpKernelContext* p, int index, TensorShape& shape) = 0;
|
||||
virtual bool OpKernelContext__TryGetInferredInputShape(const OpKernelContext* p, int index, TensorShape& shape) = 0;
|
||||
|
||||
// OpKernelInfo
|
||||
virtual std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) = 0;
|
||||
virtual void OpKernelInfo__operator_delete(OpKernelInfo* p) = 0;
|
||||
virtual AllocatorPtr OpKernelInfo__GetAllocator(const OpKernelInfo* p, int device_id, OrtMemType mem_type) = 0;
|
||||
virtual const IExecutionProvider* OpKernelInfo__GetExecutionProvider(const OpKernelInfo* p) = 0;
|
||||
virtual Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) = 0;
|
||||
virtual Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) = 0;
|
||||
|
|
@ -1406,6 +1410,9 @@ struct OpKernelContext final {
|
|||
const T* Input(int index) const;
|
||||
int InputCount() const { return g_host->OpKernelContext__InputCount(this); }
|
||||
|
||||
template <typename T>
|
||||
T* Output(int index);
|
||||
|
||||
Tensor* Output(int index, const TensorShape& shape) { return g_host->OpKernelContext__Output(this, index, shape); }
|
||||
int OutputCount() const { return g_host->OpKernelContext__OutputCount(this); }
|
||||
|
||||
|
|
@ -1413,6 +1420,9 @@ struct OpKernelContext final {
|
|||
|
||||
bool GetUseDeterministicCompute() const { return g_host->OpKernelContext__GetUseDeterministicCompute(this); }
|
||||
|
||||
bool TryGetInferredOutputShape(int index, TensorShape& shape) const { return g_host->OpKernelContext__TryGetInferredOutputShape(this, index, shape); }
|
||||
bool TryGetInferredInputShape(int index, TensorShape& shape) const { return g_host->OpKernelContext__TryGetInferredInputShape(this, index, shape); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(OpKernelContext)
|
||||
};
|
||||
|
||||
|
|
@ -1421,6 +1431,11 @@ inline const Tensor* OpKernelContext::Input<Tensor>(int index) const {
|
|||
return g_host->OpKernelContext__Input_Tensor(this, index);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
|
||||
return g_host->OpKernelContext__Output_Tensor(this, index);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const Tensor& OpKernelContext::RequiredInput(int index) const {
|
||||
return g_host->OpKernelContext__RequiredInput_Tensor(this, index);
|
||||
|
|
@ -1429,6 +1444,8 @@ inline const Tensor& OpKernelContext::RequiredInput(int index) const {
|
|||
struct OpKernelInfo final {
|
||||
static void operator delete(void* p) { g_host->OpKernelInfo__operator_delete(reinterpret_cast<OpKernelInfo*>(p)); }
|
||||
|
||||
AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const { return g_host->OpKernelInfo__GetAllocator(this, device_id, mem_type); }
|
||||
|
||||
const IExecutionProvider* GetExecutionProvider() const noexcept { return g_host->OpKernelInfo__GetExecutionProvider(this); }
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@
|
|||
#include <float.h>
|
||||
#include <cmath>
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#endif
|
||||
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
|
||||
#ifdef ENABLE_CPU_FP16_TRAINING_OPS
|
||||
|
|
@ -21,16 +22,14 @@
|
|||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
static inline bool IsPowerOfTwo(ulong x)
|
||||
{
|
||||
static inline bool IsPowerOfTwo(unsigned x) {
|
||||
return (x != 0) && ((x & (x - 1)) == 0);
|
||||
}
|
||||
|
||||
// Interface for Adasum algorithm
|
||||
template <typename Communicator_type>
|
||||
class AdasumInterface {
|
||||
public:
|
||||
|
||||
public:
|
||||
Status DispatchFusedAllreduce(void* grad_buffer, void* recv_buffer,
|
||||
std::vector<int>& tensor_counts, int start_level,
|
||||
Communicator_type communicator, int tag,
|
||||
|
|
@ -44,7 +43,7 @@ public:
|
|||
FusedAllreduce((float*)grad_buffer, (float*)recv_buffer,
|
||||
data_type, tensor_counts, start_level, communicator, tag,
|
||||
reduction_comms);
|
||||
} else if(data_type == DataTypeImpl::GetType<double>()) {
|
||||
} else if (data_type == DataTypeImpl::GetType<double>()) {
|
||||
FusedAllreduce((double*)grad_buffer, (double*)recv_buffer,
|
||||
data_type, tensor_counts, start_level, communicator, tag,
|
||||
reduction_comms);
|
||||
|
|
@ -61,7 +60,7 @@ public:
|
|||
|
||||
virtual const Communicator_type* GetReductionComms() = 0;
|
||||
|
||||
protected:
|
||||
protected:
|
||||
// Communication primitives required for Adasum algorithm
|
||||
virtual void PointToPointSendRecv(void* input_data_buffer,
|
||||
int64_t input_buffer_bytes,
|
||||
|
|
@ -113,7 +112,7 @@ protected:
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
// Allocator for temporary buffer allocations
|
||||
AllocatorPtr allocator_ = nullptr;
|
||||
|
||||
|
|
@ -161,7 +160,7 @@ private:
|
|||
|
||||
if (IsPowerOfTwo(size) == false) {
|
||||
ORT_THROW(
|
||||
"Adasum doesn't currently support reduction among non-power-of-2 number of ranks.");
|
||||
"Adasum doesn't currently support reduction among non-power-of-2 number of ranks.");
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> nghrCountVec;
|
||||
|
|
@ -213,10 +212,10 @@ private:
|
|||
tensor_counts[i] = 0;
|
||||
} else {
|
||||
nghrCountVec[nghrCountVec_index][i] =
|
||||
nghrCount - nghrCountSoFar; // should not be negative
|
||||
nghrCount - nghrCountSoFar; // should not be negative
|
||||
tensor_counts[i] =
|
||||
tensor_counts[i] -
|
||||
(nghrCount - nghrCountSoFar); // should not be negative
|
||||
(nghrCount - nghrCountSoFar); // should not be negative
|
||||
}
|
||||
} else {
|
||||
tensor_counts[i] = tensor_counts[i];
|
||||
|
|
@ -240,9 +239,9 @@ private:
|
|||
assert((myCount - myCountSoFar) >= 0);
|
||||
nghrCountVec[nghrCountVec_index][i] =
|
||||
tensor_counts[i] -
|
||||
(myCount - myCountSoFar); // should not be negative
|
||||
(myCount - myCountSoFar); // should not be negative
|
||||
tensor_counts[i] =
|
||||
myCount - myCountSoFar; // should not be negative
|
||||
myCount - myCountSoFar; // should not be negative
|
||||
}
|
||||
} else {
|
||||
nghrCountVec[nghrCountVec_index][i] = tensor_counts[i];
|
||||
|
|
@ -264,8 +263,8 @@ private:
|
|||
recv_buffer = &recv_buffer[nghrCount];
|
||||
}
|
||||
FusedPairwiseReduceWithComm((uint8_t*)grad_buffer, (uint8_t*)recv_buffer,
|
||||
data_type, tensor_counts, reduction_comms[comm_index],
|
||||
(rank & level) == 0, normAndDots);
|
||||
data_type, tensor_counts, reduction_comms[comm_index],
|
||||
(rank & level) == 0, normAndDots);
|
||||
}
|
||||
|
||||
for (level = (size >> 1); level > 0; level = (level >> 1)) {
|
||||
|
|
@ -383,5 +382,5 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/common/common.h"
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#endif
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
|
||||
#if defined(USE_MPI)
|
||||
#include <mpi.h>
|
||||
|
|
@ -14,76 +16,76 @@ namespace onnxruntime {
|
|||
namespace training {
|
||||
|
||||
#if defined(USE_MPI)
|
||||
#define MPI_CHECK(condition) \
|
||||
do { \
|
||||
int error = (condition); \
|
||||
ORT_ENFORCE( \
|
||||
error == MPI_SUCCESS, \
|
||||
"MPI Error at: ", \
|
||||
__FILE__, \
|
||||
":", \
|
||||
__LINE__, \
|
||||
": ", \
|
||||
error); \
|
||||
#define MPI_CHECK(condition) \
|
||||
do { \
|
||||
int error = (condition); \
|
||||
ORT_ENFORCE( \
|
||||
error == MPI_SUCCESS, \
|
||||
"MPI Error at: ", \
|
||||
__FILE__, \
|
||||
":", \
|
||||
__LINE__, \
|
||||
": ", \
|
||||
error); \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
struct MPIGroup {
|
||||
#if defined(USE_MPI)
|
||||
MPI_Group mpi_group {MPI_GROUP_EMPTY}; // MPI group
|
||||
MPI_Comm communicator {MPI_COMM_NULL}; // MPI communicator of this group
|
||||
MPI_Group mpi_group{MPI_GROUP_EMPTY}; // MPI group
|
||||
MPI_Comm communicator{MPI_COMM_NULL}; // MPI communicator of this group
|
||||
#endif
|
||||
bool is_group_initialized {false}; // Whether it's initialized
|
||||
bool is_group_initialized{false}; // Whether it's initialized
|
||||
};
|
||||
|
||||
class MPIContext {
|
||||
// https://stackoverflow.com/questions/1008019/c-singleton-design-pattern
|
||||
public:
|
||||
static MPIContext& GetInstance();
|
||||
public:
|
||||
static MPIContext& GetInstance();
|
||||
|
||||
MPIContext(MPIContext const&) = delete;
|
||||
void operator=(MPIContext const&) = delete;
|
||||
MPIContext(MPIContext const&) = delete;
|
||||
void operator=(MPIContext const&) = delete;
|
||||
|
||||
// within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi().
|
||||
~MPIContext();
|
||||
// within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi().
|
||||
~MPIContext();
|
||||
|
||||
void AddMPIGroup(WorkerGroupType group_type, WorkerGroup& group);
|
||||
void AddMPIGroup(WorkerGroupType group_type, WorkerGroup& group);
|
||||
|
||||
const std::vector<MPIGroup>& GetAllMPIGroups() const { return mpi_groups_; }
|
||||
|
||||
const MPIGroup& GetMPIGroup(WorkerGroupType group_type) const { return mpi_groups_[group_type]; }
|
||||
const std::vector<MPIGroup>& GetAllMPIGroups() const { return mpi_groups_; }
|
||||
|
||||
int GetWorldRank() const { return world_rank_; }
|
||||
int GetLocalRank() const { return local_rank_; }
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
int GetLocalSize() const { return local_size_; }
|
||||
|
||||
const static int MPI_TIMEOUT_IN_SECONDS = 10;
|
||||
const MPIGroup& GetMPIGroup(WorkerGroupType group_type) const { return mpi_groups_[group_type]; }
|
||||
|
||||
int GetWorldRank() const { return world_rank_; }
|
||||
int GetLocalRank() const { return local_rank_; }
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
int GetLocalSize() const { return local_size_; }
|
||||
|
||||
const static int MPI_TIMEOUT_IN_SECONDS = 10;
|
||||
|
||||
#if defined(USE_MPI)
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction
|
||||
// shutdown_mpi shall be called specifically in user code.
|
||||
static void shutdown_mpi(bool perform_graceful_exit = true);
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction
|
||||
// shutdown_mpi shall be called specifically in user code.
|
||||
static void shutdown_mpi(bool perform_graceful_exit = true);
|
||||
#endif
|
||||
|
||||
private:
|
||||
MPIContext();
|
||||
private:
|
||||
MPIContext();
|
||||
|
||||
// Groups containing mpi communicator for any worker group.
|
||||
std::vector<MPIGroup> mpi_groups_;
|
||||
// Groups containing mpi communicator for any worker group.
|
||||
std::vector<MPIGroup> mpi_groups_;
|
||||
#if defined(USE_MPI)
|
||||
// Global counter for MPI groups
|
||||
int mpi_group_id_ = 0;
|
||||
void setup_mpi();
|
||||
void ReleaseComms();
|
||||
// Global counter for MPI groups
|
||||
int mpi_group_id_ = 0;
|
||||
void setup_mpi();
|
||||
void ReleaseComms();
|
||||
#endif
|
||||
int world_rank_;
|
||||
int local_rank_;
|
||||
int world_size_;
|
||||
int local_size_;
|
||||
int world_rank_;
|
||||
int local_rank_;
|
||||
int world_size_;
|
||||
int local_size_;
|
||||
|
||||
const logging::Logger& logger_ = logging::LoggingManager::DefaultLogger();
|
||||
const logging::Logger& logger_ = logging::LoggingManager::DefaultLogger();
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -3,15 +3,18 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#endif
|
||||
|
||||
#ifdef USE_MPI
|
||||
#include <mpi.h>
|
||||
#endif
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
#ifdef USE_MPI
|
||||
MPI_Datatype GetMPIDataType (MLDataType data_type);
|
||||
MPI_Datatype GetMPIDataType(MLDataType data_type);
|
||||
|
||||
int GetMPIRank(MPI_Comm comm);
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/graph/node_arg.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@
|
|||
#include "orttraining/models/runner/training_util.h"
|
||||
#include "orttraining/models/runner/data_loader.h"
|
||||
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#include "core/providers/cuda/cuda_provider_factory_creator.h"
|
||||
namespace onnxruntime {
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* provider_options);
|
||||
}
|
||||
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
|
|
@ -111,8 +114,10 @@ int main(int argc, char* argv[]) {
|
|||
InferenceSession session_object{so, *env};
|
||||
|
||||
Status st;
|
||||
CUDAExecutionProviderInfo xp_info{static_cast<OrtDevice::DeviceId>(world_rank)};
|
||||
st = session_object.RegisterExecutionProvider(std::make_unique<CUDAExecutionProvider>(xp_info));
|
||||
OrtCUDAProviderOptions xp_info{};
|
||||
xp_info.device_id = static_cast<OrtDevice::DeviceId>(world_rank);
|
||||
auto cuda_factory = CreateExecutionProviderFactory_Cuda(&xp_info);
|
||||
st = session_object.RegisterExecutionProvider(cuda_factory->CreateProvider());
|
||||
ORT_ENFORCE(st == Status::OK(), "MPI rank ", world_rank, ": ", st.ErrorMessage());
|
||||
|
||||
std::string model_at_rank;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/common/common.h"
|
||||
#endif
|
||||
#pragma once
|
||||
|
||||
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
|
||||
|
|
@ -191,7 +193,7 @@ inline void ReceiveShapeInfo(
|
|||
info_shapes.buffer, info_shapes.size, MPI_CHAR,
|
||||
info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE));
|
||||
}
|
||||
#endif // USE_MPI
|
||||
#endif // USE_MPI
|
||||
|
||||
inline void ComputeTensorSizeAndBufferLength(OpKernelContext* context,
|
||||
std::vector<int>& tensor_element_counts,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#ifdef USE_MPI
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "orttraining/training_ops/cuda/collective/adasum_kernels.h"
|
||||
#include "orttraining/training_ops/communication_common.h"
|
||||
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
|
||||
|
|
@ -9,7 +10,6 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
Status AdasumAllReduce::ComputeInternal(OpKernelContext* context) const {
|
||||
|
||||
int vhdd_start_level = 1;
|
||||
if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction) {
|
||||
vhdd_start_level = training::DistributedRunContext::GetInstance().GroupSize(training::WorkerGroupType::NodeLocalDataParallel);
|
||||
|
|
@ -37,23 +37,23 @@ Status AdasumAllReduce::ComputeInternal(OpKernelContext* context) const {
|
|||
for (int i = 0; i < num_tensors; ++i) {
|
||||
const Tensor* x_tensor = context->Input<Tensor>(i);
|
||||
CUDA_CALL(cudaMemcpyAsync((uint8_t*)data_buffer_ptr.get() + tensor_offsets[i], x_tensor->DataRaw(),
|
||||
tensor_sizes[i], cudaMemcpyDeviceToHost, Stream()));
|
||||
tensor_sizes[i], cudaMemcpyDeviceToHost, Stream()));
|
||||
}
|
||||
|
||||
auto recv_buffer = allocator->Alloc(total_recv_buffer_len);
|
||||
BufferUniquePtr recv_buffer_ptr(recv_buffer, BufferDeleter(allocator));
|
||||
|
||||
ORT_RETURN_IF_ERROR(adasum_reducer_->DispatchFusedAllreduce((void*)data_buffer, recv_buffer, tensor_element_counts,
|
||||
vhdd_start_level, // start level
|
||||
training::MPIContext::GetInstance().GetMPIGroup(training::WorkerGroupType::GlobalParallel).communicator, // communicator
|
||||
0, // tag
|
||||
adasum_reducer_->GetReductionComms(), // reduction_comms
|
||||
context->Input<Tensor>(0)->DataType()));
|
||||
vhdd_start_level, // start level
|
||||
training::MPIContext::GetInstance().GetMPIGroup(training::WorkerGroupType::GlobalParallel).communicator, // communicator
|
||||
0, // tag
|
||||
adasum_reducer_->GetReductionComms(), // reduction_comms
|
||||
context->Input<Tensor>(0)->DataType()));
|
||||
|
||||
for (int i = 0; i < num_tensors; i++) {
|
||||
Tensor* y_tensor = context->Output(i, context->Input<Tensor>(i)->Shape());
|
||||
CUDA_CALL(cudaMemcpyAsync(y_tensor->MutableDataRaw(), (uint8_t*)data_buffer + tensor_offsets[i],
|
||||
tensor_sizes[i], cudaMemcpyHostToDevice, Stream()));
|
||||
tensor_sizes[i], cudaMemcpyHostToDevice, Stream()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -63,11 +63,11 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
(*KernelDefBuilder::Create())
|
||||
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
|
||||
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
|
||||
AdasumAllReduce);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
#endif // USE_MPI
|
||||
#endif // USE_MPI
|
||||
|
|
|
|||
|
|
@ -4,25 +4,25 @@
|
|||
|
||||
#pragma once
|
||||
#include "orttraining/training_ops/cuda/collective/nccl_common.h"
|
||||
|
||||
#include "orttraining/core/framework/adasum/adasum_interface.h"
|
||||
#include "orttraining/core/framework/adasum/adasum_mpi.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class AdasumAllReduce final : public NcclKernel {
|
||||
public:
|
||||
explicit AdasumAllReduce(const OpKernelInfo& info) : NcclKernel(info) {
|
||||
int64_t adasum_reduce_algo;
|
||||
info.GetAttrOrDefault("reduce_algo", &adasum_reduce_algo, static_cast<int64_t>(0));
|
||||
adasum_reduce_algo_ = static_cast<training::AdasumReductionType>(adasum_reduce_algo);
|
||||
if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction ||
|
||||
adasum_reduce_algo_ == training::AdasumReductionType::CpuReduction) {
|
||||
adasum_reducer_ = std::make_unique<training::AdasumMPI>();
|
||||
}
|
||||
if(!adasum_reducer_->IsAdasumInitialized()) {
|
||||
adasum_reducer_->InitializeVHDDReductionComms();
|
||||
}
|
||||
int64_t adasum_reduce_algo;
|
||||
info.GetAttrOrDefault("reduce_algo", &adasum_reduce_algo, static_cast<int64_t>(0));
|
||||
adasum_reduce_algo_ = static_cast<training::AdasumReductionType>(adasum_reduce_algo);
|
||||
if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction ||
|
||||
adasum_reduce_algo_ == training::AdasumReductionType::CpuReduction) {
|
||||
adasum_reducer_ = std::make_unique<training::AdasumMPI>();
|
||||
}
|
||||
if (!adasum_reducer_->IsAdasumInitialized()) {
|
||||
adasum_reducer_->InitializeVHDDReductionComms();
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -33,4 +33,4 @@ class AdasumAllReduce final : public NcclKernel {
|
|||
};
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
#endif // USE_MPI
|
||||
#endif // USE_MPI
|
||||
|
|
|
|||
|
|
@ -90,16 +90,16 @@ void Recv::ReceiveData(
|
|||
// Copy data out from buffer.
|
||||
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
|
||||
CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes,
|
||||
tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
|
||||
tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
|
||||
#else
|
||||
CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes,
|
||||
tensor->SizeInBytes(), cudaMemcpyHostToDevice, Stream()));
|
||||
tensor->SizeInBytes(), cudaMemcpyHostToDevice, Stream()));
|
||||
#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
// In addition to the first output, other tensors are allocated on GPU.
|
||||
// We check if the allocated memory is on the current CUDA device.
|
||||
CheckIfMemoryOnCurrentGpuDevice(tensor->DataRaw());
|
||||
// In addition to the first output, other tensors are allocated on GPU.
|
||||
// We check if the allocated memory is on the current CUDA device.
|
||||
CheckIfMemoryOnCurrentGpuDevice(tensor->DataRaw());
|
||||
#endif
|
||||
tensor_offset_in_bytes += tensor->SizeInBytes();
|
||||
}
|
||||
|
|
@ -121,10 +121,10 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* CPU variable */
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1) /* CPU variable */
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* CPU variable */
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) /* CPU variable */
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1) /* CPU variable */
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0) /* CPU variable */
|
||||
.TypeConstraint("TBool", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
|
|
@ -274,4 +274,4 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const {
|
|||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* CPU variable */
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1) /* CPU variable */
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* CPU variable */
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) /* CPU variable */
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1) /* CPU variable */
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0) /* CPU variable */
|
||||
.TypeConstraint("TBool", DataTypeImpl::GetTensorType<bool>())
|
||||
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
|
|
@ -67,10 +67,10 @@ void Send::SendData(
|
|||
|
||||
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
|
||||
CUDA_CALL(cudaMemcpyAsync(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(),
|
||||
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice, Stream()));
|
||||
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice, Stream()));
|
||||
#else
|
||||
CUDA_CALL(cudaMemcpyAsync(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(),
|
||||
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost, Stream()));
|
||||
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost, Stream()));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -226,4 +226,4 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const {
|
|||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue