diff --git a/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp b/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp index 7b7002a2a3..0e691aabce 100644 --- a/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp +++ b/winml/lib/Api.Core/CpuOrtSessionBuilder.cpp @@ -21,6 +21,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" +#include "core/session/abi_session_options_impl.h" using namespace Windows::AI::MachineLearning; @@ -32,26 +33,29 @@ CpuOrtSessionBuilder::CpuOrtSessionBuilder() { HRESULT CpuOrtSessionBuilder::CreateSessionOptions( - ISessionOptions** p_options) { - RETURN_HR_IF_NULL(E_POINTER, p_options); + OrtSessionOptions** options) { + RETURN_HR_IF_NULL(E_POINTER, options); - auto options = wil::MakeOrThrow(); - options.CopyTo(__uuidof(ISessionOptions), (void**)p_options); + Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options)); + std::unique_ptr session_options = std::make_unique(*options); - (*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3; + // set the graph optimization level to all (used to be called level 3) + session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); // Onnxruntime will use half the number of concurrent threads supported on the system // by default. This causes MLAS to not exercise every logical core. // We force the thread pool size to be maxxed out to ensure that WinML always // runs the fastest. - (*p_options)->get().intra_op_num_threads = std::thread::hardware_concurrency(); + session_options->SetIntraOpNumThreads(std::thread::hardware_concurrency()); + // all done with the smart ptr + session_options.release(); return S_OK; } HRESULT CpuOrtSessionBuilder::CreateSession( - ISessionOptions* options, + OrtSessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) { RETURN_HR_IF_NULL(E_POINTER, p_session); @@ -59,7 +63,7 @@ CpuOrtSessionBuilder::CreateSession( RETURN_HR_IF(E_POINTER, *pp_provider != nullptr); // Create the inference session - auto session = std::make_unique(options->get()); + auto session = std::make_unique(options->value); // Create the cpu execution provider onnxruntime::CPUExecutionProviderInfo xpInfo; diff --git a/winml/lib/Api.Core/CpuOrtSessionBuilder.h b/winml/lib/Api.Core/CpuOrtSessionBuilder.h index 5612656970..4151867638 100644 --- a/winml/lib/Api.Core/CpuOrtSessionBuilder.h +++ b/winml/lib/Api.Core/CpuOrtSessionBuilder.h @@ -15,10 +15,10 @@ class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < CpuOrtSessionBuilder(); HRESULT STDMETHODCALLTYPE CreateSessionOptions( - ISessionOptions** p_options) override; + OrtSessionOptions** options) override; HRESULT STDMETHODCALLTYPE CreateSession( - ISessionOptions* options, + OrtSessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) override; diff --git a/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp b/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp index b9143fb2a8..daf58d1609 100644 --- a/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp +++ b/winml/lib/Api.Core/DmlOrtSessionBuilder.cpp @@ -25,6 +25,7 @@ #include "core/framework/op_node_proto_helper.h" #include "core/framework/customRegistry.h" #include "core/framework/data_transfer.h" +#include "core/session/abi_session_options_impl.h" using namespace Windows::AI::MachineLearning; @@ -39,17 +40,20 @@ DmlOrtSessionBuilder::DmlOrtSessionBuilder( HRESULT DmlOrtSessionBuilder::CreateSessionOptions( - ISessionOptions** p_options) { - RETURN_HR_IF_NULL(E_POINTER, p_options); + OrtSessionOptions** options) { + RETURN_HR_IF_NULL(E_POINTER, options); - auto options = wil::MakeOrThrow(); - options.CopyTo(__uuidof(ISessionOptions), (void**)p_options); + Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options)); + std::unique_ptr session_options = std::make_unique(*options); - (*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3; + // set the graph optimization level to all (used to be called level 3) + session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. - (*p_options)->get().enable_mem_pattern = false; + session_options->DisableMemPattern(); + // all done with the smart ptr + session_options.release(); return S_OK; } @@ -101,7 +105,7 @@ Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { } HRESULT DmlOrtSessionBuilder::CreateSession( - ISessionOptions* options, + OrtSessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) { RETURN_HR_IF_NULL(E_POINTER, p_session); @@ -114,7 +118,7 @@ HRESULT DmlOrtSessionBuilder::CreateSession( Microsoft::WRL::ComPtr dmlDevice = CreateDmlDevice(p_d3d_device); std::unique_ptr gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue); - auto session = std::make_unique(options->get()); + auto session = std::make_unique(options->value); // Cache the provider's raw pointer *pp_provider = gpu_provider.get(); diff --git a/winml/lib/Api.Core/DmlOrtSessionBuilder.h b/winml/lib/Api.Core/DmlOrtSessionBuilder.h index 2788a41f7b..ae3e04e04f 100644 --- a/winml/lib/Api.Core/DmlOrtSessionBuilder.h +++ b/winml/lib/Api.Core/DmlOrtSessionBuilder.h @@ -15,10 +15,10 @@ class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue); HRESULT STDMETHODCALLTYPE CreateSessionOptions( - ISessionOptions** p_options) override; + OrtSessionOptions** options) override; HRESULT STDMETHODCALLTYPE CreateSession( - ISessionOptions* options, + OrtSessionOptions* options, _winmla::IInferenceSession** p_session, onnxruntime::IExecutionProvider** pp_provider) override; diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index 43f6b44df6..524bfbf8ac 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -3,6 +3,8 @@ #pragma once +#include "core/session/onnxruntime_c_api.h" + namespace Windows::AI::MachineLearning::Adapter { MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{ @@ -70,22 +72,15 @@ MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnkn virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0; }; -MIDL_INTERFACE("55a956a7-c20e-440d-b2d2-a77acf35de10") ISessionOptions : IUnknown{ - // this returns a weak ref - virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() = 0; - // end - virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) = 0; -}; - // The IOrtSessionBuilder offers an abstraction over the creation of // InferenceSession, that enables the creation of the session based on a device (CPU/DML). MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown { virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions( - ISessionOptions** options) = 0; + OrtSessionOptions ** options) = 0; virtual HRESULT STDMETHODCALLTYPE CreateSession( - ISessionOptions* options, + OrtSessionOptions * options, IInferenceSession** session, onnxruntime::IExecutionProvider** provider) = 0; @@ -191,23 +186,6 @@ private: std::shared_ptr session_; }; -class AbiSafeSessionOptions : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - ISessionOptions> { -private: - onnxruntime::SessionOptions options_; -public: - virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() override { - return options_; - } - virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) override { - onnxruntime::FreeDimensionOverride overrideOption = {}; - overrideOption.dimension_denotation = onnx::DATA_BATCH; - overrideOption.dimension_override = batch_size; - options_.free_dimension_overrides.emplace_back(overrideOption); - } -}; - // header only code to enable smart pointers on abstract ort objects template class OrtObject { diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index fbedef5210..3d0befe8a9 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -109,18 +109,22 @@ void LearningModelSession::Initialize() { device_impl->GetDeviceQueue(), session_builder.put())); - com_ptr<_winmla::ISessionOptions> options; - WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put())); + OrtSessionOptions* options_ptr; + WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(&options_ptr)); + std::unique_ptr options = std::make_unique(options_ptr); // Make onnxruntime apply the batch size override, if any if (session_options_ && session_options_.BatchSizeOverride() != 0) { - options->SetBatchOverride(session_options_.BatchSizeOverride()); + Ort::ThrowOnError(Ort::GetApi().AddFreeDimensionOverride( + *(options.get()), + onnx::DATA_BATCH, + session_options_.BatchSizeOverride())); } com_ptr<_winmla::IInferenceSession> session; WINML_THROW_IF_FAILED(session_builder->CreateSession( - options.get(), session.put(), &cached_execution_provider_)); + *(options.get()), session.put(), &cached_execution_provider_)); // Register the custom operator registry auto model = model_.as(); diff --git a/winml/lib/Common/inc/onnx.h b/winml/lib/Common/inc/onnx.h index cf32ee4824..0b190033fb 100644 --- a/winml/lib/Common/inc/onnx.h +++ b/winml/lib/Common/inc/onnx.h @@ -13,6 +13,9 @@ // Restore ERROR define #define ERROR 0 +// the C++ ort api +#include "core/session/onnxruntime_cxx_api.h" + #include #include "core/framework/customregistry.h"