diff --git a/onnxruntime/core/framework/ex_lib_loader.cc b/onnxruntime/core/framework/ex_lib_loader.cc index 7070cbf9f1..66f53db3db 100644 --- a/onnxruntime/core/framework/ex_lib_loader.cc +++ b/onnxruntime/core/framework/ex_lib_loader.cc @@ -38,7 +38,8 @@ common::Status ExLibLoader::LoadExternalLib(const std::string& dso_file_path, } void* lib_handle = nullptr; - ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(dso_file_path, false, &lib_handle)); + auto path_str = ToPathString(dso_file_path); + ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(path_str, false, &lib_handle)); dso_name_data_map_[dso_file_path] = lib_handle; *handle = lib_handle; return Status::OK(); diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 369e492292..14c82096b6 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -230,7 +230,7 @@ class Env { // OK from the function. // Otherwise returns nullptr in "*handle" and an error status from the // function. - virtual common::Status LoadDynamicLibrary(const std::string& library_filename, bool global_symbols, void** handle) const = 0; + virtual common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const = 0; virtual common::Status UnloadDynamicLibrary(void* handle) const = 0; @@ -238,7 +238,7 @@ class Env { // // Used to help load other shared libraries that live in the same folder as the core code, for example // The DNNL provider shared library. Without this path, the module won't be found on windows in all cases. - virtual std::string GetRuntimePath() const { return ""; } + virtual PathString GetRuntimePath() const { return PathString(); } // \brief Get a pointer to a symbol from a dynamic library. // diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 0fe754c173..d055e78af3 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -529,7 +529,7 @@ class PosixEnv : public Env { return Status::OK(); } - common::Status LoadDynamicLibrary(const std::string& library_filename, bool global_symbols, void** handle) const override { + common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const override { dlerror(); // clear any old error_str *handle = dlopen(library_filename.c_str(), RTLD_NOW | (global_symbols ? RTLD_GLOBAL : RTLD_LOCAL)); char* error_str = dlerror(); diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index eacb1e86fa..2eaf3fb107 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -653,22 +653,26 @@ common::Status WindowsEnv::GetCanonicalPath( // Return the path of the executable/shared library for the current running code. This is to make it // possible to load other shared libraries installed next to our core runtime code. -std::string WindowsEnv::GetRuntimePath() const { - char buffer[MAX_PATH]; - if (!GetModuleFileNameA(reinterpret_cast(&__ImageBase), buffer, _countof(buffer))) - return ""; +PathString WindowsEnv::GetRuntimePath() const { + wchar_t buffer[MAX_PATH]; + if (!GetModuleFileNameW(reinterpret_cast(&__ImageBase), buffer, _countof(buffer))) { + return PathString(); + } // Remove the filename at the end, but keep the trailing slash - std::string path(buffer); - auto slash_index = path.find_last_of('\\'); - if (slash_index == std::string::npos) - return ""; - + PathString path(buffer); + auto slash_index = path.find_last_of(ORT_TSTR('\\')); + if (slash_index == std::string::npos) { + // Windows supports forward slashes + slash_index = path.find_last_of(ORT_TSTR('/')); + if (slash_index == std::string::npos) { + return PathString(); + } + } return path.substr(0, slash_index + 1); } -Status WindowsEnv::LoadDynamicLibrary(const std::string& library_filename, bool /*global_symbols*/, void** handle) const { - const std::wstring& wlibrary_filename = ToWideString(library_filename); +Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const { #if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP *handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0); #else diff --git a/onnxruntime/core/platform/windows/env.h b/onnxruntime/core/platform/windows/env.h index 5dd79007a8..c54a552b6f 100644 --- a/onnxruntime/core/platform/windows/env.h +++ b/onnxruntime/core/platform/windows/env.h @@ -76,8 +76,8 @@ class WindowsEnv : public Env { common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override; common::Status FileClose(int fd) const override; common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override; - std::string GetRuntimePath() const override; - Status LoadDynamicLibrary(const std::string& library_filename, bool /*global_symbols*/, void** handle) const override; + PathString GetRuntimePath() const override; + Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override; Status UnloadDynamicLibrary(void* handle) const override; Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override; std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9a3411bb15..7a9666be54 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -600,7 +600,8 @@ ORT_API_STATUS_IMPL(OrtApis::AddCustomOpDomain, _Inout_ OrtSessionOptions* optio ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle) { API_IMPL_BEGIN - ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(library_path, false, library_handle)); + auto path_str = ToPathString(library_path); + ORT_API_RETURN_IF_STATUS_NOT_OK(Env::Default().LoadDynamicLibrary(path_str, false, library_handle)); if (!*library_handle) return OrtApis::CreateStatus(ORT_FAIL, "RegisterCustomOpsLibrary: Failed to load library"); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ae7cdf757f..db0777fe88 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -90,7 +90,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; // The filename extension for a shared library is different per platform #ifdef _WIN32 #define LIBRARY_PREFIX -#define LIBRARY_EXTENSION ".dll" +#define LIBRARY_EXTENSION ORT_TSTR(".dll") #elif defined(__APPLE__) #define LIBRARY_PREFIX "lib" #define LIBRARY_EXTENSION ".dylib" @@ -1049,7 +1049,8 @@ struct ProviderSharedLibrary { if (handle_) return; - std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION); + auto full_path = Env::Default().GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_shared") LIBRARY_EXTENSION); ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true /*shared_globals on unix*/, &handle_)); void (*PProvider_SetHost)(void*); @@ -1089,7 +1090,7 @@ bool InitProvidersSharedLibrary() try { } struct ProviderLibrary { - ProviderLibrary(const char* filename, bool unload = true) : filename_{filename}, unload_{unload} {} + ProviderLibrary(const ORTCHAR_T* filename, bool unload = true) : filename_{filename}, unload_{unload} {} ~ProviderLibrary() { // assert(!handle_); // We should already be unloaded at this point (disabled until Python shuts down deterministically) } @@ -1100,7 +1101,7 @@ struct ProviderLibrary { if (!provider_) { s_library_shared.Ensure(); - std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_); + auto full_path = Env::Default().GetRuntimePath() + filename_; ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, false, &handle_)); Provider* (*PGetProvider)(); @@ -1139,7 +1140,7 @@ struct ProviderLibrary { private: std::mutex mutex_; - const char* filename_; + const ORTCHAR_T* filename_; bool unload_; Provider* provider_{}; void* handle_{}; @@ -1147,28 +1148,28 @@ struct ProviderLibrary { ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary); }; -static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION +static ProviderLibrary s_library_cuda(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_cuda") LIBRARY_EXTENSION #ifndef _WIN32 , false /* unload - On Linux if we unload the cuda shared provider we crash */ #endif ); -static ProviderLibrary s_library_cann(LIBRARY_PREFIX "onnxruntime_providers_cann" LIBRARY_EXTENSION +static ProviderLibrary s_library_cann(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_cann") LIBRARY_EXTENSION #ifndef _WIN32 , false /* unload - On Linux if we unload the cann shared provider we crash */ #endif ); -static ProviderLibrary s_library_rocm(LIBRARY_PREFIX "onnxruntime_providers_rocm" LIBRARY_EXTENSION +static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_rocm") LIBRARY_EXTENSION #ifndef _WIN32 , false /* unload - On Linux if we unload the rocm shared provider we crash */ #endif ); -static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION); -static ProviderLibrary s_library_openvino(LIBRARY_PREFIX "onnxruntime_providers_openvino" LIBRARY_EXTENSION); -static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION); -static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX "onnxruntime_providers_migraphx" LIBRARY_EXTENSION); +static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION); +static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION); +static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION); +static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_migraphx") LIBRARY_EXTENSION); void UnloadSharedProviders() { s_library_dnnl.Unload(); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 27bbfde16a..995977c5f7 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -13,6 +13,7 @@ #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" #include "core/common/optional.h" +#include "core/common/path_string.h" #include "core/framework/arena_extend_strategy.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" @@ -77,7 +78,8 @@ static Env& platform_env = Env::Default(); CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& ort_so) { { - OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(library_path, false, &library_handle_)); + const auto path_str = ToPathString(library_path); + OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(path_str, false, &library_handle_)); OrtStatus*(ORT_API_CALL * RegisterCustomOps)(OrtSessionOptions * options, const OrtApiBase* api); @@ -311,7 +313,8 @@ static std::unique_ptr LoadExecutionProvider( const ProviderOptions& provider_options = {}, const std::string& entry_symbol_name = "GetProvider") { void* handle; - auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle); + const auto path_str = ToPathString(ep_shared_lib_path); + auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle); if (!error.IsOK()) { throw std::runtime_error(error.ErrorMessage()); } diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 8ba48cd14e..0049932826 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -6,6 +6,7 @@ #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" +#include "core/common/path_string.h" #include "core/providers/get_execution_providers.h" #include "core/session/provider_bridge_ort.h" @@ -49,7 +50,8 @@ bool GetDyanmicExecutionProviderHash( size_t& hash, const std::string& entry_symbol_name = "ProviderHashFunc") { void* handle; - auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle); + const auto path_str = ToPathString(ep_shared_lib_path); + auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle); if (!error.IsOK()) { throw std::runtime_error(error.ErrorMessage()); }