diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 26bb827911..66d12b3be5 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -190,6 +190,9 @@ target_compile_options(winml_lib_ort PRIVATE /GR- /await /wd4238) target_compile_definitions(winml_lib_ort PRIVATE WINML_ROOT_NS=${winml_root_ns}) target_compile_definitions(winml_lib_ort PRIVATE PLATFORM_WINDOWS) target_compile_definitions(winml_lib_ort PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators +if (onnxruntime_WINML_NAMESPACE_OVERRIDE STREQUAL "Windows") + target_compile_definitions(winml_lib_ort PRIVATE "BUILD_INBOX=1") +endif() # Specify the usage of a precompiled header target_precompiled_header(winml_lib_ort pch.h) @@ -622,7 +625,6 @@ add_dependencies(winml_dll winml_api_native) add_dependencies(winml_dll winml_api_native_internal) # Link libraries -target_link_libraries(winml_dll PRIVATE onnxruntime) target_link_libraries(winml_dll PRIVATE re2) target_link_libraries(winml_dll PRIVATE wil) target_link_libraries(winml_dll PRIVATE winml_lib_api) diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index 4a71e40d89..7b011ec795 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -13,16 +13,6 @@ using namespace _winml; -static const OrtApi* GetVersionedOrtApi() { - static const uint32_t ort_version = 2; - const auto ort_api_base = OrtGetApiBase(); - return ort_api_base->GetApi(ort_version); -} - -static const WinmlAdapterApi* GetVersionedWinmlAdapterApi() { - return OrtGetWinMLAdapter(GetVersionedOrtApi()); -} - static ONNXTensorElementDataType ONNXTensorElementDataTypeFromTensorKind(winml::TensorKind kind) { switch (kind) { diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp index af0188e827..8940f9642d 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -7,10 +7,59 @@ #include "core/platform/windows/TraceLoggingConfig.h" #include +#include + using namespace _winml; static bool debug_output_ = false; +static HRESULT GetOnnxruntimeLibrary(HMODULE& module) { + DWORD flags = 0; +#ifdef BUILD_INBOX + flags = LOAD_LIBRARY_SEARCH_SYSTEM32; +#endif + + auto out_module = LoadLibraryExA("onnxruntime.dll", nullptr, flags); + if (out_module == nullptr) { + return HRESULT_FROM_WIN32(GetLastError()); + } + module = out_module; + return S_OK; +} + +const OrtApi* _winml::GetVersionedOrtApi() { + HMODULE onnxruntime_dll; + FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll)); + + using OrtGetApiBaseSignature = decltype(OrtGetApiBase); + auto ort_get_api_base_fn = reinterpret_cast(GetProcAddress(onnxruntime_dll, "OrtGetApiBase")); + if (ort_get_api_base_fn == nullptr) { + FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError())); + } + + const auto ort_api_base = ort_get_api_base_fn(); + + static const uint32_t ort_version = 2; + return ort_api_base->GetApi(ort_version); +} + +static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) { + HMODULE onnxruntime_dll; + FAIL_FAST_IF_FAILED(GetOnnxruntimeLibrary(onnxruntime_dll)); + + using OrtGetWinMLAdapterSignature = decltype(OrtGetWinMLAdapter); + auto ort_get_winml_adapter_fn = reinterpret_cast(GetProcAddress(onnxruntime_dll, "OrtGetWinMLAdapter")); + if (ort_get_winml_adapter_fn == nullptr) { + FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError())); + } + + return ort_get_winml_adapter_fn(ort_api); +} + +const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() { + return GetVersionedWinmlAdapterApi(GetVersionedOrtApi()); +} + static void __stdcall WinmlOrtLoggingCallback(void* param, OrtLoggingLevel severity, const char* category, const char* logger_id, const char* code_location, const char* message) noexcept { UNREFERENCED_PARAMETER(param); @@ -128,7 +177,7 @@ OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_ ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv); // Configure the environment with the winml logger - auto winml_adapter_api = OrtGetWinMLAdapter(ort_api); + auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api); THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(), &WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.h b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h index 5d47266cf9..10df3aec6c 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.h +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h @@ -19,6 +19,9 @@ class OnnxruntimeEnvironment { UniqueOrtEnv ort_env_; }; +const OrtApi* GetVersionedOrtApi(); +const WinmlAdapterApi* GetVersionedWinmlAdapterApi(); + } // namespace _winml #pragma warning(pop) \ No newline at end of file