Allow cuda custom ops allocate deferred cpu mem (#17893)

Expose a new allocator from cuda stream.
The allocator manages deferred cpu memory which only get recycled before
stream destruction.

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-10-20 16:12:21 -07:00 committed by GitHub
parent 2f57625cb0
commit 009cd4ea2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 100 additions and 17 deletions

View file

@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};
void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);
resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
}
void* AllocDeferredCpuMem(size_t size) const {
if (0 == size) {
return {};
}
const auto& ort_api = Ort::GetApi();
void* mem = {};
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
if (status) {
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
return mem;
}
void FreeDeferredCpuMem(void* mem) const {
if (mem) {
const auto& ort_api = Ort::GetApi();
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
if (status) {
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
}
}
};

View file

@ -3,10 +3,11 @@
#include "core/providers/resource.h"
#define ORT_CUDA_RESOUCE_VERSION 1
#define ORT_CUDA_RESOUCE_VERSION 2
enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
cublas_handle_t,
deferred_cpu_allocator_t,
};

View file

@ -7,6 +7,25 @@
namespace onnxruntime {
DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) {
OrtAllocator::version = ORT_API_VERSION;
OrtAllocator::Alloc =
[](OrtAllocator* this_, size_t size) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
return self->cuda_stream_.GetCpuAllocator()->Alloc(size);
};
OrtAllocator::Free =
[](OrtAllocator* this_, void* p) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
self->cuda_stream_.EnqueDeferredCPUBuffer(p);
};
OrtAllocator::Info =
[](const OrtAllocator* this_) {
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
return &self->cuda_stream_.GetCpuAllocator()->Info();
};
}
struct CudaNotification : public synchronize::Notification {
CudaNotification(Stream& s) : Notification(s) {
CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream,
cublasHandle_t external_cublas_handle) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) {
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
deferred_cpu_allocator_(*this) {
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const {
case CudaResource::cublas_handle_t:
return reinterpret_cast<void*>(cublas_handle_);
break;
case CudaResource::deferred_cpu_allocator_t:
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
break;
default:
break;
}

View file

@ -9,6 +9,13 @@
namespace onnxruntime {
struct CudaStream;
struct DeferredCpuAllocator : public OrtAllocator {
DeferredCpuAllocator(CudaStream&);
CudaStream& cuda_stream_;
};
struct CudaStream : Stream {
CudaStream(cudaStream_t stream,
const OrtDevice& device,
@ -36,10 +43,13 @@ struct CudaStream : Stream {
void* GetResource(int version, int id) const override;
onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_cuda_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
};
void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef USE_CUDA
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)
#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
auto z_raw = Z.Allocate(input_shape);
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
}
@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
} // namespace Cuda
#else
void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
#endif

View file

@ -5,6 +5,14 @@
namespace Cuda {
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)
void RegisterOps(Ort::CustomOpDomain& domain);
}
#else
void RegisterOps(Ort::CustomOpDomain&) {}
#endif
} // namespace Cuda

View file

@ -13,6 +13,8 @@
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"
#include "cpu/cpu_ops.h"
#include "cuda/cuda_ops.h"
#include "rocm/rocm_ops.h"
#include "onnxruntime_lite_custom_op.h"
static const char* c_OpDomain = "test.customop";
@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};
Cpu::RegisterOps(domain);
Ort::CustomOpDomain domain_v2{"v2"};
Cpu::RegisterOps(domain_v2);
Cuda::RegisterOps(domain);
Cuda::RegisterOps(domain_v2);
Rocm::RegisterOps(domain);
Rocm::RegisterOps(domain_v2);
Ort::UnownedSessionOptions session_options(options);
session_options.Add(domain);
session_options.Add(domain_v2);

View file

@ -19,7 +19,7 @@ using namespace Ort::Custom;
throw std::runtime_error(msg); \
}
namespace Cuda {
namespace Rocm {
void KernelOne(const Ort::Custom::RocmContext& rocm_ctx,
const Ort::Custom::Tensor<float>& X,
@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
domain.Add(c_CustomOpOne.get());
}
} // namespace Cuda
#else
void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
} // namespace Rocm
#endif

View file

@ -5,6 +5,14 @@
namespace Rocm {
#ifdef USE_ROCM
void RegisterOps(Ort::CustomOpDomain& domain);
}
#else
inline void RegisterOps(Ort::CustomOpDomain&) {}
#endif
} // namespace Rocm