mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Allow cuda custom ops allocate deferred cpu mem (#17893)
Expose a new allocator from cuda stream. The allocator manages deferred cpu memory which only get recycled before stream destruction. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
parent
2f57625cb0
commit
009cd4ea2e
9 changed files with 100 additions and 17 deletions
|
|
@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext {
|
|||
cudaStream_t cuda_stream = {};
|
||||
cudnnHandle_t cudnn_handle = {};
|
||||
cublasHandle_t cublas_handle = {};
|
||||
OrtAllocator* deferred_cpu_allocator = {};
|
||||
|
||||
void Init(const OrtKernelContext& kernel_ctx) override {
|
||||
const auto& ort_api = Ort::GetApi();
|
||||
|
|
@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext {
|
|||
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);
|
||||
|
||||
resource = {};
|
||||
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
|
||||
if (status) {
|
||||
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
|
||||
}
|
||||
|
||||
void* AllocDeferredCpuMem(size_t size) const {
|
||||
if (0 == size) {
|
||||
return {};
|
||||
}
|
||||
const auto& ort_api = Ort::GetApi();
|
||||
void* mem = {};
|
||||
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
|
||||
if (status) {
|
||||
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return mem;
|
||||
}
|
||||
|
||||
void FreeDeferredCpuMem(void* mem) const {
|
||||
if (mem) {
|
||||
const auto& ort_api = Ort::GetApi();
|
||||
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
|
||||
if (status) {
|
||||
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@
|
|||
|
||||
#include "core/providers/resource.h"
|
||||
|
||||
#define ORT_CUDA_RESOUCE_VERSION 1
|
||||
#define ORT_CUDA_RESOUCE_VERSION 2
|
||||
|
||||
enum CudaResource : int {
|
||||
cuda_stream_t = cuda_resource_offset,
|
||||
cudnn_handle_t,
|
||||
cublas_handle_t
|
||||
cublas_handle_t,
|
||||
deferred_cpu_allocator_t,
|
||||
};
|
||||
|
|
@ -7,6 +7,25 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) {
|
||||
OrtAllocator::version = ORT_API_VERSION;
|
||||
OrtAllocator::Alloc =
|
||||
[](OrtAllocator* this_, size_t size) {
|
||||
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
|
||||
return self->cuda_stream_.GetCpuAllocator()->Alloc(size);
|
||||
};
|
||||
OrtAllocator::Free =
|
||||
[](OrtAllocator* this_, void* p) {
|
||||
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
|
||||
self->cuda_stream_.EnqueDeferredCPUBuffer(p);
|
||||
};
|
||||
OrtAllocator::Info =
|
||||
[](const OrtAllocator* this_) {
|
||||
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
|
||||
return &self->cuda_stream_.GetCpuAllocator()->Info();
|
||||
};
|
||||
}
|
||||
|
||||
struct CudaNotification : public synchronize::Notification {
|
||||
CudaNotification(Stream& s) : Notification(s) {
|
||||
CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
|
||||
|
|
@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream,
|
|||
cublasHandle_t external_cublas_handle) : Stream(stream, device),
|
||||
own_stream_(own_flag),
|
||||
cpu_allocator_(cpu_allocator),
|
||||
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) {
|
||||
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
|
||||
deferred_cpu_allocator_(*this) {
|
||||
if (own_flag) {
|
||||
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
|
||||
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
|
||||
|
|
@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const {
|
|||
case CudaResource::cublas_handle_t:
|
||||
return reinterpret_cast<void*>(cublas_handle_);
|
||||
break;
|
||||
case CudaResource::deferred_cpu_allocator_t:
|
||||
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,13 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct CudaStream;
|
||||
|
||||
struct DeferredCpuAllocator : public OrtAllocator {
|
||||
DeferredCpuAllocator(CudaStream&);
|
||||
CudaStream& cuda_stream_;
|
||||
};
|
||||
|
||||
struct CudaStream : Stream {
|
||||
CudaStream(cudaStream_t stream,
|
||||
const OrtDevice& device,
|
||||
|
|
@ -36,10 +43,13 @@ struct CudaStream : Stream {
|
|||
|
||||
void* GetResource(int version, int id) const override;
|
||||
|
||||
onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
|
||||
|
||||
private:
|
||||
std::vector<void*> deferred_cpu_buffers_;
|
||||
AllocatorPtr cpu_allocator_;
|
||||
bool release_cpu_buffer_on_cuda_stream_{true};
|
||||
DeferredCpuAllocator deferred_cpu_allocator_;
|
||||
};
|
||||
|
||||
void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)
|
||||
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
|
|
@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
|
|||
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
|
||||
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
|
||||
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
|
||||
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
|
||||
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
|
||||
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
|
||||
auto z_raw = Z.Allocate(input_shape);
|
||||
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
|
||||
}
|
||||
|
|
@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
|
|||
|
||||
} // namespace Cuda
|
||||
|
||||
#else
|
||||
|
||||
void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
|
||||
|
||||
#endif
|
||||
|
|
@ -5,6 +5,14 @@
|
|||
|
||||
namespace Cuda {
|
||||
|
||||
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)
|
||||
|
||||
void RegisterOps(Ort::CustomOpDomain& domain);
|
||||
|
||||
}
|
||||
#else
|
||||
|
||||
void RegisterOps(Ort::CustomOpDomain&) {}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace Cuda
|
||||
|
|
@ -13,6 +13,8 @@
|
|||
#include "core/framework/ortdevice.h"
|
||||
#include "core/framework/ortmemoryinfo.h"
|
||||
#include "cpu/cpu_ops.h"
|
||||
#include "cuda/cuda_ops.h"
|
||||
#include "rocm/rocm_ops.h"
|
||||
#include "onnxruntime_lite_custom_op.h"
|
||||
|
||||
static const char* c_OpDomain = "test.customop";
|
||||
|
|
@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
|
|||
ORT_TRY {
|
||||
Ort::CustomOpDomain domain{c_OpDomain};
|
||||
Cpu::RegisterOps(domain);
|
||||
|
||||
Ort::CustomOpDomain domain_v2{"v2"};
|
||||
Cpu::RegisterOps(domain_v2);
|
||||
|
||||
Cuda::RegisterOps(domain);
|
||||
Cuda::RegisterOps(domain_v2);
|
||||
|
||||
Rocm::RegisterOps(domain);
|
||||
Rocm::RegisterOps(domain_v2);
|
||||
|
||||
Ort::UnownedSessionOptions session_options(options);
|
||||
session_options.Add(domain);
|
||||
session_options.Add(domain_v2);
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ using namespace Ort::Custom;
|
|||
throw std::runtime_error(msg); \
|
||||
}
|
||||
|
||||
namespace Cuda {
|
||||
namespace Rocm {
|
||||
|
||||
void KernelOne(const Ort::Custom::RocmContext& rocm_ctx,
|
||||
const Ort::Custom::Tensor<float>& X,
|
||||
|
|
@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
|
|||
domain.Add(c_CustomOpOne.get());
|
||||
}
|
||||
|
||||
} // namespace Cuda
|
||||
|
||||
#else
|
||||
|
||||
void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
|
||||
} // namespace Rocm
|
||||
|
||||
#endif
|
||||
|
|
@ -5,6 +5,14 @@
|
|||
|
||||
namespace Rocm {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
void RegisterOps(Ort::CustomOpDomain& domain);
|
||||
|
||||
}
|
||||
#else
|
||||
|
||||
inline void RegisterOps(Ort::CustomOpDomain&) {}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace Rocm
|
||||
|
|
|
|||
Loading…
Reference in a new issue