mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
* Initial containerized/Store build * Remove unsupported APIs * Remove usage of STL ifstream * Revert CMake changes * Link to app runtime * WCOS/Store cmake * Update CMakeSettings.json * Fix winapi family support * Fix downlevel * Downlevel build * Remove downlevel workaround * pep8 compliance * Workaround WinRT headers bug https://github.com/microsoft/cppwinrt/issues/584 in older SDK * Always cross compile to avoid warnings as errors * PR feedback * More CI fixes * PR feedback * aiinfra build fix * Win8 store
217 lines
8.8 KiB
C++
217 lines
8.8 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "pch.h"
|
|
#include "OnnxruntimeEnvironment.h"
|
|
#include "OnnxruntimeErrors.h"
|
|
#include "core/platform/windows/TraceLoggingConfig.h"
|
|
#include <evntrace.h>
|
|
|
|
#include <windows.h>
|
|
|
|
using namespace _winml;
|
|
|
|
static bool debug_output_ = false;
|
|
|
|
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
|
|
|
|
static bool IsCurrentModuleInSystem32() {
|
|
std::string current_module_path;
|
|
current_module_path.reserve(MAX_PATH);
|
|
auto size_module_path = GetModuleFileNameA((HINSTANCE)&__ImageBase, current_module_path.data(), MAX_PATH);
|
|
FAIL_FAST_IF(size_module_path == 0);
|
|
|
|
std::string system32_path;
|
|
system32_path.reserve(MAX_PATH);
|
|
auto size_system32_path = GetSystemDirectoryA(system32_path.data(), MAX_PATH);
|
|
FAIL_FAST_IF(size_system32_path == 0);
|
|
|
|
return _strnicmp(system32_path.c_str(), current_module_path.c_str(), size_system32_path) == 0;
|
|
}
|
|
|
|
static HRESULT GetOnnxruntimeLibrary(HMODULE& module) {
|
|
#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP
|
|
auto out_module = LoadPackagedLibrary(L"onnxruntime.dll", 0);
|
|
#else
|
|
DWORD flags = 0;
|
|
#ifdef BUILD_INBOX
|
|
flags |= IsCurrentModuleInSystem32() ? LOAD_LIBRARY_SEARCH_SYSTEM32 : 0;
|
|
#endif
|
|
|
|
auto out_module = LoadLibraryExA("onnxruntime.dll", nullptr, flags);
|
|
#endif
|
|
if (out_module == nullptr) {
|
|
return HRESULT_FROM_WIN32(GetLastError());
|
|
}
|
|
module = out_module;
|
|
return S_OK;
|
|
}
|
|
|
|
const OrtApi* _winml::GetVersionedOrtApi() {
|
|
HMODULE onnxruntime_dll;
|
|
FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll));
|
|
|
|
using OrtGetApiBaseSignature = decltype(OrtGetApiBase);
|
|
auto ort_get_api_base_fn = reinterpret_cast<OrtGetApiBaseSignature*>(GetProcAddress(onnxruntime_dll, "OrtGetApiBase"));
|
|
if (ort_get_api_base_fn == nullptr) {
|
|
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
|
|
}
|
|
|
|
const auto ort_api_base = ort_get_api_base_fn();
|
|
|
|
static const uint32_t ort_version = 2;
|
|
return ort_api_base->GetApi(ort_version);
|
|
}
|
|
|
|
static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) {
|
|
HMODULE onnxruntime_dll;
|
|
FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll));
|
|
|
|
using OrtGetWinMLAdapterSignature = decltype(OrtGetWinMLAdapter);
|
|
auto ort_get_winml_adapter_fn = reinterpret_cast<OrtGetWinMLAdapterSignature*>(GetProcAddress(onnxruntime_dll, "OrtGetWinMLAdapter"));
|
|
if (ort_get_winml_adapter_fn == nullptr) {
|
|
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
|
|
}
|
|
|
|
return ort_get_winml_adapter_fn(ort_api);
|
|
}
|
|
|
|
const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() {
|
|
return GetVersionedWinmlAdapterApi(GetVersionedOrtApi());
|
|
}
|
|
|
|
static void __stdcall WinmlOrtLoggingCallback(void* param, OrtLoggingLevel severity, const char* category,
|
|
const char* logger_id, const char* code_location, const char* message) noexcept {
|
|
UNREFERENCED_PARAMETER(param);
|
|
UNREFERENCED_PARAMETER(logger_id);
|
|
// ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry.
|
|
switch (severity) {
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL: //Telemetry
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location),
|
|
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR: //Telemetry
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location),
|
|
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING:
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_WARNING),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location));
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO:
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location));
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE:
|
|
__fallthrough; //Default is Verbose too.
|
|
default:
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location));
|
|
}
|
|
|
|
if (debug_output_) {
|
|
OutputDebugStringA((std::string(message) + "\r\n").c_str());
|
|
}
|
|
}
|
|
|
|
static void __stdcall WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept {
|
|
if (profiler_record->category_ == OrtProfilerEventCategory::NODE_EVENT) {
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"OnnxRuntimeProfiling",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(profiler_record->category_name_, "Category"),
|
|
TraceLoggingInt64(profiler_record->duration_, "Duration (us)"),
|
|
TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"),
|
|
TraceLoggingString(profiler_record->event_name_, "Event Name"),
|
|
TraceLoggingInt32(profiler_record->process_id_, "Process ID"),
|
|
TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"),
|
|
TraceLoggingString(profiler_record->op_name_, "Operator Name"),
|
|
TraceLoggingString(profiler_record->execution_provider_, "Execution Provider"));
|
|
} else {
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"OnnxRuntimeProfiling",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(profiler_record->category_name_, "Category"),
|
|
TraceLoggingInt64(profiler_record->duration_, "Duration (us)"),
|
|
TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"),
|
|
TraceLoggingString(profiler_record->event_name_, "Event Name"),
|
|
TraceLoggingInt32(profiler_record->process_id_, "Process ID"),
|
|
TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"));
|
|
}
|
|
}
|
|
|
|
OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) {
|
|
OrtEnv* ort_env = nullptr;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env),
|
|
ort_api);
|
|
THROW_IF_NOT_OK_MSG(ort_api->SetLanguageProjection(ort_env, OrtLanguageProjection::ORT_PROJECTION_WINML), ort_api);
|
|
ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv);
|
|
// Configure the environment with the winml logger
|
|
auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api);
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(),
|
|
&WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr,
|
|
OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env),
|
|
ort_api);
|
|
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api);
|
|
}
|
|
|
|
HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) {
|
|
*ort_env = ort_env_.get();
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEnvironment::EnableDebugOutput(bool is_enabled) {
|
|
debug_output_ = is_enabled;
|
|
return S_OK;
|
|
}
|