mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
winml/ was previously excluded from lintrunner config. This change includes the directory and adds the clang-format config file specific to winml/ that fits existing style. --------- Signed-off-by: Justin Chu <justinchu@microsoft.com>
262 lines
9.4 KiB
C++
262 lines
9.4 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "lib/Api.Ort/pch.h"
|
|
#include "OnnxruntimeEnvironment.h"
|
|
#include "OnnxruntimeErrors.h"
|
|
#include "core/platform/windows/TraceLoggingConfig.h"
|
|
#include <evntrace.h>
|
|
#include <windows.h>
|
|
#include <winrt/Windows.ApplicationModel.h>
|
|
#include <winrt/Windows.ApplicationModel.Core.h>
|
|
|
|
using namespace _winml;
|
|
|
|
static bool debug_output_ = false;
|
|
|
|
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
|
|
|
|
static std::wstring CurrentModulePath() {
|
|
WCHAR path[MAX_PATH];
|
|
FAIL_FAST_IF(0 == GetModuleFileNameW((HINSTANCE)&__ImageBase, path, _countof(path)));
|
|
|
|
WCHAR absolute_path[MAX_PATH];
|
|
WCHAR* name;
|
|
FAIL_FAST_IF(0 == GetFullPathNameW(path, _countof(path), absolute_path, &name));
|
|
|
|
auto idx = std::distance(absolute_path, name);
|
|
auto out_path = std::wstring(absolute_path);
|
|
out_path.resize(idx);
|
|
|
|
return out_path;
|
|
}
|
|
|
|
static HRESULT GetOnnxruntimeLibrary(HMODULE& module) {
|
|
#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP
|
|
// Store + Redist (note that this is never built into the inbox dll)
|
|
auto out_module = LoadPackagedLibrary(L"onnxruntime.dll", 0);
|
|
#else
|
|
auto onnxruntime_dll = CurrentModulePath() + L"\\onnxruntime.dll";
|
|
auto out_module = LoadLibraryExW(onnxruntime_dll.c_str(), nullptr, 0);
|
|
#endif
|
|
|
|
if (out_module == nullptr) {
|
|
return HRESULT_FROM_WIN32(GetLastError());
|
|
}
|
|
module = out_module;
|
|
return S_OK;
|
|
}
|
|
|
|
const OrtApi* _winml::GetVersionedOrtApi() {
|
|
HMODULE onnxruntime_dll;
|
|
FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll));
|
|
|
|
using OrtGetApiBaseSignature = decltype(OrtGetApiBase);
|
|
auto ort_get_api_base_fn =
|
|
reinterpret_cast<OrtGetApiBaseSignature*>(GetProcAddress(onnxruntime_dll, "OrtGetApiBase"));
|
|
if (ort_get_api_base_fn == nullptr) {
|
|
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
|
|
}
|
|
|
|
const auto ort_api_base = ort_get_api_base_fn();
|
|
return ort_api_base->GetApi(ORT_API_VERSION);
|
|
}
|
|
|
|
static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) {
|
|
HMODULE onnxruntime_dll;
|
|
FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll));
|
|
|
|
using OrtGetWinMLAdapterSignature = decltype(OrtGetWinMLAdapter);
|
|
auto ort_get_winml_adapter_fn =
|
|
reinterpret_cast<OrtGetWinMLAdapterSignature*>(GetProcAddress(onnxruntime_dll, "OrtGetWinMLAdapter"));
|
|
if (ort_get_winml_adapter_fn == nullptr) {
|
|
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
|
|
}
|
|
|
|
return ort_get_winml_adapter_fn(ORT_API_VERSION);
|
|
}
|
|
|
|
const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() {
|
|
return GetVersionedWinmlAdapterApi(GetVersionedOrtApi());
|
|
}
|
|
|
|
static void __stdcall WinmlOrtLoggingCallback(
|
|
void* param,
|
|
OrtLoggingLevel severity,
|
|
const char* category,
|
|
const char* logger_id,
|
|
const char* code_location,
|
|
const char* message
|
|
) noexcept {
|
|
UNREFERENCED_PARAMETER(param);
|
|
UNREFERENCED_PARAMETER(logger_id);
|
|
// ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry.
|
|
switch (severity) {
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL: //Telemetry
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location),
|
|
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)
|
|
);
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR: //Telemetry
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location),
|
|
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)
|
|
);
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING:
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_WARNING),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location)
|
|
);
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO:
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location)
|
|
);
|
|
break;
|
|
case OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE:
|
|
__fallthrough; //Default is Verbose too.
|
|
default:
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"WinMLLogSink",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(category),
|
|
TraceLoggingUInt32((UINT32)severity),
|
|
TraceLoggingString(message),
|
|
TraceLoggingString(code_location)
|
|
);
|
|
}
|
|
|
|
if (debug_output_) {
|
|
OutputDebugStringA((std::string(message) + "\r\n").c_str());
|
|
}
|
|
}
|
|
|
|
static void __stdcall WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept {
|
|
if (profiler_record->category_ == OrtProfilerEventCategory::NODE_EVENT) {
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"OnnxRuntimeProfiling",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(profiler_record->category_name_, "Category"),
|
|
TraceLoggingInt64(profiler_record->duration_, "Duration (us)"),
|
|
TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"),
|
|
TraceLoggingString(profiler_record->event_name_, "Event Name"),
|
|
TraceLoggingInt32(profiler_record->process_id_, "Process ID"),
|
|
TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"),
|
|
TraceLoggingString(profiler_record->op_name_, "Operator Name"),
|
|
TraceLoggingString(profiler_record->execution_provider_, "Execution Provider")
|
|
);
|
|
} else {
|
|
TraceLoggingWrite(
|
|
winml_trace_logging_provider,
|
|
"OnnxRuntimeProfiling",
|
|
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
|
|
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
|
|
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
|
|
TraceLoggingString(profiler_record->category_name_, "Category"),
|
|
TraceLoggingInt64(profiler_record->duration_, "Duration (us)"),
|
|
TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"),
|
|
TraceLoggingString(profiler_record->event_name_, "Event Name"),
|
|
TraceLoggingInt32(profiler_record->process_id_, "Process ID"),
|
|
TraceLoggingInt32(profiler_record->thread_id_, "Thread ID")
|
|
);
|
|
}
|
|
}
|
|
|
|
static void OnSuspending(
|
|
winrt::Windows::Foundation::IInspectable const& sender,
|
|
winrt::Windows::ApplicationModel::SuspendingEventArgs const& args
|
|
) {
|
|
telemetry_helper.LogWinMLSuspended();
|
|
}
|
|
|
|
void OnnxruntimeEnvironment::RegisterSuspendHandler() {
|
|
try {
|
|
auto suspend_event_handler =
|
|
winrt::Windows::Foundation::EventHandler<winrt::Windows::ApplicationModel::SuspendingEventArgs>(&OnSuspending);
|
|
suspend_token_ = winrt::Windows::ApplicationModel::Core::CoreApplication::Suspending(suspend_event_handler);
|
|
} catch (...) {
|
|
} //Catch in case CoreApplication cannot be found for non-UWP executions
|
|
}
|
|
|
|
OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) {
|
|
OrtEnv* ort_env = nullptr;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), ort_api);
|
|
THROW_IF_NOT_OK_MSG(ort_api->SetLanguageProjection(ort_env, OrtLanguageProjection::ORT_PROJECTION_WINML), ort_api);
|
|
ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv);
|
|
// Configure the environment with the winml logger
|
|
auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api);
|
|
THROW_IF_NOT_OK_MSG(
|
|
winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(
|
|
ort_env_.get(),
|
|
&WinmlOrtLoggingCallback,
|
|
&WinmlOrtProfileEventCallback,
|
|
nullptr,
|
|
OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
|
|
"Default",
|
|
&ort_env
|
|
),
|
|
ort_api
|
|
);
|
|
|
|
THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api);
|
|
|
|
// Register suspend handler for UWP applications
|
|
RegisterSuspendHandler();
|
|
}
|
|
|
|
OnnxruntimeEnvironment::~OnnxruntimeEnvironment() {
|
|
if (suspend_token_) {
|
|
winrt::Windows::ApplicationModel::Core::CoreApplication::Suspending(suspend_token_);
|
|
}
|
|
}
|
|
|
|
HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) {
|
|
*ort_env = ort_env_.get();
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEnvironment::EnableDebugOutput(bool is_enabled) {
|
|
debug_output_ = is_enabled;
|
|
return S_OK;
|
|
}
|