mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
Support loading widechar paths on windows (#14066)
### Description Make GetRuntimePath() and LoadDynamicLibrary() operate on platform specific paths ### Motivation and Context This addresses https://github.com/microsoft/onnxruntime/issues/14063
This commit is contained in:
parent
b85878953f
commit
5d729839b5
9 changed files with 45 additions and 33 deletions
|
|
@ -38,7 +38,8 @@ common::Status ExLibLoader::LoadExternalLib(const std::string& dso_file_path,
|
|||
}
|
||||
|
||||
void* lib_handle = nullptr;
|
||||
ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(dso_file_path, false, &lib_handle));
|
||||
auto path_str = ToPathString(dso_file_path);
|
||||
ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(path_str, false, &lib_handle));
|
||||
dso_name_data_map_[dso_file_path] = lib_handle;
|
||||
*handle = lib_handle;
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ class Env {
|
|||
// OK from the function.
|
||||
// Otherwise returns nullptr in "*handle" and an error status from the
|
||||
// function.
|
||||
virtual common::Status LoadDynamicLibrary(const std::string& library_filename, bool global_symbols, void** handle) const = 0;
|
||||
virtual common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const = 0;
|
||||
|
||||
virtual common::Status UnloadDynamicLibrary(void* handle) const = 0;
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ class Env {
|
|||
//
|
||||
// Used to help load other shared libraries that live in the same folder as the core code, for example
|
||||
// The DNNL provider shared library. Without this path, the module won't be found on windows in all cases.
|
||||
virtual std::string GetRuntimePath() const { return ""; }
|
||||
virtual PathString GetRuntimePath() const { return PathString(); }
|
||||
|
||||
// \brief Get a pointer to a symbol from a dynamic library.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -529,7 +529,7 @@ class PosixEnv : public Env {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status LoadDynamicLibrary(const std::string& library_filename, bool global_symbols, void** handle) const override {
|
||||
common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const override {
|
||||
dlerror(); // clear any old error_str
|
||||
*handle = dlopen(library_filename.c_str(), RTLD_NOW | (global_symbols ? RTLD_GLOBAL : RTLD_LOCAL));
|
||||
char* error_str = dlerror();
|
||||
|
|
|
|||
|
|
@ -653,22 +653,26 @@ common::Status WindowsEnv::GetCanonicalPath(
|
|||
|
||||
// Return the path of the executable/shared library for the current running code. This is to make it
|
||||
// possible to load other shared libraries installed next to our core runtime code.
|
||||
std::string WindowsEnv::GetRuntimePath() const {
|
||||
char buffer[MAX_PATH];
|
||||
if (!GetModuleFileNameA(reinterpret_cast<HINSTANCE>(&__ImageBase), buffer, _countof(buffer)))
|
||||
return "";
|
||||
PathString WindowsEnv::GetRuntimePath() const {
|
||||
wchar_t buffer[MAX_PATH];
|
||||
if (!GetModuleFileNameW(reinterpret_cast<HINSTANCE>(&__ImageBase), buffer, _countof(buffer))) {
|
||||
return PathString();
|
||||
}
|
||||
|
||||
// Remove the filename at the end, but keep the trailing slash
|
||||
std::string path(buffer);
|
||||
auto slash_index = path.find_last_of('\\');
|
||||
if (slash_index == std::string::npos)
|
||||
return "";
|
||||
|
||||
PathString path(buffer);
|
||||
auto slash_index = path.find_last_of(ORT_TSTR('\\'));
|
||||
if (slash_index == std::string::npos) {
|
||||
// Windows supports forward slashes
|
||||
slash_index = path.find_last_of(ORT_TSTR('/'));
|
||||
if (slash_index == std::string::npos) {
|
||||
return PathString();
|
||||
}
|
||||
}
|
||||
return path.substr(0, slash_index + 1);
|
||||
}
|
||||
|
||||
Status WindowsEnv::LoadDynamicLibrary(const std::string& library_filename, bool /*global_symbols*/, void** handle) const {
|
||||
const std::wstring& wlibrary_filename = ToWideString(library_filename);
|
||||
Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const {
|
||||
#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP
|
||||
*handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0);
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -76,8 +76,8 @@ class WindowsEnv : public Env {
|
|||
common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override;
|
||||
common::Status FileClose(int fd) const override;
|
||||
common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override;
|
||||
std::string GetRuntimePath() const override;
|
||||
Status LoadDynamicLibrary(const std::string& library_filename, bool /*global_symbols*/, void** handle) const override;
|
||||
PathString GetRuntimePath() const override;
|
||||
Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override;
|
||||
Status UnloadDynamicLibrary(void* handle) const override;
|
||||
Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override;
|
||||
std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override;
|
||||
|
|
|
|||
|
|
@ -600,7 +600,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddCustomOpDomain, _Inout_ OrtSessionOptions* optio
|
|||
ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(library_path, false, library_handle));
|
||||
auto path_str = ToPathString(library_path);
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(path_str, false, library_handle));
|
||||
if (!*library_handle)
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "RegisterCustomOpsLibrary: Failed to load library");
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
|
|||
// The filename extension for a shared library is different per platform
|
||||
#ifdef _WIN32
|
||||
#define LIBRARY_PREFIX
|
||||
#define LIBRARY_EXTENSION ".dll"
|
||||
#define LIBRARY_EXTENSION ORT_TSTR(".dll")
|
||||
#elif defined(__APPLE__)
|
||||
#define LIBRARY_PREFIX "lib"
|
||||
#define LIBRARY_EXTENSION ".dylib"
|
||||
|
|
@ -1049,7 +1049,8 @@ struct ProviderSharedLibrary {
|
|||
if (handle_)
|
||||
return;
|
||||
|
||||
std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION);
|
||||
auto full_path = Env::Default().GetRuntimePath() +
|
||||
PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_shared") LIBRARY_EXTENSION);
|
||||
ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true /*shared_globals on unix*/, &handle_));
|
||||
|
||||
void (*PProvider_SetHost)(void*);
|
||||
|
|
@ -1089,7 +1090,7 @@ bool InitProvidersSharedLibrary() try {
|
|||
}
|
||||
|
||||
struct ProviderLibrary {
|
||||
ProviderLibrary(const char* filename, bool unload = true) : filename_{filename}, unload_{unload} {}
|
||||
ProviderLibrary(const ORTCHAR_T* filename, bool unload = true) : filename_{filename}, unload_{unload} {}
|
||||
~ProviderLibrary() {
|
||||
// assert(!handle_); // We should already be unloaded at this point (disabled until Python shuts down deterministically)
|
||||
}
|
||||
|
|
@ -1100,7 +1101,7 @@ struct ProviderLibrary {
|
|||
if (!provider_) {
|
||||
s_library_shared.Ensure();
|
||||
|
||||
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
|
||||
auto full_path = Env::Default().GetRuntimePath() + filename_;
|
||||
ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, false, &handle_));
|
||||
|
||||
Provider* (*PGetProvider)();
|
||||
|
|
@ -1139,7 +1140,7 @@ struct ProviderLibrary {
|
|||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
const char* filename_;
|
||||
const ORTCHAR_T* filename_;
|
||||
bool unload_;
|
||||
Provider* provider_{};
|
||||
void* handle_{};
|
||||
|
|
@ -1147,28 +1148,28 @@ struct ProviderLibrary {
|
|||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary);
|
||||
};
|
||||
|
||||
static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION
|
||||
static ProviderLibrary s_library_cuda(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_cuda") LIBRARY_EXTENSION
|
||||
#ifndef _WIN32
|
||||
,
|
||||
false /* unload - On Linux if we unload the cuda shared provider we crash */
|
||||
#endif
|
||||
);
|
||||
static ProviderLibrary s_library_cann(LIBRARY_PREFIX "onnxruntime_providers_cann" LIBRARY_EXTENSION
|
||||
static ProviderLibrary s_library_cann(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_cann") LIBRARY_EXTENSION
|
||||
#ifndef _WIN32
|
||||
,
|
||||
false /* unload - On Linux if we unload the cann shared provider we crash */
|
||||
#endif
|
||||
);
|
||||
static ProviderLibrary s_library_rocm(LIBRARY_PREFIX "onnxruntime_providers_rocm" LIBRARY_EXTENSION
|
||||
static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_rocm") LIBRARY_EXTENSION
|
||||
#ifndef _WIN32
|
||||
,
|
||||
false /* unload - On Linux if we unload the rocm shared provider we crash */
|
||||
#endif
|
||||
);
|
||||
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_openvino(LIBRARY_PREFIX "onnxruntime_providers_openvino" LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX "onnxruntime_providers_migraphx" LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_migraphx") LIBRARY_EXTENSION);
|
||||
|
||||
void UnloadSharedProviders() {
|
||||
s_library_dnnl.Unload();
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
#include "core/common/optional.h"
|
||||
#include "core/common/path_string.h"
|
||||
#include "core/framework/arena_extend_strategy.h"
|
||||
#include "core/framework/data_transfer_utils.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
|
|
@ -77,7 +78,8 @@ static Env& platform_env = Env::Default();
|
|||
|
||||
CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) {
|
||||
{
|
||||
OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(library_path, false, &library_handle_));
|
||||
const auto path_str = ToPathString(library_path);
|
||||
OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(path_str, false, &library_handle_));
|
||||
|
||||
OrtStatus*(ORT_API_CALL * RegisterCustomOps)(OrtSessionOptions * options, const OrtApiBase* api);
|
||||
|
||||
|
|
@ -311,7 +313,8 @@ static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
|
|||
const ProviderOptions& provider_options = {},
|
||||
const std::string& entry_symbol_name = "GetProvider") {
|
||||
void* handle;
|
||||
auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle);
|
||||
const auto path_str = ToPathString(ep_shared_lib_path);
|
||||
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
|
||||
if (!error.IsOK()) {
|
||||
throw std::runtime_error(error.ErrorMessage());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
#include "core/common/path_string.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
#include "core/session/provider_bridge_ort.h"
|
||||
|
||||
|
|
@ -49,7 +50,8 @@ bool GetDyanmicExecutionProviderHash(
|
|||
size_t& hash,
|
||||
const std::string& entry_symbol_name = "ProviderHashFunc") {
|
||||
void* handle;
|
||||
auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle);
|
||||
const auto path_str = ToPathString(ep_shared_lib_path);
|
||||
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
|
||||
if (!error.IsOK()) {
|
||||
throw std::runtime_error(error.ErrorMessage());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue