mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add new API KernelContext_GetScratchBuffer (#19809)
### Description <!-- Describe your changes. --> add new API KernelContext_GetScratchBuffer to get scratch buffer from kernel context ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> add new API KernelContext_GetScratchBuffer to get scratch buffer from kernel context which will be used in ORT extension project for GroupQueryAttention custom op
This commit is contained in:
parent
18ad8587a6
commit
2c525a79b1
8 changed files with 38 additions and 3 deletions
|
|
@ -100,6 +100,8 @@ class Stream {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; }
|
||||
|
||||
private:
|
||||
StreamHandle handle_;
|
||||
const OrtDevice& device_;
|
||||
|
|
|
|||
|
|
@ -4590,6 +4590,16 @@ struct OrtApi {
|
|||
_In_reads_(num_keys) const char* const* provider_options_keys,
|
||||
_In_reads_(num_keys) const char* const* provider_options_values,
|
||||
_In_ size_t num_keys);
|
||||
|
||||
/** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object.
|
||||
* NOTE: callers are responsible to release this scratch buffer from the corresponding allocator
|
||||
* \param[in] context OrtKernelContext instance
|
||||
* \param[in] mem_info OrtMemoryInfo instance
|
||||
* \param[in] count_or_bytes How many bytes is this scratch buffer
|
||||
* \param[out] out A pointer to the scrach buffer
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/providers/cann/cann_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
|
||||
struct CannStream : Stream {
|
||||
CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag);
|
||||
|
|
@ -23,10 +24,11 @@ struct CannStream : Stream {
|
|||
void Flush() override;
|
||||
|
||||
bool own_stream_{true};
|
||||
|
||||
WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; }
|
||||
};
|
||||
|
||||
void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
|
||||
const OrtDevice::DeviceType device_type);
|
||||
|
||||
void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
struct CudaStream;
|
||||
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
|
||||
struct DeferredCpuAllocator : public OrtAllocator {
|
||||
DeferredCpuAllocator(CudaStream&);
|
||||
|
|
@ -47,6 +48,8 @@ struct CudaStream : Stream {
|
|||
|
||||
onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
|
||||
|
||||
WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; }
|
||||
|
||||
private:
|
||||
std::vector<void*> deferred_cpu_buffers_;
|
||||
AllocatorPtr cpu_allocator_;
|
||||
|
|
@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
|
|||
cudnnHandle_t external_cudnn_handle,
|
||||
cublasHandle_t external_cublass_handle,
|
||||
const CUDAExecutionProviderInfo& ep_info);
|
||||
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/framework/stream_handles.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
|
||||
struct RocmStream : Stream {
|
||||
RocmStream(hipStream_t stream,
|
||||
|
|
@ -36,6 +37,8 @@ struct RocmStream : Stream {
|
|||
|
||||
void* GetResource(int version, int id) const override;
|
||||
|
||||
WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; }
|
||||
|
||||
private:
|
||||
std::vector<void*> deferred_cpu_buffers_;
|
||||
AllocatorPtr cpu_allocator_;
|
||||
|
|
@ -50,5 +53,4 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
|
|||
bool use_existing_stream,
|
||||
miopenHandle_t external_miopen_handle,
|
||||
rocblas_handle external_rocblas_handle);
|
||||
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -736,6 +736,20 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf
|
|||
});
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
|
||||
if (count_or_bytes == 0) {
|
||||
*out = nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetAllocator(mem_info->device);
|
||||
if (!allocator) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
|
||||
}
|
||||
onnxruntime::Stream* stream = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetComputeStream();
|
||||
*out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn());
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
#if ENABLE_CUSTOM_OP_API
|
||||
#include "core/framework/customregistry.h"
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -2725,6 +2725,7 @@ static constexpr OrtApi ort_api_1_to_18 = {
|
|||
&OrtApis::KernelContext_ParallelFor,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
|
||||
&OrtApis::KernelContext_GetScratchBuffer,
|
||||
};
|
||||
|
||||
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
|
||||
|
|
|
|||
|
|
@ -513,4 +513,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
|
|||
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options,
|
||||
_In_reads_(num_keys) const char* const* provider_options_keys,
|
||||
_In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
|
||||
|
||||
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
Loading…
Reference in a new issue