From 2c525a79b1a53339664e2dfd21f604070f150781 Mon Sep 17 00:00:00 2001 From: cao lei Date: Wed, 13 Mar 2024 19:41:15 -0700 Subject: [PATCH] Add new API KernelContext_GetScratchBuffer (#19809) ### Description add new API KernelContext_GetScratchBuffer to get scratch buffer from kernel context ### Motivation and Context add new API KernelContext_GetScratchBuffer to get scratch buffer from kernel context which will be used in ORT extension project for GroupQueryAttention custom op --- .../onnxruntime/core/framework/stream_handles.h | 2 ++ .../onnxruntime/core/session/onnxruntime_c_api.h | 10 ++++++++++ .../core/providers/cann/cann_stream_handle.h | 4 +++- .../core/providers/cuda/cuda_stream_handle.h | 4 +++- .../core/providers/rocm/rocm_stream_handle.h | 4 +++- onnxruntime/core/session/custom_ops.cc | 14 ++++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 2 ++ 8 files changed, 38 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index c235ee9047..26d78133b5 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -100,6 +100,8 @@ class Stream { return nullptr; } + virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; } + private: StreamHandle handle_; const OrtDevice& device_; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 144ee1205e..2e2a903da2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4590,6 +4590,16 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. + * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] count_or_bytes How many bytes is this scratch buffer + * \param[out] out A pointer to the scrach buffer + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); }; /* diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h index 4d03fe5201..5d822d23f9 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.h +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -12,6 +12,7 @@ #include "core/providers/cann/cann_call.h" namespace onnxruntime { +void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct CannStream : Stream { CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag); @@ -23,10 +24,11 @@ struct CannStream : Stream { void Flush() override; bool own_stream_{true}; + + WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; } }; void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type); -void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index b02c167e9e..15e7a0553c 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -11,6 +11,7 @@ namespace onnxruntime { struct CudaStream; +void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct DeferredCpuAllocator : public OrtAllocator { DeferredCpuAllocator(CudaStream&); @@ -47,6 +48,8 @@ struct CudaStream : Stream { onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; @@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis cudnnHandle_t external_cudnn_handle, cublasHandle_t external_cublass_handle, const CUDAExecutionProviderInfo& ep_info); -void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 1f3e5b7554..30983ce035 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -8,6 +8,7 @@ #include "core/framework/stream_handles.h" namespace onnxruntime { +void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); struct RocmStream : Stream { RocmStream(hipStream_t stream, @@ -36,6 +37,8 @@ struct RocmStream : Stream { void* GetResource(int version, int id) const override; + WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; @@ -50,5 +53,4 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis bool use_existing_stream, miopenHandle_t external_miopen_handle, rocblas_handle external_rocblas_handle); -void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); } // namespace onnxruntime diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 7a233c57cf..9a09c74cd0 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -736,6 +736,20 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf }); } +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) { + if (count_or_bytes == 0) { + *out = nullptr; + return nullptr; + } + onnxruntime::AllocatorPtr allocator = reinterpret_cast(context)->GetAllocator(mem_info->device); + if (!allocator) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); + } + onnxruntime::Stream* stream = reinterpret_cast(context)->GetComputeStream(); + *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn()); + return nullptr; +}; + #if ENABLE_CUSTOM_OP_API #include "core/framework/customregistry.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index dec8754ea2..273f94ae5d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2725,6 +2725,7 @@ static constexpr OrtApi ort_api_1_to_18 = { &OrtApis::KernelContext_ParallelFor, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + &OrtApis::KernelContext_GetScratchBuffer, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ce94ba89a..adb89f7f85 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -513,4 +513,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); } // namespace OrtApis