MPI Kernels for Training

This commit is contained in:
Ryan Hill 2021-05-11 20:16:42 -07:00
parent 335562def6
commit 66babe4ed3
13 changed files with 162 additions and 109 deletions

View file

@ -631,16 +631,20 @@ struct ProviderHostImpl : ProviderHost {
// OpKernelContext (wrapped)
const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) override { return p->Input<Tensor>(index); }
const Tensor& OpKernelContext__RequiredInput_Tensor(const OpKernelContext* p, int index) override { return p->RequiredInput<Tensor>(index); }
Tensor* OpKernelContext__Output_Tensor(OpKernelContext* p, int index) override { return p->Output<Tensor>(index); }
Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) override { return p->Output(index, shape); }
Tensor& OpKernelContext__RequiredOutput(OpKernelContext* p, int index, const TensorShape& shape) override { return p->RequiredOutput(index, shape); }
int OpKernelContext__InputCount(const OpKernelContext* p) override { return p->InputCount(); }
int OpKernelContext__OutputCount(const OpKernelContext* p) override { return p->OutputCount(); }
Status OpKernelContext__GetTempSpaceAllocator(const OpKernelContext* p, AllocatorPtr* output) override { return p->GetTempSpaceAllocator(output); }
bool OpKernelContext__GetUseDeterministicCompute(const OpKernelContext* p) override { return p->GetUseDeterministicCompute(); }
bool OpKernelContext__TryGetInferredOutputShape(const OpKernelContext* p, int index, TensorShape& shape) override { return p->TryGetInferredOutputShape(index, shape); }
bool OpKernelContext__TryGetInferredInputShape(const OpKernelContext* p, int index, TensorShape& shape) override { return p->TryGetInferredInputShape(index, shape); }
// OpKernelInfo (wrapped)
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) override { return onnxruntime::CopyOpKernelInfo(info); }
void OpKernelInfo__operator_delete(OpKernelInfo* p) override { delete p; }
AllocatorPtr OpKernelInfo__GetAllocator(const OpKernelInfo* p, int device_id, OrtMemType mem_type) override { return p->GetAllocator(device_id, mem_type); }
const IExecutionProvider* OpKernelInfo__GetExecutionProvider(const OpKernelInfo* p) override { return p->GetExecutionProvider(); }
Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) override { return p->GetAttr(name, value); }
Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) override { return p->GetAttr(name, value); }

View file

@ -188,6 +188,8 @@ using MLDataType = const DataTypeImpl*;
// be used with class MLValue
using DeleteFunc = void (*)(void*);
using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto;
using NameMLValMap = std::unordered_map<std::string, OrtValue>;
} // namespace onnxruntime
#include "core/platform/threadpool.h"
@ -204,6 +206,23 @@ using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto;
namespace onnxruntime {
// From Tensor.h
class BufferDeleter {
public:
BufferDeleter() : alloc_(nullptr) {}
BufferDeleter(AllocatorPtr alloc) : alloc_(alloc) {}
void operator()(void* p) const {
if (alloc_)
alloc_->Free(p);
}
private:
AllocatorPtr alloc_;
};
using BufferUniquePtr = std::unique_ptr<void, BufferDeleter>;
// The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own
// Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example.
void RunOnUnload(std::function<void()> function);

View file

@ -542,16 +542,20 @@ struct ProviderHost {
// OpKernelContext
virtual const Tensor* OpKernelContext__Input_Tensor(const OpKernelContext* p, int index) = 0;
virtual const Tensor& OpKernelContext__RequiredInput_Tensor(const OpKernelContext* p, int index) = 0;
virtual Tensor* OpKernelContext__Output_Tensor(OpKernelContext* p, int index) = 0;
virtual Tensor* OpKernelContext__Output(OpKernelContext* p, int index, const TensorShape& shape) = 0;
virtual Tensor& OpKernelContext__RequiredOutput(OpKernelContext* p, int index, const TensorShape& shape) = 0;
virtual int OpKernelContext__InputCount(const OpKernelContext* p) = 0;
virtual int OpKernelContext__OutputCount(const OpKernelContext* p) = 0;
virtual Status OpKernelContext__GetTempSpaceAllocator(const OpKernelContext* p, AllocatorPtr* output) = 0;
virtual bool OpKernelContext__GetUseDeterministicCompute(const OpKernelContext* p) = 0;
virtual bool OpKernelContext__TryGetInferredOutputShape(const OpKernelContext* p, int index, TensorShape& shape) = 0;
virtual bool OpKernelContext__TryGetInferredInputShape(const OpKernelContext* p, int index, TensorShape& shape) = 0;
// OpKernelInfo
virtual std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info) = 0;
virtual void OpKernelInfo__operator_delete(OpKernelInfo* p) = 0;
virtual AllocatorPtr OpKernelInfo__GetAllocator(const OpKernelInfo* p, int device_id, OrtMemType mem_type) = 0;
virtual const IExecutionProvider* OpKernelInfo__GetExecutionProvider(const OpKernelInfo* p) = 0;
virtual Status OpKernelInfo__GetAttr_int64(const OpKernelInfo* p, const std::string& name, int64_t* value) = 0;
virtual Status OpKernelInfo__GetAttr_float(const OpKernelInfo* p, const std::string& name, float* value) = 0;
@ -1406,6 +1410,9 @@ struct OpKernelContext final {
const T* Input(int index) const;
int InputCount() const { return g_host->OpKernelContext__InputCount(this); }
template <typename T>
T* Output(int index);
Tensor* Output(int index, const TensorShape& shape) { return g_host->OpKernelContext__Output(this, index, shape); }
int OutputCount() const { return g_host->OpKernelContext__OutputCount(this); }
@ -1413,6 +1420,9 @@ struct OpKernelContext final {
bool GetUseDeterministicCompute() const { return g_host->OpKernelContext__GetUseDeterministicCompute(this); }
bool TryGetInferredOutputShape(int index, TensorShape& shape) const { return g_host->OpKernelContext__TryGetInferredOutputShape(this, index, shape); }
bool TryGetInferredInputShape(int index, TensorShape& shape) const { return g_host->OpKernelContext__TryGetInferredInputShape(this, index, shape); }
PROVIDER_DISALLOW_ALL(OpKernelContext)
};
@ -1421,6 +1431,11 @@ inline const Tensor* OpKernelContext::Input<Tensor>(int index) const {
return g_host->OpKernelContext__Input_Tensor(this, index);
}
template <>
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
return g_host->OpKernelContext__Output_Tensor(this, index);
}
template <>
inline const Tensor& OpKernelContext::RequiredInput(int index) const {
return g_host->OpKernelContext__RequiredInput_Tensor(this, index);
@ -1429,6 +1444,8 @@ inline const Tensor& OpKernelContext::RequiredInput(int index) const {
struct OpKernelInfo final {
static void operator delete(void* p) { g_host->OpKernelInfo__operator_delete(reinterpret_cast<OpKernelInfo*>(p)); }
AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const { return g_host->OpKernelInfo__GetAllocator(this, device_id, mem_type); }
const IExecutionProvider* GetExecutionProvider() const noexcept { return g_host->OpKernelInfo__GetExecutionProvider(this); }
template <typename T>

View file

@ -7,11 +7,12 @@
#include <float.h>
#include <cmath>
#ifndef SHARED_PROVIDER
#include "core/framework/tensor.h"
#include "core/framework/op_kernel.h"
#endif
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/distributed_run_context.h"
#ifdef ENABLE_CPU_FP16_TRAINING_OPS
@ -21,16 +22,14 @@
namespace onnxruntime {
namespace training {
static inline bool IsPowerOfTwo(ulong x)
{
static inline bool IsPowerOfTwo(unsigned x) {
return (x != 0) && ((x & (x - 1)) == 0);
}
// Interface for Adasum algorithm
template <typename Communicator_type>
class AdasumInterface {
public:
public:
Status DispatchFusedAllreduce(void* grad_buffer, void* recv_buffer,
std::vector<int>& tensor_counts, int start_level,
Communicator_type communicator, int tag,
@ -44,7 +43,7 @@ public:
FusedAllreduce((float*)grad_buffer, (float*)recv_buffer,
data_type, tensor_counts, start_level, communicator, tag,
reduction_comms);
} else if(data_type == DataTypeImpl::GetType<double>()) {
} else if (data_type == DataTypeImpl::GetType<double>()) {
FusedAllreduce((double*)grad_buffer, (double*)recv_buffer,
data_type, tensor_counts, start_level, communicator, tag,
reduction_comms);
@ -61,7 +60,7 @@ public:
virtual const Communicator_type* GetReductionComms() = 0;
protected:
protected:
// Communication primitives required for Adasum algorithm
virtual void PointToPointSendRecv(void* input_data_buffer,
int64_t input_buffer_bytes,
@ -113,7 +112,7 @@ protected:
}
}
private:
private:
// Allocator for temporary buffer allocations
AllocatorPtr allocator_ = nullptr;
@ -161,7 +160,7 @@ private:
if (IsPowerOfTwo(size) == false) {
ORT_THROW(
"Adasum doesn't currently support reduction among non-power-of-2 number of ranks.");
"Adasum doesn't currently support reduction among non-power-of-2 number of ranks.");
}
std::vector<std::vector<int>> nghrCountVec;
@ -213,10 +212,10 @@ private:
tensor_counts[i] = 0;
} else {
nghrCountVec[nghrCountVec_index][i] =
nghrCount - nghrCountSoFar; // should not be negative
nghrCount - nghrCountSoFar; // should not be negative
tensor_counts[i] =
tensor_counts[i] -
(nghrCount - nghrCountSoFar); // should not be negative
(nghrCount - nghrCountSoFar); // should not be negative
}
} else {
tensor_counts[i] = tensor_counts[i];
@ -240,9 +239,9 @@ private:
assert((myCount - myCountSoFar) >= 0);
nghrCountVec[nghrCountVec_index][i] =
tensor_counts[i] -
(myCount - myCountSoFar); // should not be negative
(myCount - myCountSoFar); // should not be negative
tensor_counts[i] =
myCount - myCountSoFar; // should not be negative
myCount - myCountSoFar; // should not be negative
}
} else {
nghrCountVec[nghrCountVec_index][i] = tensor_counts[i];
@ -264,8 +263,8 @@ private:
recv_buffer = &recv_buffer[nghrCount];
}
FusedPairwiseReduceWithComm((uint8_t*)grad_buffer, (uint8_t*)recv_buffer,
data_type, tensor_counts, reduction_comms[comm_index],
(rank & level) == 0, normAndDots);
data_type, tensor_counts, reduction_comms[comm_index],
(rank & level) == 0, normAndDots);
}
for (level = (size >> 1); level > 0; level = (level >> 1)) {
@ -383,5 +382,5 @@ private:
}
};
} // namespace training
} // namespace onnxruntime
} // namespace training
} // namespace onnxruntime

View file

@ -2,9 +2,11 @@
// Licensed under the MIT License.
#pragma once
#ifndef SHARED_PROVIDER
#include "core/common/common.h"
#include "orttraining/core/framework/distributed_run_context.h"
#include "core/common/logging/logging.h"
#endif
#include "orttraining/core/framework/distributed_run_context.h"
#if defined(USE_MPI)
#include <mpi.h>
@ -14,76 +16,76 @@ namespace onnxruntime {
namespace training {
#if defined(USE_MPI)
#define MPI_CHECK(condition) \
do { \
int error = (condition); \
ORT_ENFORCE( \
error == MPI_SUCCESS, \
"MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
#define MPI_CHECK(condition) \
do { \
int error = (condition); \
ORT_ENFORCE( \
error == MPI_SUCCESS, \
"MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
#endif
struct MPIGroup {
#if defined(USE_MPI)
MPI_Group mpi_group {MPI_GROUP_EMPTY}; // MPI group
MPI_Comm communicator {MPI_COMM_NULL}; // MPI communicator of this group
MPI_Group mpi_group{MPI_GROUP_EMPTY}; // MPI group
MPI_Comm communicator{MPI_COMM_NULL}; // MPI communicator of this group
#endif
bool is_group_initialized {false}; // Whether it's initialized
bool is_group_initialized{false}; // Whether it's initialized
};
class MPIContext {
// https://stackoverflow.com/questions/1008019/c-singleton-design-pattern
public:
static MPIContext& GetInstance();
public:
static MPIContext& GetInstance();
MPIContext(MPIContext const&) = delete;
void operator=(MPIContext const&) = delete;
MPIContext(MPIContext const&) = delete;
void operator=(MPIContext const&) = delete;
// within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi().
~MPIContext();
// within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi().
~MPIContext();
void AddMPIGroup(WorkerGroupType group_type, WorkerGroup& group);
void AddMPIGroup(WorkerGroupType group_type, WorkerGroup& group);
const std::vector<MPIGroup>& GetAllMPIGroups() const { return mpi_groups_; }
const MPIGroup& GetMPIGroup(WorkerGroupType group_type) const { return mpi_groups_[group_type]; }
const std::vector<MPIGroup>& GetAllMPIGroups() const { return mpi_groups_; }
int GetWorldRank() const { return world_rank_; }
int GetLocalRank() const { return local_rank_; }
int GetWorldSize() const { return world_size_; }
int GetLocalSize() const { return local_size_; }
const static int MPI_TIMEOUT_IN_SECONDS = 10;
const MPIGroup& GetMPIGroup(WorkerGroupType group_type) const { return mpi_groups_[group_type]; }
int GetWorldRank() const { return world_rank_; }
int GetLocalRank() const { return local_rank_; }
int GetWorldSize() const { return world_size_; }
int GetLocalSize() const { return local_size_; }
const static int MPI_TIMEOUT_IN_SECONDS = 10;
#if defined(USE_MPI)
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction
// shutdown_mpi shall be called specifically in user code.
static void shutdown_mpi(bool perform_graceful_exit = true);
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction
// shutdown_mpi shall be called specifically in user code.
static void shutdown_mpi(bool perform_graceful_exit = true);
#endif
private:
MPIContext();
private:
MPIContext();
// Groups containing mpi communicator for any worker group.
std::vector<MPIGroup> mpi_groups_;
// Groups containing mpi communicator for any worker group.
std::vector<MPIGroup> mpi_groups_;
#if defined(USE_MPI)
// Global counter for MPI groups
int mpi_group_id_ = 0;
void setup_mpi();
void ReleaseComms();
// Global counter for MPI groups
int mpi_group_id_ = 0;
void setup_mpi();
void ReleaseComms();
#endif
int world_rank_;
int local_rank_;
int world_size_;
int local_size_;
int world_rank_;
int local_rank_;
int world_size_;
int local_size_;
const logging::Logger& logger_ = logging::LoggingManager::DefaultLogger();
const logging::Logger& logger_ = logging::LoggingManager::DefaultLogger();
};
} // namespace training

View file

@ -3,15 +3,18 @@
#pragma once
#ifndef SHARED_PROVIDER
#include "core/framework/tensor.h"
#include "core/framework/op_kernel.h"
#endif
#ifdef USE_MPI
#include <mpi.h>
#endif
namespace onnxruntime {
namespace training {
#ifdef USE_MPI
MPI_Datatype GetMPIDataType (MLDataType data_type);
MPI_Datatype GetMPIDataType(MLDataType data_type);
int GetMPIRank(MPI_Comm comm);

View file

@ -5,10 +5,12 @@
#include <string>
#include <unordered_map>
#ifndef SHARED_PROVIDER
#include "core/common/logging/logging.h"
#include "core/framework/framework_common.h"
#include "core/framework/ml_value.h"
#include "core/graph/node_arg.h"
#endif
namespace onnxruntime {
namespace training {

View file

@ -16,7 +16,10 @@
#include "orttraining/models/runner/training_util.h"
#include "orttraining/models/runner/data_loader.h"
#include "core/providers/cuda/cuda_execution_provider.h"
#include "core/providers/cuda/cuda_provider_factory_creator.h"
namespace onnxruntime {
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* provider_options);
}
#include <condition_variable>
#include <mutex>
@ -111,8 +114,10 @@ int main(int argc, char* argv[]) {
InferenceSession session_object{so, *env};
Status st;
CUDAExecutionProviderInfo xp_info{static_cast<OrtDevice::DeviceId>(world_rank)};
st = session_object.RegisterExecutionProvider(std::make_unique<CUDAExecutionProvider>(xp_info));
OrtCUDAProviderOptions xp_info{};
xp_info.device_id = static_cast<OrtDevice::DeviceId>(world_rank);
auto cuda_factory = CreateExecutionProviderFactory_Cuda(&xp_info);
st = session_object.RegisterExecutionProvider(cuda_factory->CreateProvider());
ORT_ENFORCE(st == Status::OK(), "MPI rank ", world_rank, ": ", st.ErrorMessage());
std::string model_at_rank;

View file

@ -1,7 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef SHARED_PROVIDER
#include "core/common/common.h"
#endif
#pragma once
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
@ -191,7 +193,7 @@ inline void ReceiveShapeInfo(
info_shapes.buffer, info_shapes.size, MPI_CHAR,
info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE));
}
#endif // USE_MPI
#endif // USE_MPI
inline void ComputeTensorSizeAndBufferLength(OpKernelContext* context,
std::vector<int>& tensor_element_counts,

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef USE_MPI
#include "core/providers/shared_library/provider_api.h"
#include "orttraining/training_ops/cuda/collective/adasum_kernels.h"
#include "orttraining/training_ops/communication_common.h"
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
@ -9,7 +10,6 @@ namespace onnxruntime {
namespace cuda {
Status AdasumAllReduce::ComputeInternal(OpKernelContext* context) const {
int vhdd_start_level = 1;
if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction) {
vhdd_start_level = training::DistributedRunContext::GetInstance().GroupSize(training::WorkerGroupType::NodeLocalDataParallel);
@ -37,23 +37,23 @@ Status AdasumAllReduce::ComputeInternal(OpKernelContext* context) const {
for (int i = 0; i < num_tensors; ++i) {
const Tensor* x_tensor = context->Input<Tensor>(i);
CUDA_CALL(cudaMemcpyAsync((uint8_t*)data_buffer_ptr.get() + tensor_offsets[i], x_tensor->DataRaw(),
tensor_sizes[i], cudaMemcpyDeviceToHost, Stream()));
tensor_sizes[i], cudaMemcpyDeviceToHost, Stream()));
}
auto recv_buffer = allocator->Alloc(total_recv_buffer_len);
BufferUniquePtr recv_buffer_ptr(recv_buffer, BufferDeleter(allocator));
ORT_RETURN_IF_ERROR(adasum_reducer_->DispatchFusedAllreduce((void*)data_buffer, recv_buffer, tensor_element_counts,
vhdd_start_level, // start level
training::MPIContext::GetInstance().GetMPIGroup(training::WorkerGroupType::GlobalParallel).communicator, // communicator
0, // tag
adasum_reducer_->GetReductionComms(), // reduction_comms
context->Input<Tensor>(0)->DataType()));
vhdd_start_level, // start level
training::MPIContext::GetInstance().GetMPIGroup(training::WorkerGroupType::GlobalParallel).communicator, // communicator
0, // tag
adasum_reducer_->GetReductionComms(), // reduction_comms
context->Input<Tensor>(0)->DataType()));
for (int i = 0; i < num_tensors; i++) {
Tensor* y_tensor = context->Output(i, context->Input<Tensor>(i)->Shape());
CUDA_CALL(cudaMemcpyAsync(y_tensor->MutableDataRaw(), (uint8_t*)data_buffer + tensor_offsets[i],
tensor_sizes[i], cudaMemcpyHostToDevice, Stream()));
tensor_sizes[i], cudaMemcpyHostToDevice, Stream()));
}
return Status::OK();
}
@ -63,11 +63,11 @@ ONNX_OPERATOR_KERNEL_EX(
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AdasumAllReduce);
} // namespace cuda
} // namespace onnxruntime
#endif // USE_MPI
#endif // USE_MPI

View file

@ -4,25 +4,25 @@
#pragma once
#include "orttraining/training_ops/cuda/collective/nccl_common.h"
#include "orttraining/core/framework/adasum/adasum_interface.h"
#include "orttraining/core/framework/adasum/adasum_mpi.h"
namespace onnxruntime {
namespace cuda {
class AdasumAllReduce final : public NcclKernel {
public:
explicit AdasumAllReduce(const OpKernelInfo& info) : NcclKernel(info) {
int64_t adasum_reduce_algo;
info.GetAttrOrDefault("reduce_algo", &adasum_reduce_algo, static_cast<int64_t>(0));
adasum_reduce_algo_ = static_cast<training::AdasumReductionType>(adasum_reduce_algo);
if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction ||
adasum_reduce_algo_ == training::AdasumReductionType::CpuReduction) {
adasum_reducer_ = std::make_unique<training::AdasumMPI>();
}
if(!adasum_reducer_->IsAdasumInitialized()) {
adasum_reducer_->InitializeVHDDReductionComms();
}
int64_t adasum_reduce_algo;
info.GetAttrOrDefault("reduce_algo", &adasum_reduce_algo, static_cast<int64_t>(0));
adasum_reduce_algo_ = static_cast<training::AdasumReductionType>(adasum_reduce_algo);
if (adasum_reduce_algo_ == training::AdasumReductionType::GpuHierarchicalReduction ||
adasum_reduce_algo_ == training::AdasumReductionType::CpuReduction) {
adasum_reducer_ = std::make_unique<training::AdasumMPI>();
}
if (!adasum_reducer_->IsAdasumInitialized()) {
adasum_reducer_->InitializeVHDDReductionComms();
}
}
Status ComputeInternal(OpKernelContext* context) const override;
@ -33,4 +33,4 @@ class AdasumAllReduce final : public NcclKernel {
};
} // namespace cuda
} // namespace onnxruntime
#endif // USE_MPI
#endif // USE_MPI

View file

@ -90,16 +90,16 @@ void Recv::ReceiveData(
// Copy data out from buffer.
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes,
tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
tensor->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
#else
CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes,
tensor->SizeInBytes(), cudaMemcpyHostToDevice, Stream()));
tensor->SizeInBytes(), cudaMemcpyHostToDevice, Stream()));
#endif
#ifndef NDEBUG
// In addition to the first output, other tensors are allocated on GPU.
// We check if the allocated memory is on the current CUDA device.
CheckIfMemoryOnCurrentGpuDevice(tensor->DataRaw());
// In addition to the first output, other tensors are allocated on GPU.
// We check if the allocated memory is on the current CUDA device.
CheckIfMemoryOnCurrentGpuDevice(tensor->DataRaw());
#endif
tensor_offset_in_bytes += tensor->SizeInBytes();
}
@ -121,10 +121,10 @@ ONNX_OPERATOR_KERNEL_EX(
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) /* CPU variable */
.InputMemoryType<OrtMemTypeCPUInput>(1) /* CPU variable */
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* CPU variable */
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) /* CPU variable */
.InputMemoryType(OrtMemTypeCPUInput, 1) /* CPU variable */
.OutputMemoryType(OrtMemTypeCPUOutput, 0) /* CPU variable */
.TypeConstraint("TBool", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
@ -274,4 +274,4 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const {
} // namespace cuda
} // namespace onnxruntime
#endif
#endif

View file

@ -22,10 +22,10 @@ ONNX_OPERATOR_KERNEL_EX(
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) /* CPU variable */
.InputMemoryType<OrtMemTypeCPUInput>(1) /* CPU variable */
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* CPU variable */
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) /* CPU variable */
.InputMemoryType(OrtMemTypeCPUInput, 1) /* CPU variable */
.OutputMemoryType(OrtMemTypeCPUOutput, 0) /* CPU variable */
.TypeConstraint("TBool", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
@ -67,10 +67,10 @@ void Send::SendData(
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
CUDA_CALL(cudaMemcpyAsync(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(),
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice, Stream()));
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToDevice, Stream()));
#else
CUDA_CALL(cudaMemcpyAsync(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(),
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost, Stream()));
tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost, Stream()));
#endif
}
@ -226,4 +226,4 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const {
} // namespace cuda
} // namespace onnxruntime
#endif
#endif