Return windows error code for library loading and unloading failure (#5036)

This commit is contained in:
Hariharan Seshadri 2020-09-02 18:07:36 -07:00 committed by GitHub
parent b4e9e98cee
commit a9db287bd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 16 deletions

View file

@ -473,22 +473,29 @@ class WindowsEnv : public Env {
virtual Status LoadDynamicLibrary(const std::string& library_filename, void** handle) const override {
*handle = ::LoadLibraryExA(library_filename.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
if (!*handle)
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to load library");
return common::Status::OK();
if (!*handle) {
const auto error_code = GetLastError();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load library, error code: ", error_code);
}
return Status::OK();
}
virtual common::Status UnloadDynamicLibrary(void* handle) const override {
if (::FreeLibrary(reinterpret_cast<HMODULE>(handle)) == 0)
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to unload library");
return common::Status::OK();
virtual Status UnloadDynamicLibrary(void* handle) const override {
if (::FreeLibrary(reinterpret_cast<HMODULE>(handle)) == 0) {
const auto error_code = GetLastError();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to unload library, error code: ", error_code);
}
return Status::OK();
}
virtual Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override {
*symbol = ::GetProcAddress(reinterpret_cast<HMODULE>(handle), symbol_name.c_str());
if (!*symbol)
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to find symbol in library");
return common::Status::OK();
if (!*symbol) {
const auto error_code = GetLastError();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to find symbol in library, error code: ",
error_code);
}
return Status::OK();
}
virtual std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override {

View file

@ -202,16 +202,10 @@ CustomOpLibrary::CustomOpLibrary(const char* library_path, OrtSessionOptions& or
{
OrtPybindThrowIfError(platform_env.LoadDynamicLibrary(library_path, &library_handle_));
if (!library_handle_)
throw std::runtime_error("RegisterCustomOpsLibrary: Failed to load library");
OrtStatus*(ORT_API_CALL * RegisterCustomOps)(OrtSessionOptions * options, const OrtApiBase* api);
OrtPybindThrowIfError(platform_env.GetSymbolFromLibrary(library_handle_, "RegisterCustomOps", (void**)&RegisterCustomOps));
if (!RegisterCustomOps)
throw std::runtime_error("RegisterCustomOpsLibrary: Entry point RegisterCustomOps not found in library");
auto* status_raw = RegisterCustomOps(&ort_so, OrtGetApiBase());
// Manage the raw Status pointer using a smart pointer
auto status = std::unique_ptr<OrtStatus>(status_raw);