Support loading widechar paths on windows (#14066)

### Description
Make GetRuntimePath() and LoadDynamicLibrary() operate on platform
specific paths

### Motivation and Context
This addresses https://github.com/microsoft/onnxruntime/issues/14063
This commit is contained in:
Dmitri Smirnov 2022-12-30 16:30:11 -08:00 committed by GitHub
parent b85878953f
commit 5d729839b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 45 additions and 33 deletions

View file

@ -38,7 +38,8 @@ common::Status ExLibLoader::LoadExternalLib(const std::string& dso_file_path,
}
void* lib_handle = nullptr;
ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(dso_file_path, false, &lib_handle));
auto path_str = ToPathString(dso_file_path);
ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(path_str, false, &lib_handle));
dso_name_data_map_[dso_file_path] = lib_handle;
*handle = lib_handle;
return Status::OK();

View file

@ -230,7 +230,7 @@ class Env {
// OK from the function.
// Otherwise returns nullptr in "*handle" and an error status from the
// function.
virtual common::Status LoadDynamicLibrary(const std::string& library_filename, bool global_symbols, void** handle) const = 0;
virtual common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const = 0;
virtual common::Status UnloadDynamicLibrary(void* handle) const = 0;
@ -238,7 +238,7 @@ class Env {
//
// Used to help load other shared libraries that live in the same folder as the core code, for example
// The DNNL provider shared library. Without this path, the module won't be found on windows in all cases.
virtual std::string GetRuntimePath() const { return ""; }
virtual PathString GetRuntimePath() const { return PathString(); }
// \brief Get a pointer to a symbol from a dynamic library.
//

View file

@ -529,7 +529,7 @@ class PosixEnv : public Env {
return Status::OK();
}
common::Status LoadDynamicLibrary(const std::string& library_filename, bool global_symbols, void** handle) const override {
common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const override {
dlerror(); // clear any old error_str
*handle = dlopen(library_filename.c_str(), RTLD_NOW | (global_symbols ? RTLD_GLOBAL : RTLD_LOCAL));
char* error_str = dlerror();

View file

@ -653,22 +653,26 @@ common::Status WindowsEnv::GetCanonicalPath(
// Return the path of the executable/shared library for the current running code. This is to make it
// possible to load other shared libraries installed next to our core runtime code.
std::string WindowsEnv::GetRuntimePath() const {
char buffer[MAX_PATH];
if (!GetModuleFileNameA(reinterpret_cast<HINSTANCE>(&__ImageBase), buffer, _countof(buffer)))
return "";
PathString WindowsEnv::GetRuntimePath() const {
wchar_t buffer[MAX_PATH];
if (!GetModuleFileNameW(reinterpret_cast<HINSTANCE>(&__ImageBase), buffer, _countof(buffer))) {
return PathString();
}
// Remove the filename at the end, but keep the trailing slash
std::string path(buffer);
auto slash_index = path.find_last_of('\\');
if (slash_index == std::string::npos)
return "";
PathString path(buffer);
auto slash_index = path.find_last_of(ORT_TSTR('\\'));
if (slash_index == std::string::npos) {
// Windows supports forward slashes
slash_index = path.find_last_of(ORT_TSTR('/'));
if (slash_index == std::string::npos) {
return PathString();
}
}
return path.substr(0, slash_index + 1);
}
Status WindowsEnv::LoadDynamicLibrary(const std::string& library_filename, bool /*global_symbols*/, void** handle) const {
const std::wstring& wlibrary_filename = ToWideString(library_filename);
Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const {
#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP
*handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0);
#else

View file

@ -76,8 +76,8 @@ class WindowsEnv : public Env {
common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override;
common::Status FileClose(int fd) const override;
common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override;
std::string GetRuntimePath() const override;
Status LoadDynamicLibrary(const std::string& library_filename, bool /*global_symbols*/, void** handle) const override;
PathString GetRuntimePath() const override;
Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override;
Status UnloadDynamicLibrary(void* handle) const override;
Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override;
std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override;

View file

@ -600,7 +600,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddCustomOpDomain, _Inout_ OrtSessionOptions* optio
ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle) {
API_IMPL_BEGIN
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(library_path, false, library_handle));
auto path_str = ToPathString(library_path);
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(path_str, false, library_handle));
if (!*library_handle)
return OrtApis::CreateStatus(ORT_FAIL, "RegisterCustomOpsLibrary: Failed to load library");

View file

@ -90,7 +90,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
// The filename extension for a shared library is different per platform
#ifdef _WIN32
#define LIBRARY_PREFIX
#define LIBRARY_EXTENSION ".dll"
#define LIBRARY_EXTENSION ORT_TSTR(".dll")
#elif defined(__APPLE__)
#define LIBRARY_PREFIX "lib"
#define LIBRARY_EXTENSION ".dylib"
@ -1049,7 +1049,8 @@ struct ProviderSharedLibrary {
if (handle_)
return;
std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION);
auto full_path = Env::Default().GetRuntimePath() +
PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_shared") LIBRARY_EXTENSION);
ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true /*shared_globals on unix*/, &handle_));
void (*PProvider_SetHost)(void*);
@ -1089,7 +1090,7 @@ bool InitProvidersSharedLibrary() try {
}
struct ProviderLibrary {
ProviderLibrary(const char* filename, bool unload = true) : filename_{filename}, unload_{unload} {}
ProviderLibrary(const ORTCHAR_T* filename, bool unload = true) : filename_{filename}, unload_{unload} {}
~ProviderLibrary() {
// assert(!handle_); // We should already be unloaded at this point (disabled until Python shuts down deterministically)
}
@ -1100,7 +1101,7 @@ struct ProviderLibrary {
if (!provider_) {
s_library_shared.Ensure();
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
auto full_path = Env::Default().GetRuntimePath() + filename_;
ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, false, &handle_));
Provider* (*PGetProvider)();
@ -1139,7 +1140,7 @@ struct ProviderLibrary {
private:
std::mutex mutex_;
const char* filename_;
const ORTCHAR_T* filename_;
bool unload_;
Provider* provider_{};
void* handle_{};
@ -1147,28 +1148,28 @@ struct ProviderLibrary {
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary);
};
static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION
static ProviderLibrary s_library_cuda(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_cuda") LIBRARY_EXTENSION
#ifndef _WIN32
,
false /* unload - On Linux if we unload the cuda shared provider we crash */
#endif
);
static ProviderLibrary s_library_cann(LIBRARY_PREFIX "onnxruntime_providers_cann" LIBRARY_EXTENSION
static ProviderLibrary s_library_cann(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_cann") LIBRARY_EXTENSION
#ifndef _WIN32
,
false /* unload - On Linux if we unload the cann shared provider we crash */
#endif
);
static ProviderLibrary s_library_rocm(LIBRARY_PREFIX "onnxruntime_providers_rocm" LIBRARY_EXTENSION
static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_rocm") LIBRARY_EXTENSION
#ifndef _WIN32
,
false /* unload - On Linux if we unload the rocm shared provider we crash */
#endif
);
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
static ProviderLibrary s_library_openvino(LIBRARY_PREFIX "onnxruntime_providers_openvino" LIBRARY_EXTENSION);
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX "onnxruntime_providers_migraphx" LIBRARY_EXTENSION);
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION);
static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION);
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION);
static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_migraphx") LIBRARY_EXTENSION);
void UnloadSharedProviders() {
s_library_dnnl.Unload();

View file

@ -13,6 +13,7 @@
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/common/optional.h"
#include "core/common/path_string.h"
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
@ -77,7 +78,8 @@ static Env& platform_env = Env::Default();
CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) {
{
OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(library_path, false, &library_handle_));
const auto path_str = ToPathString(library_path);
OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(path_str, false, &library_handle_));
OrtStatus*(ORT_API_CALL * RegisterCustomOps)(OrtSessionOptions * options, const OrtApiBase* api);
@ -311,7 +313,8 @@ static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
const ProviderOptions& provider_options = {},
const std::string& entry_symbol_name = "GetProvider") {
void* handle;
auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle);
const auto path_str = ToPathString(ep_shared_lib_path);
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
if (!error.IsOK()) {
throw std::runtime_error(error.ErrorMessage());
}

View file

@ -6,6 +6,7 @@
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/common/path_string.h"
#include "core/providers/get_execution_providers.h"
#include "core/session/provider_bridge_ort.h"
@ -49,7 +50,8 @@ bool GetDyanmicExecutionProviderHash(
size_t& hash,
const std::string& entry_symbol_name = "ProviderHashFunc") {
void* handle;
auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle);
const auto path_str = ToPathString(ep_shared_lib_path);
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
if (!error.IsOK()) {
throw std::runtime_error(error.ErrorMessage());
}