mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
* Share thread pools between devices * make tests reuse device * Change cpu thread pool options for dml sessions to use 1 thread with no spinning * fix test failure * Update missing type constraints for dft * Add comment and rename inference session parameter * default missing causing inconsistent test behavior Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
331 lines
13 KiB
C++
331 lines
13 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "testPch.h"
|
|
|
|
#include <functional>
|
|
|
|
#include "cppwinrt_onnx.h"
|
|
|
|
#include "AdapterSessionTest.h"
|
|
#include "ILotusValueProviderPrivate.h"
|
|
#include "onnxruntime_c_api.h"
|
|
#include "OnnxruntimeEngine.h"
|
|
#include "OnnxruntimeErrors.h"
|
|
#include "OnnxruntimeModel.h"
|
|
#include "core/common/logging/isink.h"
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/session/abi_session_options_impl.h"
|
|
#include "core/session/ort_env.h"
|
|
#include "core/providers/dml/dml_provider_factory.h"
|
|
|
|
using namespace _winml;
|
|
using namespace winrt::Windows::Foundation::Collections;
|
|
using namespace winrt::Windows::Graphics::Imaging;
|
|
using namespace winrt::Windows::Media;
|
|
using namespace winrt::Windows::Storage;
|
|
using namespace winrt::Windows::Storage::Streams;
|
|
|
|
namespace {
|
|
winrt::com_ptr<_winml::OnnxruntimeEngineFactory> engine_factory;
|
|
const OrtApi* ort_api;
|
|
const WinmlAdapterApi* winml_adapter_api;
|
|
OrtEnv* ort_env;
|
|
|
|
void AdapterSessionTestSetup() {
|
|
winrt::init_apartment();
|
|
#ifdef BUILD_INBOX
|
|
winrt_activation_handler = WINRT_RoGetActivationFactory;
|
|
#endif
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(Microsoft::WRL::MakeAndInitialize<_winml::OnnxruntimeEngineFactory>(engine_factory.put()));
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(engine_factory->GetOrtEnvironment(&ort_env));
|
|
WINML_EXPECT_NOT_EQUAL(nullptr, winml_adapter_api = engine_factory->UseWinmlAdapterApi());
|
|
WINML_EXPECT_NOT_EQUAL(nullptr, ort_api = engine_factory->UseOrtApi());
|
|
}
|
|
|
|
void AdapterSessionTestTeardown() {
|
|
engine_factory = nullptr;
|
|
}
|
|
|
|
UniqueOrtSessionOptions CreateUniqueOrtSessionOptions() {
|
|
OrtSessionOptions *options;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&options), ort_api);
|
|
return UniqueOrtSessionOptions(options, ort_api->ReleaseSessionOptions);
|
|
}
|
|
|
|
void AppendExecutionProvider_CPU() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), true), ort_api);
|
|
}
|
|
|
|
winrt::com_ptr<ID3D12Device> CreateD3DDevice() {
|
|
winrt::com_ptr<ID3D12Device> device;
|
|
WINML_EXPECT_NO_THROW(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device.put())));
|
|
return device;
|
|
}
|
|
|
|
winrt::com_ptr<ID3D12CommandQueue> CreateD3DQueue(ID3D12Device* device) {
|
|
winrt::com_ptr<ID3D12CommandQueue> queue;
|
|
D3D12_COMMAND_QUEUE_DESC command_queue_desc = {};
|
|
command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
|
|
device->CreateCommandQueue(&command_queue_desc, IID_PPV_ARGS(queue.put()));
|
|
return queue;
|
|
}
|
|
|
|
UniqueOrtSession CreateUniqueOrtSession(const UniqueOrtSessionOptions& session_options) {
|
|
OrtSession* session;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, session_options.get(), nullptr, nullptr, &session), ort_api);
|
|
return UniqueOrtSession(session, ort_api->ReleaseSession);
|
|
}
|
|
|
|
UniqueOrtSession CreateUniqueOrtSession(const std::wstring& model_path, const UniqueOrtSessionOptions& session_options) {
|
|
OrtSession* session;
|
|
ort_api->SetIntraOpNumThreads(session_options.get(), 1);
|
|
ort_api->SetSessionGraphOptimizationLevel(session_options.get(), ORT_ENABLE_BASIC);
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), true), ort_api);
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateSession(ort_env, model_path.c_str(), session_options.get(), &session), ort_api);
|
|
return UniqueOrtSession(session, ort_api->ReleaseSession);
|
|
}
|
|
|
|
void AppendExecutionProvider_DML() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
|
|
const auto device = CreateD3DDevice();
|
|
const auto queue = CreateD3DQueue(device.get());
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device.get(), queue.get(), true), ort_api);
|
|
}
|
|
|
|
void CreateWithoutModel() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
CreateUniqueOrtSession(session_options);
|
|
}
|
|
|
|
void GetExecutionProvider() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
const auto model_path = FileHelpers::GetModulePath() + L"fns-candy.onnx";
|
|
auto session = CreateUniqueOrtSession(model_path, session_options);
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
|
}
|
|
|
|
void GetExecutionProvider_DML() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
THROW_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), ort_api);
|
|
const auto device = CreateD3DDevice();
|
|
const auto queue = CreateD3DQueue(device.get());
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device.get(), queue.get(), true), ort_api);
|
|
|
|
const auto model_path = FileHelpers::GetModulePath() + L"fns-candy.onnx";
|
|
auto session = CreateUniqueOrtSession(model_path, session_options);
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
|
// Test if DML EP method can be called
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), ort_api);
|
|
}
|
|
|
|
void RegisterGraphTransformers() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
winml_adapter_api->SessionRegisterGraphTransformers(session.get());
|
|
}
|
|
|
|
void RegisterGraphTransformers_DML() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
winml_adapter_api->SessionRegisterGraphTransformers(session.get());
|
|
}
|
|
|
|
void RegisterCustomRegistry() {
|
|
IMLOperatorRegistry* registry;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->CreateCustomRegistry(®istry), ort_api);
|
|
if (registry) {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterCustomRegistry(session.get(), registry), ort_api);
|
|
}
|
|
}
|
|
|
|
void RegisterCustomRegistry_DML() {
|
|
IMLOperatorRegistry* registry;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->CreateCustomRegistry(®istry), ort_api);
|
|
WINML_EXPECT_NOT_EQUAL(nullptr, registry);
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterCustomRegistry(session.get(), registry), ort_api);
|
|
}
|
|
|
|
void LoadAndPurloinModel(const UniqueOrtSession& session, const std::string& model_path) {
|
|
winrt::com_ptr<_winml::IModel> model;
|
|
WINML_THROW_IF_FAILED(engine_factory->CreateModel(model_path.c_str(), sizeof(model_path), model.put()));
|
|
|
|
winrt::com_ptr<_winml::IOnnxruntimeModel> onnxruntime_model;
|
|
WINML_EXPECT_NO_THROW(onnxruntime_model = model.as<_winml::IOnnxruntimeModel>());
|
|
OrtModel* ort_model = nullptr;
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(onnxruntime_model->DetachOrtModel(&ort_model));
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionLoadAndPurloinModel(session.get(), ort_model), ort_api);
|
|
}
|
|
|
|
void LoadAndPurloinModel() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
LoadAndPurloinModel(session, "fns-candy.onnx");
|
|
}
|
|
|
|
void Initialize() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
|
|
winrt::com_ptr<_winml::IModel> model;
|
|
const auto model_path = "fns-candy.onnx";
|
|
WINML_THROW_IF_FAILED(engine_factory->CreateModel(model_path, sizeof(model_path), model.put()));
|
|
|
|
winrt::com_ptr<_winml::IOnnxruntimeModel> onnxruntime_model;
|
|
WINML_EXPECT_NO_THROW(onnxruntime_model = model.as<_winml::IOnnxruntimeModel>());
|
|
OrtModel* ort_model = nullptr;
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(onnxruntime_model->DetachOrtModel(&ort_model));
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionLoadAndPurloinModel(session.get(), ort_model), ort_api);
|
|
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session.get()), ort_api);
|
|
}
|
|
|
|
static bool logging_called = false, profile_called = false;
|
|
void Profiling() {
|
|
const auto logging_callback = [](void*, OrtLoggingLevel, const char*, const char*, const char*, const char*) {
|
|
logging_called = true;
|
|
};
|
|
const auto profile_callback = [](const OrtProfilerEventRecord*) {
|
|
profile_called = true;
|
|
};
|
|
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(
|
|
ort_env,
|
|
logging_callback,
|
|
profile_callback,
|
|
nullptr,
|
|
OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
|
"Default",
|
|
&ort_env),
|
|
ort_api);
|
|
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
|
|
winml_adapter_api->SessionStartProfiling(ort_env, session.get());
|
|
LoadAndPurloinModel(session, "fns-candy.onnx");
|
|
winml_adapter_api->SessionEndProfiling(session.get());
|
|
WINML_EXPECT_TRUE(logging_called);
|
|
WINML_EXPECT_TRUE(profile_called);
|
|
}
|
|
|
|
void CopyInputAcrossDevices() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
auto session = CreateUniqueOrtSession(L"fns-candy.onnx", session_options);
|
|
|
|
constexpr std::array<int64_t, 4> dimensions{1, 3, 720, 720};
|
|
constexpr size_t input_tensor_size = [&dimensions]() {
|
|
size_t size = 1;
|
|
for (auto dim : dimensions)
|
|
size *= static_cast<size_t>(dim);
|
|
return size;
|
|
} ();
|
|
|
|
OrtMemoryInfo* memory_info;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info), ort_api);
|
|
std::vector<float> input_tensor_values(input_tensor_size);
|
|
OrtValue* input_tensor;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(memory_info, input_tensor_values.data(), input_tensor_size * sizeof(float), dimensions.data(), 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor), ort_api);
|
|
|
|
int is_tensor;
|
|
THROW_IF_NOT_OK_MSG(ort_api->IsTensor(input_tensor, &is_tensor), ort_api);
|
|
WINML_EXPECT_TRUE(is_tensor);
|
|
|
|
OrtValue* dest_ort_value;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionCopyOneInputAcrossDevices(session.get(), "inputImage", input_tensor, &dest_ort_value), ort_api);
|
|
|
|
ort_api->ReleaseValue(input_tensor);
|
|
ort_api->ReleaseValue(dest_ort_value);
|
|
ort_api->ReleaseMemoryInfo(memory_info);
|
|
}
|
|
|
|
void CopyInputAcrossDevices_DML() {
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
THROW_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), ort_api);
|
|
const auto device = CreateD3DDevice();
|
|
const auto queue = CreateD3DQueue(device.get());
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device.get(), queue.get(), true), ort_api);
|
|
auto session = CreateUniqueOrtSession(session_options);
|
|
|
|
LoadAndPurloinModel(session, "fns-candy.onnx");
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session.get()), ort_api);
|
|
constexpr std::array<int64_t, 4> dimensions{1, 3, 720, 720};
|
|
constexpr size_t input_tensor_size = [&dimensions]() {
|
|
size_t size = 1;
|
|
for (auto dim : dimensions)
|
|
size *= static_cast<size_t>(dim);
|
|
return size;
|
|
} ();
|
|
|
|
OrtMemoryInfo* memory_info;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info), ort_api);
|
|
std::vector<float> input_tensor_values(input_tensor_size);
|
|
OrtValue* input_tensor;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(memory_info, input_tensor_values.data(), input_tensor_size * sizeof(float), dimensions.data(), 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor), ort_api);
|
|
|
|
int is_tensor;
|
|
THROW_IF_NOT_OK_MSG(ort_api->IsTensor(input_tensor, &is_tensor), ort_api);
|
|
WINML_EXPECT_TRUE(is_tensor);
|
|
|
|
OrtValue* dest_ort_value = nullptr;
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionCopyOneInputAcrossDevices(session.get(), "inputImage", input_tensor, &dest_ort_value), ort_api);
|
|
|
|
ort_api->ReleaseValue(input_tensor);
|
|
ort_api->ReleaseMemoryInfo(memory_info);
|
|
}
|
|
|
|
void GetNumberOfIntraOpThreads(){
|
|
const auto session_options = CreateUniqueOrtSessionOptions();
|
|
uint32_t desired_num_threads = std::thread::hardware_concurrency() / 2;
|
|
ort_api->SetIntraOpNumThreads(session_options.get(), desired_num_threads);
|
|
const auto session = CreateUniqueOrtSession(session_options);
|
|
uint32_t num_threads;
|
|
winml_adapter_api->SessionGetNumberOfIntraOpThreads(session.get(), &num_threads);
|
|
WINML_EXPECT_EQUAL(num_threads, desired_num_threads);
|
|
}
|
|
}
|
|
|
|
const AdapterSessionTestAPI& getapi() {
|
|
static AdapterSessionTestAPI api =
|
|
{
|
|
AdapterSessionTestSetup,
|
|
AdapterSessionTestTeardown,
|
|
AppendExecutionProvider_CPU,
|
|
AppendExecutionProvider_DML,
|
|
CreateWithoutModel,
|
|
GetExecutionProvider,
|
|
GetExecutionProvider_DML,
|
|
Initialize,
|
|
RegisterGraphTransformers,
|
|
RegisterGraphTransformers_DML,
|
|
RegisterCustomRegistry,
|
|
RegisterCustomRegistry_DML,
|
|
LoadAndPurloinModel,
|
|
Profiling,
|
|
CopyInputAcrossDevices,
|
|
CopyInputAcrossDevices_DML,
|
|
GetNumberOfIntraOpThreads
|
|
};
|
|
|
|
if (SkipGpuTests()) {
|
|
api.AppendExecutionProvider_DML = SkipTest;
|
|
api.GetExecutionProvider_DML = SkipTest;
|
|
api.RegisterGraphTransformers_DML = SkipTest;
|
|
api.RegisterCustomRegistry_DML = SkipTest;
|
|
api.CopyInputAcrossDevices_DML = SkipTest;
|
|
}
|
|
if (SkipTestsImpactedByOpenMP()) {
|
|
api.GetNumberOfIntraOpThreads = SkipTest;
|
|
}
|
|
return api;
|
|
}
|