From 009cd4ea2e0621459806010cea7d7533d0acb39d Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:12:21 -0700 Subject: [PATCH] Allow cuda custom ops allocate deferred cpu mem (#17893) Expose a new allocator from cuda stream. The allocator manages deferred cpu memory which only get recycled before stream destruction. --------- Co-authored-by: Randy Shuai --- .../core/providers/cuda/cuda_context.h | 31 +++++++++++++++++++ .../core/providers/cuda/cuda_resource.h | 5 +-- .../core/providers/cuda/cuda_stream_handle.cc | 25 ++++++++++++++- .../core/providers/cuda/cuda_stream_handle.h | 10 ++++++ .../custom_op_library/cuda/cuda_ops.cc | 9 +++--- .../custom_op_library/cuda/cuda_ops.h | 10 +++++- .../custom_op_library/custom_op_library.cc | 9 +++++- .../custom_op_library/rocm/rocm_ops.cc | 8 ++--- .../custom_op_library/rocm/rocm_ops.h | 10 +++++- 9 files changed, 100 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 13c176dad3..646f33ed95 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext { cudaStream_t cuda_stream = {}; cudnnHandle_t cudnn_handle = {}; cublasHandle_t cublas_handle = {}; + OrtAllocator* deferred_cpu_allocator = {}; void Init(const OrtKernelContext& kernel_ctx) override { const auto& ort_api = Ort::GetApi(); @@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext { ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } cublas_handle = reinterpret_cast(resource); + + resource = {}; + status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource); + if (status) { + ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + deferred_cpu_allocator = reinterpret_cast(resource); + } + + void* AllocDeferredCpuMem(size_t size) const { + if (0 == size) { + return {}; + } + const auto& ort_api = Ort::GetApi(); + void* mem = {}; + auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); + if (status) { + ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return mem; + } + + void FreeDeferredCpuMem(void* mem) const { + if (mem) { + const auto& ort_api = Ort::GetApi(); + auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); + if (status) { + ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + } } }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index e46fc5b421..8c3ed46ade 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -3,10 +3,11 @@ #include "core/providers/resource.h" -#define ORT_CUDA_RESOUCE_VERSION 1 +#define ORT_CUDA_RESOUCE_VERSION 2 enum CudaResource : int { cuda_stream_t = cuda_resource_offset, cudnn_handle_t, - cublas_handle_t + cublas_handle_t, + deferred_cpu_allocator_t, }; \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index e855a515f4..5f1dbd30f6 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->cuda_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->cuda_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->cuda_stream_.GetCpuAllocator()->Info(); + }; +} + struct CudaNotification : public synchronize::Notification { CudaNotification(Stream& s) : Notification(s) { CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); @@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream, cublasHandle_t external_cublas_handle) : Stream(stream, device), own_stream_(own_flag), cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) { + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), + deferred_cpu_allocator_(*this) { if (own_flag) { CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); @@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::cublas_handle_t: return reinterpret_cast(cublas_handle_); break; + case CudaResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index 9c62b029b7..917702fae0 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -9,6 +9,13 @@ namespace onnxruntime { +struct CudaStream; + +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(CudaStream&); + CudaStream& cuda_stream_; +}; + struct CudaStream : Stream { CudaStream(cudaStream_t stream, const OrtDevice& device, @@ -36,10 +43,13 @@ struct CudaStream : Stream { void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_cuda_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; }; void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc index aba35b33b7..3d561d378c 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUDA +#if defined(USE_CUDA) && !defined(ENABLE_TRAINING) #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" @@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx, CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream"); CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle"); CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle"); + void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t)); + CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator"); + cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem); auto z_raw = Z.Allocate(input_shape); cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream); } @@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) { } // namespace Cuda -#else - -void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {} - #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h index c0287c4932..35cd36fcd4 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h @@ -5,6 +5,14 @@ namespace Cuda { +#if defined(USE_CUDA) && !defined(ENABLE_TRAINING) + void RegisterOps(Ort::CustomOpDomain& domain); -} \ No newline at end of file +#else + +void RegisterOps(Ort::CustomOpDomain&) {} + +#endif + +} // namespace Cuda \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 40fb127eb0..2d5ffc3c81 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -13,6 +13,8 @@ #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" #include "cpu/cpu_ops.h" +#include "cuda/cuda_ops.h" +#include "rocm/rocm_ops.h" #include "onnxruntime_lite_custom_op.h" static const char* c_OpDomain = "test.customop"; @@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA ORT_TRY { Ort::CustomOpDomain domain{c_OpDomain}; Cpu::RegisterOps(domain); - Ort::CustomOpDomain domain_v2{"v2"}; Cpu::RegisterOps(domain_v2); + Cuda::RegisterOps(domain); + Cuda::RegisterOps(domain_v2); + + Rocm::RegisterOps(domain); + Rocm::RegisterOps(domain_v2); + Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); session_options.Add(domain_v2); diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc index 113bfb8545..069246b420 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc @@ -19,7 +19,7 @@ using namespace Ort::Custom; throw std::runtime_error(msg); \ } -namespace Cuda { +namespace Rocm { void KernelOne(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, @@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_CustomOpOne.get()); } -} // namespace Cuda - -#else - -void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {} +} // namespace Rocm #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h index 4e8958cd9d..d3e9e4040a 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h @@ -5,6 +5,14 @@ namespace Rocm { +#ifdef USE_ROCM + void RegisterOps(Ort::CustomOpDomain& domain); -} \ No newline at end of file +#else + +inline void RegisterOps(Ort::CustomOpDomain&) {} + +#endif + +} // namespace Rocm