mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Description This PR is to refactor ExecutionProvider API for memory management, which is to move allocators from EP level to SessionState level and indexed by OrtDevice ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This PR is to refactor ExecutionProvider API for memory management, which is to move allocators from EP level to SessionState level and indexed by OrtDevice. By this change, EP level will shift the burden of maintaining allocators, which will be user friendly for EP developers --------- Co-authored-by: Lei Cao <leca@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
79 lines
2.9 KiB
C++
79 lines
2.9 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "adapter/pch.h"
|
|
|
|
#include "winml_adapter_c_api.h"
|
|
#include "core/session/ort_apis.h"
|
|
#include "winml_adapter_apis.h"
|
|
#include "core/framework/error_code_helper.h"
|
|
#include "core/framework/execution_provider.h"
|
|
|
|
namespace winmla = Windows::AI::MachineLearning::Adapter;
|
|
|
|
struct OrtAllocatorWrapper : public OrtAllocator {
|
|
public:
|
|
OrtAllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) {
|
|
version = ORT_API_VERSION;
|
|
Alloc = AllocImpl;
|
|
Free = FreeImpl;
|
|
Info = InfoImpl;
|
|
}
|
|
|
|
static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) {
|
|
return static_cast<OrtAllocatorWrapper*>(this_)->impl_->Alloc(size);
|
|
}
|
|
static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) {
|
|
return static_cast<OrtAllocatorWrapper*>(this_)->impl_->Free(p);
|
|
}
|
|
static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) {
|
|
return &(static_cast<const OrtAllocatorWrapper*>(this_)->impl_->Info());
|
|
}
|
|
|
|
private:
|
|
onnxruntime::AllocatorPtr impl_;
|
|
};
|
|
|
|
ORT_API_STATUS_IMPL(winmla::ExecutionProviderSync, _In_ OrtExecutionProvider* provider) {
|
|
API_IMPL_BEGIN
|
|
const auto execution_provider = reinterpret_cast<onnxruntime::IExecutionProvider*>(provider);
|
|
ORT_API_RETURN_IF_STATUS_NOT_OK(execution_provider->Sync());
|
|
return nullptr;
|
|
API_IMPL_END
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(winmla::GetProviderAllocator, _In_ OrtSession* session, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator) {
|
|
API_IMPL_BEGIN
|
|
auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session);
|
|
const auto execution_provider = reinterpret_cast<onnxruntime::IExecutionProvider*>(provider);
|
|
OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, execution_provider->GetOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault));
|
|
auto allocator_ptr = inference_session->GetAllocator(mem_info); // TODO(leca): REVIEW
|
|
*allocator = new (std::nothrow) OrtAllocatorWrapper(allocator_ptr);
|
|
if (*allocator == nullptr) {
|
|
return OrtApis::CreateStatus(ORT_FAIL, "Out of memory");
|
|
}
|
|
return nullptr;
|
|
API_IMPL_END
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(winmla::GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info) {
|
|
API_IMPL_BEGIN
|
|
const auto execution_provider = reinterpret_cast<onnxruntime::IExecutionProvider*>(provider);
|
|
|
|
auto device = execution_provider->GetOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault);
|
|
*memory_info = new (std::nothrow) OrtMemoryInfo("", ::OrtAllocatorType::OrtDeviceAllocator, device); // TODO(leca): REVIEW
|
|
if (*memory_info == nullptr) {
|
|
return OrtApis::CreateStatus(ORT_FAIL, "Out of memory");
|
|
}
|
|
return nullptr;
|
|
API_IMPL_END
|
|
}
|
|
|
|
ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator) {
|
|
API_IMPL_BEGIN
|
|
delete static_cast<OrtAllocatorWrapper*>(allocator);
|
|
return nullptr;
|
|
API_IMPL_END
|
|
}
|