onnxruntime/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp
Tiago Koji Castro Shibata f7c3e4fa99
Store/containerized apps support (#4651)
* Initial containerized/Store build

* Remove unsupported APIs

* Remove usage of STL ifstream

* Revert CMake changes

* Link to app runtime

* WCOS/Store cmake

* Update CMakeSettings.json

* Fix winapi family support

* Fix downlevel

* Downlevel build

* Remove downlevel workaround

* pep8 compliance

* Workaround WinRT headers bug

https://github.com/microsoft/cppwinrt/issues/584 in older SDK

* Always cross compile to avoid warnings as errors

* PR feedback

* More CI fixes

* PR feedback

* aiinfra build fix

* Win8 store
2020-09-09 14:36:35 -07:00

217 lines
8.8 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "pch.h"
#include "OnnxruntimeEnvironment.h"
#include "OnnxruntimeErrors.h"
#include "core/platform/windows/TraceLoggingConfig.h"
#include <evntrace.h>
#include <windows.h>
using namespace _winml;
static bool debug_output_ = false;
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
static bool IsCurrentModuleInSystem32() {
std::string current_module_path;
current_module_path.reserve(MAX_PATH);
auto size_module_path = GetModuleFileNameA((HINSTANCE)&__ImageBase, current_module_path.data(), MAX_PATH);
FAIL_FAST_IF(size_module_path == 0);
std::string system32_path;
system32_path.reserve(MAX_PATH);
auto size_system32_path = GetSystemDirectoryA(system32_path.data(), MAX_PATH);
FAIL_FAST_IF(size_system32_path == 0);
return _strnicmp(system32_path.c_str(), current_module_path.c_str(), size_system32_path) == 0;
}
static HRESULT GetOnnxruntimeLibrary(HMODULE& module) {
#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP
auto out_module = LoadPackagedLibrary(L"onnxruntime.dll", 0);
#else
DWORD flags = 0;
#ifdef BUILD_INBOX
flags |= IsCurrentModuleInSystem32() ? LOAD_LIBRARY_SEARCH_SYSTEM32 : 0;
#endif
auto out_module = LoadLibraryExA("onnxruntime.dll", nullptr, flags);
#endif
if (out_module == nullptr) {
return HRESULT_FROM_WIN32(GetLastError());
}
module = out_module;
return S_OK;
}
const OrtApi* _winml::GetVersionedOrtApi() {
HMODULE onnxruntime_dll;
FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll));
using OrtGetApiBaseSignature = decltype(OrtGetApiBase);
auto ort_get_api_base_fn = reinterpret_cast<OrtGetApiBaseSignature*>(GetProcAddress(onnxruntime_dll, "OrtGetApiBase"));
if (ort_get_api_base_fn == nullptr) {
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
}
const auto ort_api_base = ort_get_api_base_fn();
static const uint32_t ort_version = 2;
return ort_api_base->GetApi(ort_version);
}
static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) {
HMODULE onnxruntime_dll;
FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll));
using OrtGetWinMLAdapterSignature = decltype(OrtGetWinMLAdapter);
auto ort_get_winml_adapter_fn = reinterpret_cast<OrtGetWinMLAdapterSignature*>(GetProcAddress(onnxruntime_dll, "OrtGetWinMLAdapter"));
if (ort_get_winml_adapter_fn == nullptr) {
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
}
return ort_get_winml_adapter_fn(ort_api);
}
const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() {
return GetVersionedWinmlAdapterApi(GetVersionedOrtApi());
}
static void __stdcall WinmlOrtLoggingCallback(void* param, OrtLoggingLevel severity, const char* category,
const char* logger_id, const char* code_location, const char* message) noexcept {
UNREFERENCED_PARAMETER(param);
UNREFERENCED_PARAMETER(logger_id);
// ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry.
switch (severity) {
case OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL: //Telemetry
TraceLoggingWrite(
winml_trace_logging_provider,
"WinMLLogSink",
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(category),
TraceLoggingUInt32((UINT32)severity),
TraceLoggingString(message),
TraceLoggingString(code_location),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
break;
case OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR: //Telemetry
TraceLoggingWrite(
winml_trace_logging_provider,
"WinMLLogSink",
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(category),
TraceLoggingUInt32((UINT32)severity),
TraceLoggingString(message),
TraceLoggingString(code_location),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
break;
case OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING:
TraceLoggingWrite(
winml_trace_logging_provider,
"WinMLLogSink",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_WARNING),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(category),
TraceLoggingUInt32((UINT32)severity),
TraceLoggingString(message),
TraceLoggingString(code_location));
break;
case OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO:
TraceLoggingWrite(
winml_trace_logging_provider,
"WinMLLogSink",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(category),
TraceLoggingUInt32((UINT32)severity),
TraceLoggingString(message),
TraceLoggingString(code_location));
break;
case OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE:
__fallthrough; //Default is Verbose too.
default:
TraceLoggingWrite(
winml_trace_logging_provider,
"WinMLLogSink",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(category),
TraceLoggingUInt32((UINT32)severity),
TraceLoggingString(message),
TraceLoggingString(code_location));
}
if (debug_output_) {
OutputDebugStringA((std::string(message) + "\r\n").c_str());
}
}
static void __stdcall WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept {
if (profiler_record->category_ == OrtProfilerEventCategory::NODE_EVENT) {
TraceLoggingWrite(
winml_trace_logging_provider,
"OnnxRuntimeProfiling",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(profiler_record->category_name_, "Category"),
TraceLoggingInt64(profiler_record->duration_, "Duration (us)"),
TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"),
TraceLoggingString(profiler_record->event_name_, "Event Name"),
TraceLoggingInt32(profiler_record->process_id_, "Process ID"),
TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"),
TraceLoggingString(profiler_record->op_name_, "Operator Name"),
TraceLoggingString(profiler_record->execution_provider_, "Execution Provider"));
} else {
TraceLoggingWrite(
winml_trace_logging_provider,
"OnnxRuntimeProfiling",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO),
TraceLoggingString(profiler_record->category_name_, "Category"),
TraceLoggingInt64(profiler_record->duration_, "Duration (us)"),
TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"),
TraceLoggingString(profiler_record->event_name_, "Event Name"),
TraceLoggingInt32(profiler_record->process_id_, "Process ID"),
TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"));
}
}
OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) {
OrtEnv* ort_env = nullptr;
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env),
ort_api);
THROW_IF_NOT_OK_MSG(ort_api->SetLanguageProjection(ort_env, OrtLanguageProjection::ORT_PROJECTION_WINML), ort_api);
ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv);
// Configure the environment with the winml logger
auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api);
THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(),
&WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr,
OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env),
ort_api);
THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api);
}
HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) {
*ort_env = ort_env_.get();
return S_OK;
}
HRESULT OnnxruntimeEnvironment::EnableDebugOutput(bool is_enabled) {
debug_output_ = is_enabled;
return S_OK;
}