mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Fix shared provider unload crash (#5553)
This commit is contained in:
parent
4291c57322
commit
82c7a9756e
10 changed files with 130 additions and 82 deletions
8
include/onnxruntime/core/framework/provider_shutdown.h
Normal file
8
include/onnxruntime/core/framework/provider_shutdown.h
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
void UnloadSharedProviders();
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@
|
|||
#include "core/framework/data_transfer_manager.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
#include "core/framework/provider_shutdown.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/providers/common.h"
|
||||
|
|
@ -328,7 +329,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
return onnxruntime::make_unique<logging::Capture>(logger, severity, category, dataType, location);
|
||||
}
|
||||
void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; }
|
||||
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }
|
||||
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }
|
||||
|
||||
// Provider_TypeProto_Tensor
|
||||
int32_t Provider_TypeProto_Tensor__elem_type(const Provider_TypeProto_Tensor* p) override { return p->elem_type(); }
|
||||
|
|
@ -609,62 +610,97 @@ struct ProviderHostImpl : ProviderHost {
|
|||
} provider_host_;
|
||||
|
||||
struct ProviderSharedLibrary {
|
||||
ProviderSharedLibrary() {
|
||||
bool Ensure() {
|
||||
if (handle_)
|
||||
return true;
|
||||
|
||||
std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION);
|
||||
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
|
||||
if (!error.IsOK()) {
|
||||
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
void (*PProvider_SetHost)(void*);
|
||||
Env::Default().GetSymbolFromLibrary(handle_, "Provider_SetHost", (void**)&PProvider_SetHost);
|
||||
|
||||
PProvider_SetHost(&provider_host_);
|
||||
return true;
|
||||
}
|
||||
|
||||
~ProviderSharedLibrary() {
|
||||
Env::Default().UnloadDynamicLibrary(handle_);
|
||||
void Unload() {
|
||||
if (handle_) {
|
||||
Env::Default().UnloadDynamicLibrary(handle_);
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ProviderSharedLibrary() = default;
|
||||
~ProviderSharedLibrary() { /*assert(!handle_);*/
|
||||
} // We should already be unloaded at this point (disabled until Python shuts down deterministically)
|
||||
|
||||
private:
|
||||
void* handle_{};
|
||||
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderSharedLibrary);
|
||||
};
|
||||
|
||||
bool EnsureSharedProviderLibrary() {
|
||||
static ProviderSharedLibrary shared_library;
|
||||
return shared_library.handle_;
|
||||
}
|
||||
static ProviderSharedLibrary s_library_shared;
|
||||
|
||||
struct ProviderLibrary {
|
||||
ProviderLibrary(const char* filename) {
|
||||
if (!EnsureSharedProviderLibrary())
|
||||
return;
|
||||
ProviderLibrary(const char* filename) : filename_{filename} {}
|
||||
~ProviderLibrary() { /*assert(!handle_);*/
|
||||
} // We should already be unloaded at this point (disabled until Python shuts down deterministically)
|
||||
|
||||
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename);
|
||||
Provider* Get() {
|
||||
if (provider_)
|
||||
return provider_;
|
||||
|
||||
if (!s_library_shared.Ensure())
|
||||
return nullptr;
|
||||
|
||||
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
|
||||
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
|
||||
if (!error.IsOK()) {
|
||||
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
|
||||
return;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Provider* (*PGetProvider)();
|
||||
Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider);
|
||||
|
||||
provider_ = PGetProvider();
|
||||
return provider_;
|
||||
}
|
||||
|
||||
~ProviderLibrary() {
|
||||
Env::Default().UnloadDynamicLibrary(handle_);
|
||||
void Unload() {
|
||||
if (handle_) {
|
||||
if (provider_)
|
||||
provider_->Shutdown();
|
||||
|
||||
Env::Default().UnloadDynamicLibrary(handle_);
|
||||
handle_ = nullptr;
|
||||
provider_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const char* filename_;
|
||||
Provider* provider_{};
|
||||
void* handle_{};
|
||||
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary);
|
||||
};
|
||||
|
||||
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
|
||||
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
|
||||
|
||||
void UnloadSharedProviders() {
|
||||
s_library_dnnl.Unload();
|
||||
s_library_tensorrt.Unload();
|
||||
s_library_shared.Unload();
|
||||
}
|
||||
|
||||
// This class translates the IExecutionProviderFactory interface to work with the interface providers implement
|
||||
struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
|
||||
IExecutionProviderFactory_Translator(std::shared_ptr<Provider_IExecutionProviderFactory> p) : p_{p} {}
|
||||
|
|
@ -677,22 +713,18 @@ struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
|
|||
std::shared_ptr<Provider_IExecutionProviderFactory> p_;
|
||||
};
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int device_id) {
|
||||
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
|
||||
if (!library.provider_)
|
||||
return nullptr;
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena) {
|
||||
if (auto provider = s_library_dnnl.Get())
|
||||
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(use_arena));
|
||||
|
||||
//return std::make_shared<onnxruntime::MkldnnProviderFactory>(device_id);
|
||||
//TODO: This is apparently a bug. The constructor parameter is create-arena-flag, not the device-id
|
||||
return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id) {
|
||||
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
|
||||
if (!library.provider_)
|
||||
return nullptr;
|
||||
if (auto provider = s_library_tensorrt.Get())
|
||||
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(device_id));
|
||||
|
||||
return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -700,7 +732,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensor
|
|||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
|
||||
auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena);
|
||||
if (!factory) {
|
||||
LOGS_DEFAULT(ERROR) << "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library";
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,16 +11,6 @@
|
|||
#include "dnnl_execution_provider.h"
|
||||
#include "dnnl_fwd.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct KernelRegistryAndStatus {
|
||||
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};
|
||||
|
||||
Status st;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr const char* DNNL = "Dnnl";
|
||||
|
|
@ -62,18 +52,24 @@ Status RegisterDNNLKernels(Provider_KernelRegistry& kernel_registry) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
KernelRegistryAndStatus GetDnnlKernelRegistry() {
|
||||
KernelRegistryAndStatus ret;
|
||||
ret.st = RegisterDNNLKernels(*ret.kernel_registry);
|
||||
return ret;
|
||||
}
|
||||
} // namespace ort_dnnl
|
||||
|
||||
static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;
|
||||
|
||||
void Shutdown_DeleteRegistry() {
|
||||
s_kernel_registry.reset();
|
||||
}
|
||||
|
||||
std::shared_ptr<Provider_KernelRegistry> DNNLExecutionProvider::Provider_GetKernelRegistry() const {
|
||||
static KernelRegistryAndStatus k = onnxruntime::ort_dnnl::GetDnnlKernelRegistry();
|
||||
// throw if the registry failed to initialize
|
||||
ORT_THROW_IF_ERROR(k.st);
|
||||
return k.kernel_registry;
|
||||
if (!s_kernel_registry) {
|
||||
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
|
||||
auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry);
|
||||
if (!status.IsOK())
|
||||
s_kernel_registry.reset();
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
}
|
||||
|
||||
return s_kernel_registry;
|
||||
}
|
||||
|
||||
bool DNNLExecutionProvider::UseSubgraph(const onnxruntime::Provider_GraphViewer& graph_viewer) const {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ using namespace onnxruntime;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
void Shutdown_DeleteRegistry();
|
||||
|
||||
struct DnnlProviderFactory : Provider_IExecutionProviderFactory {
|
||||
DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {}
|
||||
~DnnlProviderFactory() override {}
|
||||
|
|
@ -47,9 +49,10 @@ struct Dnnl_Provider : Provider {
|
|||
return std::make_shared<DnnlProviderFactory>(use_arena != 0);
|
||||
}
|
||||
|
||||
void SetProviderHost(ProviderHost& host) {
|
||||
onnxruntime::SetProviderHost(host);
|
||||
void Shutdown() override {
|
||||
Shutdown_DeleteRegistry();
|
||||
}
|
||||
|
||||
} g_provider;
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ struct Provider_NodeAttributes;
|
|||
struct Provider_OpKernelContext;
|
||||
struct Provider_OpKernelInfo;
|
||||
struct Provider_Tensor;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
#include "provider_interfaces.h"
|
||||
|
||||
|
|
@ -127,8 +127,6 @@ enum OperatorStatus : int {
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
void SetProviderHost(ProviderHost& host);
|
||||
|
||||
// The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own
|
||||
// Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example.
|
||||
void RunOnUnload(std::function<void()> function);
|
||||
|
|
|
|||
|
|
@ -221,6 +221,7 @@ struct Provider_IExecutionProvider {
|
|||
|
||||
struct Provider {
|
||||
virtual std::shared_ptr<Provider_IExecutionProviderFactory> CreateExecutionProviderFactory(int device_id) = 0;
|
||||
virtual void Shutdown() = 0;
|
||||
};
|
||||
|
||||
// There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to member function)
|
||||
|
|
@ -543,35 +544,35 @@ struct CPUIDInfo {
|
|||
bool HasAVX2() const { return g_host->CPUIDInfo__HasAVX2(this); }
|
||||
bool HasAVX512f() const { return g_host->CPUIDInfo__HasAVX512f(this); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(CPUIDInfo)
|
||||
PROVIDER_DISALLOW_ALL(CPUIDInfo)
|
||||
};
|
||||
|
||||
namespace logging {
|
||||
|
||||
struct Logger {
|
||||
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }
|
||||
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(Logger)
|
||||
PROVIDER_DISALLOW_ALL(Logger)
|
||||
};
|
||||
|
||||
struct LoggingManager {
|
||||
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }
|
||||
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(LoggingManager)
|
||||
};
|
||||
|
||||
struct Capture {
|
||||
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
|
||||
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
|
||||
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }
|
||||
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
|
||||
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
|
||||
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }
|
||||
|
||||
std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }
|
||||
std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }
|
||||
|
||||
Capture() = delete;
|
||||
Capture(const Capture&) = delete;
|
||||
void operator=(const Capture&) = delete;
|
||||
Capture() = delete;
|
||||
Capture(const Capture&) = delete;
|
||||
void operator=(const Capture&) = delete;
|
||||
};
|
||||
}
|
||||
} // namespace logging
|
||||
|
||||
struct Provider_TypeProto_Tensor {
|
||||
int32_t elem_type() const { return g_host->Provider_TypeProto_Tensor__elem_type(this); }
|
||||
|
|
|
|||
|
|
@ -31,11 +31,6 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::logging;
|
||||
namespace fs = std::experimental::filesystem;
|
||||
namespace {
|
||||
struct KernelRegistryAndStatus {
|
||||
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};
|
||||
Status st;
|
||||
};
|
||||
|
||||
std::string GetEnginePath(const ::std::string& root, const std::string& name) {
|
||||
if (root.empty()) {
|
||||
return name + ".engine";
|
||||
|
|
@ -151,17 +146,22 @@ static Status RegisterTensorrtKernels(Provider_KernelRegistry& kernel_registry)
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
KernelRegistryAndStatus GetTensorrtKernelRegistry() {
|
||||
KernelRegistryAndStatus ret;
|
||||
ret.st = RegisterTensorrtKernels(*ret.kernel_registry);
|
||||
return ret;
|
||||
static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;
|
||||
|
||||
void Shutdown_DeleteRegistry() {
|
||||
s_kernel_registry.reset();
|
||||
}
|
||||
|
||||
std::shared_ptr<Provider_KernelRegistry> TensorrtExecutionProvider::Provider_GetKernelRegistry() const {
|
||||
static KernelRegistryAndStatus k = onnxruntime::GetTensorrtKernelRegistry();
|
||||
// throw if the registry failed to initialize
|
||||
ORT_THROW_IF_ERROR(k.st);
|
||||
return k.kernel_registry;
|
||||
if (!s_kernel_registry) {
|
||||
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
|
||||
auto status = RegisterTensorrtKernels(*s_kernel_registry);
|
||||
if (!status.IsOK())
|
||||
s_kernel_registry.reset();
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
}
|
||||
|
||||
return s_kernel_registry;
|
||||
}
|
||||
|
||||
// Per TensorRT documentation, logger needs to be a singleton.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ using namespace onnxruntime;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
void Shutdown_DeleteRegistry();
|
||||
|
||||
struct TensorrtProviderFactory : Provider_IExecutionProviderFactory {
|
||||
TensorrtProviderFactory(int device_id) : device_id_(device_id) {}
|
||||
~TensorrtProviderFactory() override {}
|
||||
|
|
@ -37,9 +39,10 @@ struct Tensorrt_Provider : Provider {
|
|||
return std::make_shared<TensorrtProviderFactory>(device_id);
|
||||
}
|
||||
|
||||
void SetProviderHost(ProviderHost& host) {
|
||||
onnxruntime::SetProviderHost(host);
|
||||
void Shutdown() override {
|
||||
Shutdown_DeleteRegistry();
|
||||
}
|
||||
|
||||
} g_provider;
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include "core/session/environment.h"
|
||||
#include "core/session/allocator_impl.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/provider_shutdown.h"
|
||||
#ifdef __ANDROID__
|
||||
#include "core/platform/android/logging/android_log_sink.h"
|
||||
#else
|
||||
|
|
@ -38,6 +39,13 @@ OrtEnv::OrtEnv(std::unique_ptr<onnxruntime::Environment> value1)
|
|||
: value_(std::move(value1)) {
|
||||
}
|
||||
|
||||
OrtEnv::~OrtEnv() {
|
||||
// We don't support any shared providers in the minimal build yet
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
UnloadSharedProviders();
|
||||
#endif
|
||||
}
|
||||
|
||||
OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info,
|
||||
onnxruntime::common::Status& status,
|
||||
const OrtThreadingOptions* tp_options) {
|
||||
|
|
@ -106,4 +114,4 @@ void OrtEnv::SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingMana
|
|||
onnxruntime::Status OrtEnv::RegisterAllocator(AllocatorPtr allocator) {
|
||||
auto status = value_->RegisterAllocator(allocator);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ struct OrtEnv {
|
|||
std::unique_ptr<onnxruntime::Environment> value_;
|
||||
|
||||
OrtEnv(std::unique_ptr<onnxruntime::Environment> value1);
|
||||
~OrtEnv() = default;
|
||||
~OrtEnv();
|
||||
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv);
|
||||
};
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue