mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-20 21:40:57 +00:00
Return windows error code for library loading and unloading failure (#5036)
This commit is contained in:
parent
b4e9e98cee
commit
a9db287bd7
2 changed files with 17 additions and 16 deletions
|
|
@ -473,22 +473,29 @@ class WindowsEnv : public Env {
|
|||
|
||||
virtual Status LoadDynamicLibrary(const std::string& library_filename, void** handle) const override {
|
||||
*handle = ::LoadLibraryExA(library_filename.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
|
||||
if (!*handle)
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to load library");
|
||||
return common::Status::OK();
|
||||
if (!*handle) {
|
||||
const auto error_code = GetLastError();
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load library, error code: ", error_code);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual common::Status UnloadDynamicLibrary(void* handle) const override {
|
||||
if (::FreeLibrary(reinterpret_cast<HMODULE>(handle)) == 0)
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to unload library");
|
||||
return common::Status::OK();
|
||||
virtual Status UnloadDynamicLibrary(void* handle) const override {
|
||||
if (::FreeLibrary(reinterpret_cast<HMODULE>(handle)) == 0) {
|
||||
const auto error_code = GetLastError();
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to unload library, error code: ", error_code);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override {
|
||||
*symbol = ::GetProcAddress(reinterpret_cast<HMODULE>(handle), symbol_name.c_str());
|
||||
if (!*symbol)
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to find symbol in library");
|
||||
return common::Status::OK();
|
||||
if (!*symbol) {
|
||||
const auto error_code = GetLastError();
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to find symbol in library, error code: ",
|
||||
error_code);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override {
|
||||
|
|
|
|||
|
|
@ -202,16 +202,10 @@ CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& or
|
|||
{
|
||||
OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(library_path, &library_handle_));
|
||||
|
||||
if (!library_handle_)
|
||||
throw std::runtime_error("RegisterCustomOpsLibrary: Failed to load library");
|
||||
|
||||
OrtStatus*(ORT_API_CALL * RegisterCustomOps)(OrtSessionOptions * options, const OrtApiBase* api);
|
||||
|
||||
OrtPybindThrowIfError(platform_env.GetSymbolFromLibrary(library_handle_, "RegisterCustomOps", (void**)&RegisterCustomOps));
|
||||
|
||||
if (!RegisterCustomOps)
|
||||
throw std::runtime_error("RegisterCustomOpsLibrary: Entry point RegisterCustomOps not found in library");
|
||||
|
||||
auto* status_raw = RegisterCustomOps(&ort_so, OrtGetApiBase());
|
||||
// Manage the raw Status pointer using a smart pointer
|
||||
auto status = std::unique_ptr<OrtStatus>(status_raw);
|
||||
|
|
|
|||
Loading…
Reference in a new issue