Moved SessionOptions over to the abi

This commit is contained in:
Paul McDaniel 2019-11-19 18:15:47 -08:00
parent 94fc7bccff
commit 44a4fc0cc2
7 changed files with 43 additions and 50 deletions

View file

@ -21,6 +21,7 @@
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/session/abi_session_options_impl.h"
using namespace Windows::AI::MachineLearning;
@ -32,26 +33,29 @@ CpuOrtSessionBuilder::CpuOrtSessionBuilder() {
HRESULT
CpuOrtSessionBuilder::CreateSessionOptions(
ISessionOptions** p_options) {
RETURN_HR_IF_NULL(E_POINTER, p_options);
OrtSessionOptions** options) {
RETURN_HR_IF_NULL(E_POINTER, options);
auto options = wil::MakeOrThrow<AbiSafeSessionOptions>();
options.CopyTo(__uuidof(ISessionOptions), (void**)p_options);
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
std::unique_ptr<Ort::SessionOptions> session_options = std::make_unique<Ort::SessionOptions>(*options);
(*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3;
// set the graph optimization level to all (used to be called level 3)
session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
// Onnxruntime will use half the number of concurrent threads supported on the system
// by default. This causes MLAS to not exercise every logical core.
// We force the thread pool size to be maxxed out to ensure that WinML always
// runs the fastest.
(*p_options)->get().intra_op_num_threads = std::thread::hardware_concurrency();
session_options->SetIntraOpNumThreads(std::thread::hardware_concurrency());
// all done with the smart ptr
session_options.release();
return S_OK;
}
HRESULT
CpuOrtSessionBuilder::CreateSession(
ISessionOptions* options,
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) {
RETURN_HR_IF_NULL(E_POINTER, p_session);
@ -59,7 +63,7 @@ CpuOrtSessionBuilder::CreateSession(
RETURN_HR_IF(E_POINTER, *pp_provider != nullptr);
// Create the inference session
auto session = std::make_unique<onnxruntime::InferenceSession>(options->get());
auto session = std::make_unique<onnxruntime::InferenceSession>(options->value);
// Create the cpu execution provider
onnxruntime::CPUExecutionProviderInfo xpInfo;

View file

@ -15,10 +15,10 @@ class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <
CpuOrtSessionBuilder();
HRESULT STDMETHODCALLTYPE CreateSessionOptions(
ISessionOptions** p_options) override;
OrtSessionOptions** options) override;
HRESULT STDMETHODCALLTYPE CreateSession(
ISessionOptions* options,
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) override;

View file

@ -25,6 +25,7 @@
#include "core/framework/op_node_proto_helper.h"
#include "core/framework/customRegistry.h"
#include "core/framework/data_transfer.h"
#include "core/session/abi_session_options_impl.h"
using namespace Windows::AI::MachineLearning;
@ -39,17 +40,20 @@ DmlOrtSessionBuilder::DmlOrtSessionBuilder(
HRESULT
DmlOrtSessionBuilder::CreateSessionOptions(
ISessionOptions** p_options) {
RETURN_HR_IF_NULL(E_POINTER, p_options);
OrtSessionOptions** options) {
RETURN_HR_IF_NULL(E_POINTER, options);
auto options = wil::MakeOrThrow<AbiSafeSessionOptions>();
options.CopyTo(__uuidof(ISessionOptions), (void**)p_options);
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
std::unique_ptr<Ort::SessionOptions> session_options = std::make_unique<Ort::SessionOptions>(*options);
(*p_options)->get().graph_optimization_level = onnxruntime::TransformerLevel::Level3;
// set the graph optimization level to all (used to be called level 3)
session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
// Disable the mem pattern session option for DML. It will cause problems with how memory is allocated.
(*p_options)->get().enable_mem_pattern = false;
session_options->DisableMemPattern();
// all done with the smart ptr
session_options.release();
return S_OK;
}
@ -101,7 +105,7 @@ Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
}
HRESULT DmlOrtSessionBuilder::CreateSession(
ISessionOptions* options,
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) {
RETURN_HR_IF_NULL(E_POINTER, p_session);
@ -114,7 +118,7 @@ HRESULT DmlOrtSessionBuilder::CreateSession(
Microsoft::WRL::ComPtr<IDMLDevice> dmlDevice = CreateDmlDevice(p_d3d_device);
std::unique_ptr<onnxruntime::IExecutionProvider> gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue);
auto session = std::make_unique<onnxruntime::InferenceSession>(options->get());
auto session = std::make_unique<onnxruntime::InferenceSession>(options->value);
// Cache the provider's raw pointer
*pp_provider = gpu_provider.get();

View file

@ -15,10 +15,10 @@ class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <
DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue);
HRESULT STDMETHODCALLTYPE CreateSessionOptions(
ISessionOptions** p_options) override;
OrtSessionOptions** options) override;
HRESULT STDMETHODCALLTYPE CreateSession(
ISessionOptions* options,
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) override;

View file

@ -3,6 +3,8 @@
#pragma once
#include "core/session/onnxruntime_c_api.h"
namespace Windows::AI::MachineLearning::Adapter {
MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{
@ -70,22 +72,15 @@ MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnkn
virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0;
};
MIDL_INTERFACE("55a956a7-c20e-440d-b2d2-a77acf35de10") ISessionOptions : IUnknown{
// this returns a weak ref
virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() = 0;
// end
virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) = 0;
};
// The IOrtSessionBuilder offers an abstraction over the creation of
// InferenceSession, that enables the creation of the session based on a device (CPU/DML).
MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown {
virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions(
ISessionOptions** options) = 0;
OrtSessionOptions ** options) = 0;
virtual HRESULT STDMETHODCALLTYPE CreateSession(
ISessionOptions* options,
OrtSessionOptions * options,
IInferenceSession** session,
onnxruntime::IExecutionProvider** provider) = 0;
@ -191,23 +186,6 @@ private:
std::shared_ptr<onnxruntime::InferenceSession> session_;
};
class AbiSafeSessionOptions : public Microsoft::WRL::RuntimeClass <
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
ISessionOptions> {
private:
onnxruntime::SessionOptions options_;
public:
virtual onnxruntime::SessionOptions& STDMETHODCALLTYPE get() override {
return options_;
}
virtual void STDMETHODCALLTYPE SetBatchOverride(uint32_t batch_size) override {
onnxruntime::FreeDimensionOverride overrideOption = {};
overrideOption.dimension_denotation = onnx::DATA_BATCH;
overrideOption.dimension_override = batch_size;
options_.free_dimension_overrides.emplace_back(overrideOption);
}
};
// header only code to enable smart pointers on abstract ort objects
template <typename T>
class OrtObject {

View file

@ -109,18 +109,22 @@ void LearningModelSession::Initialize() {
device_impl->GetDeviceQueue(),
session_builder.put()));
com_ptr<_winmla::ISessionOptions> options;
WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put()));
OrtSessionOptions* options_ptr;
WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(&options_ptr));
std::unique_ptr<Ort::SessionOptions> options = std::make_unique<Ort::SessionOptions>(options_ptr);
// Make onnxruntime apply the batch size override, if any
if (session_options_ && session_options_.BatchSizeOverride() != 0)
{
options->SetBatchOverride(session_options_.BatchSizeOverride());
Ort::ThrowOnError(Ort::GetApi().AddFreeDimensionOverride(
*(options.get()),
onnx::DATA_BATCH,
session_options_.BatchSizeOverride()));
}
com_ptr<_winmla::IInferenceSession> session;
WINML_THROW_IF_FAILED(session_builder->CreateSession(
options.get(), session.put(), &cached_execution_provider_));
*(options.get()), session.put(), &cached_execution_provider_));
// Register the custom operator registry
auto model = model_.as<winmlp::LearningModel>();

View file

@ -13,6 +13,9 @@
// Restore ERROR define
#define ERROR 0
// the C++ ort api
#include "core/session/onnxruntime_cxx_api.h"
#include <DirectML.h>
#include "core/framework/customregistry.h"