// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "lib/Api.Ort/pch.h" #include "OnnxruntimeEnvironment.h" #include "OnnxruntimeErrors.h" #include "core/platform/windows/TraceLoggingConfig.h" #include #include #include #include using namespace _winml; static bool debug_output_ = false; EXTERN_C IMAGE_DOS_HEADER __ImageBase; static std::wstring CurrentModulePath() { WCHAR path[MAX_PATH]; FAIL_FAST_IF(0 == GetModuleFileNameW((HINSTANCE)&__ImageBase, path, _countof(path))); WCHAR absolute_path[MAX_PATH]; WCHAR* name; FAIL_FAST_IF(0 == GetFullPathNameW(path, _countof(path), absolute_path, &name)); auto idx = std::distance(absolute_path, name); auto out_path = std::wstring(absolute_path); out_path.resize(idx); return out_path; } static HRESULT GetOnnxruntimeLibrary(HMODULE& module) { #if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP // Store + Redist (note that this is never built into the inbox dll) auto out_module = LoadPackagedLibrary(L"onnxruntime.dll", 0); #else auto onnxruntime_dll = CurrentModulePath() + L"\\onnxruntime.dll"; auto out_module = LoadLibraryExW(onnxruntime_dll.c_str(), nullptr, 0); #endif if (out_module == nullptr) { return HRESULT_FROM_WIN32(GetLastError()); } module = out_module; return S_OK; } const OrtApi* _winml::GetVersionedOrtApi() { HMODULE onnxruntime_dll; FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll)); using OrtGetApiBaseSignature = decltype(OrtGetApiBase); auto ort_get_api_base_fn = reinterpret_cast(GetProcAddress(onnxruntime_dll, "OrtGetApiBase")); if (ort_get_api_base_fn == nullptr) { FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError())); } const auto ort_api_base = ort_get_api_base_fn(); return ort_api_base->GetApi(ORT_API_VERSION); } static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) { HMODULE onnxruntime_dll; FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll)); using OrtGetWinMLAdapterSignature = decltype(OrtGetWinMLAdapter); auto ort_get_winml_adapter_fn = reinterpret_cast(GetProcAddress(onnxruntime_dll, "OrtGetWinMLAdapter")); if (ort_get_winml_adapter_fn == nullptr) { FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError())); } return ort_get_winml_adapter_fn(ORT_API_VERSION); } const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() { return GetVersionedWinmlAdapterApi(GetVersionedOrtApi()); } static void __stdcall WinmlOrtLoggingCallback( void* param, OrtLoggingLevel severity, const char* category, const char* logger_id, const char* code_location, const char* message ) noexcept { UNREFERENCED_PARAMETER(param); UNREFERENCED_PARAMETER(logger_id); // ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry. switch (severity) { case OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL: //Telemetry TraceLoggingWrite( winml_trace_logging_provider, "WinMLLogSink", TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(category), TraceLoggingUInt32((UINT32)severity), TraceLoggingString(message), TraceLoggingString(code_location), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES) ); break; case OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR: //Telemetry TraceLoggingWrite( winml_trace_logging_provider, "WinMLLogSink", TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(category), TraceLoggingUInt32((UINT32)severity), TraceLoggingString(message), TraceLoggingString(code_location), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES) ); break; case OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING: TraceLoggingWrite( winml_trace_logging_provider, "WinMLLogSink", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_WARNING), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(category), TraceLoggingUInt32((UINT32)severity), TraceLoggingString(message), TraceLoggingString(code_location) ); break; case OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO: TraceLoggingWrite( winml_trace_logging_provider, "WinMLLogSink", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(category), TraceLoggingUInt32((UINT32)severity), TraceLoggingString(message), TraceLoggingString(code_location) ); break; case OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE: __fallthrough; //Default is Verbose too. default: TraceLoggingWrite( winml_trace_logging_provider, "WinMLLogSink", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(category), TraceLoggingUInt32((UINT32)severity), TraceLoggingString(message), TraceLoggingString(code_location) ); } if (debug_output_) { OutputDebugStringA((std::string(message) + "\r\n").c_str()); } } static void __stdcall WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept { if (profiler_record->category_ == OrtProfilerEventCategory::NODE_EVENT) { TraceLoggingWrite( winml_trace_logging_provider, "OnnxRuntimeProfiling", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(profiler_record->category_name_, "Category"), TraceLoggingInt64(profiler_record->duration_, "Duration (us)"), TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"), TraceLoggingString(profiler_record->event_name_, "Event Name"), TraceLoggingInt32(profiler_record->process_id_, "Process ID"), TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"), TraceLoggingString(profiler_record->op_name_, "Operator Name"), TraceLoggingString(profiler_record->execution_provider_, "Execution Provider") ); } else { TraceLoggingWrite( winml_trace_logging_provider, "OnnxRuntimeProfiling", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), TraceLoggingString(profiler_record->category_name_, "Category"), TraceLoggingInt64(profiler_record->duration_, "Duration (us)"), TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"), TraceLoggingString(profiler_record->event_name_, "Event Name"), TraceLoggingInt32(profiler_record->process_id_, "Process ID"), TraceLoggingInt32(profiler_record->thread_id_, "Thread ID") ); } } static void OnSuspending( winrt::Windows::Foundation::IInspectable const& sender, winrt::Windows::ApplicationModel::SuspendingEventArgs const& args ) { telemetry_helper.LogWinMLSuspended(); } void OnnxruntimeEnvironment::RegisterSuspendHandler() { try { auto suspend_event_handler = winrt::Windows::Foundation::EventHandler(&OnSuspending); suspend_token_ = winrt::Windows::ApplicationModel::Core::CoreApplication::Suspending(suspend_event_handler); } catch (...) { } //Catch in case CoreApplication cannot be found for non-UWP executions } OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) { OrtEnv* ort_env = nullptr; THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), ort_api); THROW_IF_NOT_OK_MSG(ort_api->SetLanguageProjection(ort_env, OrtLanguageProjection::ORT_PROJECTION_WINML), ort_api); ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv); // Configure the environment with the winml logger auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api); THROW_IF_NOT_OK_MSG( winml_adapter_api->EnvConfigureCustomLoggerAndProfiler( ort_env_.get(), &WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env ), ort_api ); THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api); // Register suspend handler for UWP applications RegisterSuspendHandler(); } OnnxruntimeEnvironment::~OnnxruntimeEnvironment() { if (suspend_token_) { winrt::Windows::ApplicationModel::Core::CoreApplication::Suspending(suspend_token_); } } HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) { *ort_env = ort_env_.get(); return S_OK; } HRESULT OnnxruntimeEnvironment::EnableDebugOutput(bool is_enabled) { debug_output_ = is_enabled; return S_OK; }