onnxruntime/winml/adapter/winml_adapter_execution_provider.cpp
cao lei dd72192cf4
ExecutionProvider API refactor - move allocator from EP level to SessionState level and indexed by OrtDevice (#15833)
### Description
This PR is to refactor ExecutionProvider API for memory management,
which is to move allocators from EP level to SessionState level and
indexed by OrtDevice



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This PR is to refactor ExecutionProvider API for memory management,
which is to move allocators from EP level to SessionState level and
indexed by OrtDevice. By this change, EP level will shift the burden of
maintaining allocators, which will be user friendly for EP developers

---------

Co-authored-by: Lei Cao <leca@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2023-06-19 17:44:45 -07:00

79 lines
2.9 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "adapter/pch.h"
#include "winml_adapter_c_api.h"
#include "core/session/ort_apis.h"
#include "winml_adapter_apis.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_provider.h"
namespace winmla = Windows::AI::MachineLearning::Adapter;
struct OrtAllocatorWrapper : public OrtAllocator {
public:
OrtAllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) {
version = ORT_API_VERSION;
Alloc = AllocImpl;
Free = FreeImpl;
Info = InfoImpl;
}
static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) {
return static_cast<OrtAllocatorWrapper*>(this_)->impl_->Alloc(size);
}
static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) {
return static_cast<OrtAllocatorWrapper*>(this_)->impl_->Free(p);
}
static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) {
return &(static_cast<const OrtAllocatorWrapper*>(this_)->impl_->Info());
}
private:
onnxruntime::AllocatorPtr impl_;
};
ORT_API_STATUS_IMPL(winmla::ExecutionProviderSync, _In_ OrtExecutionProvider* provider) {
API_IMPL_BEGIN
const auto execution_provider = reinterpret_cast<onnxruntime::IExecutionProvider*>(provider);
ORT_API_RETURN_IF_STATUS_NOT_OK(execution_provider->Sync());
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::GetProviderAllocator, _In_ OrtSession* session, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator) {
API_IMPL_BEGIN
auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session);
const auto execution_provider = reinterpret_cast<onnxruntime::IExecutionProvider*>(provider);
OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, execution_provider->GetOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault));
auto allocator_ptr = inference_session->GetAllocator(mem_info); // TODO(leca): REVIEW
*allocator = new (std::nothrow) OrtAllocatorWrapper(allocator_ptr);
if (*allocator == nullptr) {
return OrtApis::CreateStatus(ORT_FAIL, "Out of memory");
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info) {
API_IMPL_BEGIN
const auto execution_provider = reinterpret_cast<onnxruntime::IExecutionProvider*>(provider);
auto device = execution_provider->GetOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault);
*memory_info = new (std::nothrow) OrtMemoryInfo("", ::OrtAllocatorType::OrtDeviceAllocator, device); // TODO(leca): REVIEW
if (*memory_info == nullptr) {
return OrtApis::CreateStatus(ORT_FAIL, "Out of memory");
}
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator) {
API_IMPL_BEGIN
delete static_cast<OrtAllocatorWrapper*>(allocator);
return nullptr;
API_IMPL_END
}