Add new API KernelContext_GetScratchBuffer (#19809)

### Description
<!-- Describe your changes. -->
add new API KernelContext_GetScratchBuffer to get scratch buffer from
kernel context


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
add new API KernelContext_GetScratchBuffer to get scratch buffer from
kernel context which will be used in ORT extension project for
GroupQueryAttention custom op
This commit is contained in:
cao lei 2024-03-13 19:41:15 -07:00 committed by GitHub
parent 18ad8587a6
commit 2c525a79b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 38 additions and 3 deletions

View file

@ -100,6 +100,8 @@ class Stream {
return nullptr;
}
virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; }
private:
StreamHandle handle_;
const OrtDevice& device_;

View file

@ -4590,6 +4590,16 @@ struct OrtApi {
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);
/** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object.
* NOTE: callers are responsible to release this scratch buffer from the corresponding allocator
* \param[in] context OrtKernelContext instance
* \param[in] mem_info OrtMemoryInfo instance
* \param[in] count_or_bytes How many bytes is this scratch buffer
* \param[out] out A pointer to the scrach buffer
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
};
/*

View file

@ -12,6 +12,7 @@
#include "core/providers/cann/cann_call.h"
namespace onnxruntime {
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
struct CannStream : Stream {
CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag);
@ -23,10 +24,11 @@ struct CannStream : Stream {
void Flush() override;
bool own_stream_{true};
WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; }
};
void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
const OrtDevice::DeviceType device_type);
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime

View file

@ -11,6 +11,7 @@
namespace onnxruntime {
struct CudaStream;
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
struct DeferredCpuAllocator : public OrtAllocator {
DeferredCpuAllocator(CudaStream&);
@ -47,6 +48,8 @@ struct CudaStream : Stream {
onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; }
private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublass_handle,
const CUDAExecutionProviderInfo& ep_info);
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime

View file

@ -8,6 +8,7 @@
#include "core/framework/stream_handles.h"
namespace onnxruntime {
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
struct RocmStream : Stream {
RocmStream(hipStream_t stream,
@ -36,6 +37,8 @@ struct RocmStream : Stream {
void* GetResource(int version, int id) const override;
WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; }
private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
@ -50,5 +53,4 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
bool use_existing_stream,
miopenHandle_t external_miopen_handle,
rocblas_handle external_rocblas_handle);
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime

View file

@ -736,6 +736,20 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
if (count_or_bytes == 0) {
*out = nullptr;
return nullptr;
}
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetAllocator(mem_info->device);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
onnxruntime::Stream* stream = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetComputeStream();
*out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn());
return nullptr;
};
#if ENABLE_CUSTOM_OP_API
#include "core/framework/customregistry.h"
namespace onnxruntime {

View file

@ -2725,6 +2725,7 @@ static constexpr OrtApi ort_api_1_to_18 = {
&OrtApis::KernelContext_ParallelFor,
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
&OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
&OrtApis::KernelContext_GetScratchBuffer,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

View file

@ -513,4 +513,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options,
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
} // namespace OrtApis