mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
* Register ILearningModelSessionOptionsNate interface * Threading options exposed * Add interrogator for Session options * Add test * Polish test * PR comments * Set intra op threads * Add adapter api to grab intra op threads * Add adapter test for getting intraop num threads * Make ILearningModelSessionNative and update winml api test * Make it required when building engine to set the intraop num threads * Make test more pretty * Change naming of idl function * Revert "Change naming of idl function" This reverts commit c06916aa5bf94e3bf233ed281e508b935fc8638d. * PR comment on naming * Skip the test because it's influenced if it's built with openmp Co-authored-by: Ryan Lai <ryalai96@gamil.com>
87 lines
No EOL
3.1 KiB
C++
87 lines
No EOL
3.1 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "pch.h"
|
|
|
|
#include "OnnxruntimeEngine.h"
|
|
#include "OnnxruntimeEngineBuilder.h"
|
|
#include "OnnxruntimeCpuSessionBuilder.h"
|
|
|
|
#ifdef USE_DML
|
|
#include "OnnxruntimeDmlSessionBuilder.h"
|
|
#endif
|
|
|
|
#include "OnnxruntimeErrors.h"
|
|
using namespace _winml;
|
|
|
|
HRESULT OnnxruntimeEngineBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) {
|
|
engine_factory_ = engine_factory;
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(_winml::IEngine** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
|
|
Microsoft::WRL::ComPtr<IOrtSessionBuilder> onnxruntime_session_builder;
|
|
|
|
if (device_ == nullptr) {
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeCpuSessionBuilder>(&onnxruntime_session_builder, engine_factory_.Get()));
|
|
} else {
|
|
#ifdef USE_DML
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeDmlSessionBuilder>(&onnxruntime_session_builder, engine_factory_.Get(), device_.Get(), queue_.Get(), metacommands_enabled_));
|
|
#endif
|
|
}
|
|
|
|
OrtSessionOptions* ort_options;
|
|
RETURN_IF_FAILED(onnxruntime_session_builder->CreateSessionOptions(&ort_options));
|
|
auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions);
|
|
|
|
if (batch_size_override_.has_value()) {
|
|
constexpr const char* DATA_BATCH = "DATA_BATCH";
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->AddFreeDimensionOverride(session_options.get(), DATA_BATCH, batch_size_override_.value()),
|
|
ort_api);
|
|
}
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->SetIntraOpNumThreads(session_options.get(), intra_op_num_threads_override_), ort_api);
|
|
|
|
OrtSession* ort_session = nullptr;
|
|
onnxruntime_session_builder->CreateSession(session_options.get(), &ort_session);
|
|
auto session = UniqueOrtSession(ort_session, ort_api->ReleaseSession);
|
|
|
|
Microsoft::WRL::ComPtr<OnnxruntimeEngine> onnxruntime_engine;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeEngine>(&onnxruntime_engine,
|
|
engine_factory_.Get(), std::move(session), onnxruntime_session_builder.Get()));
|
|
RETURN_IF_FAILED(onnxruntime_engine.CopyTo(out));
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::GetD3D12Device(ID3D12Device** device) {
|
|
*device = device_.Get();
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::SetD3D12Resources(ID3D12Device* device, ID3D12CommandQueue* queue) {
|
|
device_ = device;
|
|
queue_ = queue;
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::SetMetacommandsEnabled(int enabled) {
|
|
metacommands_enabled_ = static_cast<bool>(enabled);
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::GetID3D12CommandQueue(ID3D12CommandQueue** queue) {
|
|
*queue = queue_.Get();
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::SetBatchSizeOverride(uint32_t batch_size_override) {
|
|
batch_size_override_ = batch_size_override;
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineBuilder::SetIntraOpNumThreadsOverride(uint32_t intra_op_num_threads) {
|
|
intra_op_num_threads_override_ = intra_op_num_threads;
|
|
return S_OK;
|
|
} |