mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
Moved SessionOptions over to the abi
This commit is contained in:
parent
94fc7bccff
commit
44a4fc0cc2
7 changed files with 43 additions and 50 deletions
|
|
@ -21,6 +21,7 @@
|
|||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
|
||||
using namespace Windows::AI::MachineLearning;
|
||||
|
||||
|
|
@ -32,26 +33,29 @@ CpuOrtSessionBuilder::CpuOrtSessionBuilder() {
|
|||
|
||||
HRESULT
|
||||
CpuOrtSessionBuilder::CreateSessionOptions(
|
||||
ISessionOptions** p_options) {
|
||||
RETURN_HR_IF_NULL(E_POINTER, p_options);
|
||||
OrtSessionOptions** options) {
|
||||
RETURN_HR_IF_NULL(E_POINTER, options);
|
||||
|
||||
auto options = wil::MakeOrThrow<AbiSafeSessionOptions>();
|
||||
options.CopyTo(__uuidof(ISessionOptions), (void**)p_options);
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
|
||||
std::unique_ptr<Ort::SessionOptions> session_options = std::make_unique<Ort::SessionOptions>(*options);
|
||||
|
||||
(*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3;
|
||||
// set the graph optimization level to all (used to be called level 3)
|
||||
session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
|
||||
// Onnxruntime will use half the number of concurrent threads supported on the system
|
||||
// by default. This causes MLAS to not exercise every logical core.
|
||||
// We force the thread pool size to be maxxed out to ensure that WinML always
|
||||
// runs the fastest.
|
||||
(*p_options)->get().intra_op_num_threads = std::thread::hardware_concurrency();
|
||||
session_options->SetIntraOpNumThreads(std::thread::hardware_concurrency());
|
||||
|
||||
// all done with the smart ptr
|
||||
session_options.release();
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
HRESULT
|
||||
CpuOrtSessionBuilder::CreateSession(
|
||||
ISessionOptions* options,
|
||||
OrtSessionOptions* options,
|
||||
_winmla::IInferenceSession** p_session,
|
||||
onnxruntime::IExecutionProvider** pp_provider) {
|
||||
RETURN_HR_IF_NULL(E_POINTER, p_session);
|
||||
|
|
@ -59,7 +63,7 @@ CpuOrtSessionBuilder::CreateSession(
|
|||
RETURN_HR_IF(E_POINTER, *pp_provider != nullptr);
|
||||
|
||||
// Create the inference session
|
||||
auto session = std::make_unique<onnxruntime::InferenceSession>(options->get());
|
||||
auto session = std::make_unique<onnxruntime::InferenceSession>(options->value);
|
||||
|
||||
// Create the cpu execution provider
|
||||
onnxruntime::CPUExecutionProviderInfo xpInfo;
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <
|
|||
CpuOrtSessionBuilder();
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CreateSessionOptions(
|
||||
ISessionOptions** p_options) override;
|
||||
OrtSessionOptions** options) override;
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CreateSession(
|
||||
ISessionOptions* options,
|
||||
OrtSessionOptions* options,
|
||||
_winmla::IInferenceSession** p_session,
|
||||
onnxruntime::IExecutionProvider** pp_provider) override;
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include "core/framework/op_node_proto_helper.h"
|
||||
#include "core/framework/customRegistry.h"
|
||||
#include "core/framework/data_transfer.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
|
||||
using namespace Windows::AI::MachineLearning;
|
||||
|
||||
|
|
@ -39,17 +40,20 @@ DmlOrtSessionBuilder::DmlOrtSessionBuilder(
|
|||
|
||||
HRESULT
|
||||
DmlOrtSessionBuilder::CreateSessionOptions(
|
||||
ISessionOptions** p_options) {
|
||||
RETURN_HR_IF_NULL(E_POINTER, p_options);
|
||||
OrtSessionOptions** options) {
|
||||
RETURN_HR_IF_NULL(E_POINTER, options);
|
||||
|
||||
auto options = wil::MakeOrThrow<AbiSafeSessionOptions>();
|
||||
options.CopyTo(__uuidof(ISessionOptions), (void**)p_options);
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
|
||||
std::unique_ptr<Ort::SessionOptions> session_options = std::make_unique<Ort::SessionOptions>(*options);
|
||||
|
||||
(*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3;
|
||||
// set the graph optimization level to all (used to be called level 3)
|
||||
session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
|
||||
// Disable the mem pattern session option for DML. It will cause problems with how memory is allocated.
|
||||
(*p_options)->get().enable_mem_pattern = false;
|
||||
session_options->DisableMemPattern();
|
||||
|
||||
// all done with the smart ptr
|
||||
session_options.release();
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +105,7 @@ Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
|
|||
}
|
||||
|
||||
HRESULT DmlOrtSessionBuilder::CreateSession(
|
||||
ISessionOptions* options,
|
||||
OrtSessionOptions* options,
|
||||
_winmla::IInferenceSession** p_session,
|
||||
onnxruntime::IExecutionProvider** pp_provider) {
|
||||
RETURN_HR_IF_NULL(E_POINTER, p_session);
|
||||
|
|
@ -114,7 +118,7 @@ HRESULT DmlOrtSessionBuilder::CreateSession(
|
|||
Microsoft::WRL::ComPtr<IDMLDevice> dmlDevice = CreateDmlDevice(p_d3d_device);
|
||||
|
||||
std::unique_ptr<onnxruntime::IExecutionProvider> gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue);
|
||||
auto session = std::make_unique<onnxruntime::InferenceSession>(options->get());
|
||||
auto session = std::make_unique<onnxruntime::InferenceSession>(options->value);
|
||||
|
||||
// Cache the provider's raw pointer
|
||||
*pp_provider = gpu_provider.get();
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <
|
|||
DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue);
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CreateSessionOptions(
|
||||
ISessionOptions** p_options) override;
|
||||
OrtSessionOptions** options) override;
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CreateSession(
|
||||
ISessionOptions* options,
|
||||
OrtSessionOptions* options,
|
||||
_winmla::IInferenceSession** p_session,
|
||||
onnxruntime::IExecutionProvider** pp_provider) override;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
namespace Windows::AI::MachineLearning::Adapter {
|
||||
|
||||
MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{
|
||||
|
|
@ -70,22 +72,15 @@ MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnkn
|
|||
virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0;
|
||||
};
|
||||
|
||||
MIDL_INTERFACE("55a956a7-c20e-440d-b2d2-a77acf35de10") ISessionOptions : IUnknown{
|
||||
// this returns a weak ref
|
||||
virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() = 0;
|
||||
// end
|
||||
virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) = 0;
|
||||
};
|
||||
|
||||
// The IOrtSessionBuilder offers an abstraction over the creation of
|
||||
// InferenceSession, that enables the creation of the session based on a device (CPU/DML).
|
||||
MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown {
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions(
|
||||
ISessionOptions** options) = 0;
|
||||
OrtSessionOptions ** options) = 0;
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE CreateSession(
|
||||
ISessionOptions* options,
|
||||
OrtSessionOptions * options,
|
||||
IInferenceSession** session,
|
||||
onnxruntime::IExecutionProvider** provider) = 0;
|
||||
|
||||
|
|
@ -191,23 +186,6 @@ private:
|
|||
std::shared_ptr<onnxruntime::InferenceSession> session_;
|
||||
};
|
||||
|
||||
class AbiSafeSessionOptions : public Microsoft::WRL::RuntimeClass <
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
||||
ISessionOptions> {
|
||||
private:
|
||||
onnxruntime::SessionOptions options_;
|
||||
public:
|
||||
virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() override {
|
||||
return options_;
|
||||
}
|
||||
virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) override {
|
||||
onnxruntime::FreeDimensionOverride overrideOption = {};
|
||||
overrideOption.dimension_denotation = onnx::DATA_BATCH;
|
||||
overrideOption.dimension_override = batch_size;
|
||||
options_.free_dimension_overrides.emplace_back(overrideOption);
|
||||
}
|
||||
};
|
||||
|
||||
// header only code to enable smart pointers on abstract ort objects
|
||||
template <typename T>
|
||||
class OrtObject {
|
||||
|
|
|
|||
|
|
@ -109,18 +109,22 @@ void LearningModelSession::Initialize() {
|
|||
device_impl->GetDeviceQueue(),
|
||||
session_builder.put()));
|
||||
|
||||
com_ptr<_winmla::ISessionOptions> options;
|
||||
WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put()));
|
||||
OrtSessionOptions* options_ptr;
|
||||
WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(&options_ptr));
|
||||
std::unique_ptr<Ort::SessionOptions> options = std::make_unique<Ort::SessionOptions>(options_ptr);
|
||||
|
||||
// Make onnxruntime apply the batch size override, if any
|
||||
if (session_options_ && session_options_.BatchSizeOverride() != 0)
|
||||
{
|
||||
options->SetBatchOverride(session_options_.BatchSizeOverride());
|
||||
Ort::ThrowOnError(Ort::GetApi().AddFreeDimensionOverride(
|
||||
*(options.get()),
|
||||
onnx::DATA_BATCH,
|
||||
session_options_.BatchSizeOverride()));
|
||||
}
|
||||
|
||||
com_ptr<_winmla::IInferenceSession> session;
|
||||
WINML_THROW_IF_FAILED(session_builder->CreateSession(
|
||||
options.get(), session.put(), &cached_execution_provider_));
|
||||
*(options.get()), session.put(), &cached_execution_provider_));
|
||||
|
||||
// Register the custom operator registry
|
||||
auto model = model_.as<winmlp::LearningModel>();
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
// Restore ERROR define
|
||||
#define ERROR 0
|
||||
|
||||
// the C++ ort api
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
|
||||
#include <DirectML.h>
|
||||
|
||||
#include "core/framework/customregistry.h"
|
||||
|
|
|
|||
Loading…
Reference in a new issue