Fix shared provider unload crash (#5553)

This commit is contained in:
Ryan Hill 2020-10-21 13:01:21 -07:00 committed by GitHub
parent 4291c57322
commit 82c7a9756e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 130 additions and 82 deletions

View file

@ -0,0 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
void UnloadSharedProviders();
}

View file

@ -10,6 +10,7 @@
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/provider_shutdown.h"
#include "core/graph/model.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
@ -328,7 +329,7 @@ struct ProviderHostImpl : ProviderHost {
return onnxruntime::make_unique<logging::Capture>(logger, severity, category, dataType, location);
}
void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; }
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }
// Provider_TypeProto_Tensor
int32_t Provider_TypeProto_Tensor__elem_type(const Provider_TypeProto_Tensor* p) override { return p->elem_type(); }
@ -609,62 +610,97 @@ struct ProviderHostImpl : ProviderHost {
} provider_host_;
struct ProviderSharedLibrary {
ProviderSharedLibrary() {
bool Ensure() {
if (handle_)
return true;
std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION);
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
return;
return false;
}
void (*PProvider_SetHost)(void*);
Env::Default().GetSymbolFromLibrary(handle_, "Provider_SetHost", (void**)&PProvider_SetHost);
PProvider_SetHost(&provider_host_);
return true;
}
~ProviderSharedLibrary() {
Env::Default().UnloadDynamicLibrary(handle_);
void Unload() {
if (handle_) {
Env::Default().UnloadDynamicLibrary(handle_);
handle_ = nullptr;
}
}
ProviderSharedLibrary() = default;
~ProviderSharedLibrary() { /*assert(!handle_);*/
} // We should already be unloaded at this point (disabled until Python shuts down deterministically)
private:
void* handle_{};
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderSharedLibrary);
};
bool EnsureSharedProviderLibrary() {
static ProviderSharedLibrary shared_library;
return shared_library.handle_;
}
static ProviderSharedLibrary s_library_shared;
struct ProviderLibrary {
ProviderLibrary(const char* filename) {
if (!EnsureSharedProviderLibrary())
return;
ProviderLibrary(const char* filename) : filename_{filename} {}
~ProviderLibrary() { /*assert(!handle_);*/
} // We should already be unloaded at this point (disabled until Python shuts down deterministically)
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename);
Provider* Get() {
if (provider_)
return provider_;
if (!s_library_shared.Ensure())
return nullptr;
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
return;
return nullptr;
}
Provider* (*PGetProvider)();
Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider);
provider_ = PGetProvider();
return provider_;
}
~ProviderLibrary() {
Env::Default().UnloadDynamicLibrary(handle_);
void Unload() {
if (handle_) {
if (provider_)
provider_->Shutdown();
Env::Default().UnloadDynamicLibrary(handle_);
handle_ = nullptr;
provider_ = nullptr;
}
}
private:
const char* filename_;
Provider* provider_{};
void* handle_{};
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary);
};
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
void UnloadSharedProviders() {
s_library_dnnl.Unload();
s_library_tensorrt.Unload();
s_library_shared.Unload();
}
// This class translates the IExecutionProviderFactory interface to work with the interface providers implement
struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
IExecutionProviderFactory_Translator(std::shared_ptr<Provider_IExecutionProviderFactory> p) : p_{p} {}
@ -677,22 +713,18 @@ struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
std::shared_ptr<Provider_IExecutionProviderFactory> p_;
};
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int device_id) {
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
if (!library.provider_)
return nullptr;
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena) {
if (auto provider = s_library_dnnl.Get())
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(use_arena));
//return std::make_shared<onnxruntime::MkldnnProviderFactory>(device_id);
//TODO: This is apparently a bug. The constructor parameter is create-arena-flag, not the device-id
return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
return nullptr;
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id) {
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
if (!library.provider_)
return nullptr;
if (auto provider = s_library_tensorrt.Get())
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(device_id));
return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
return nullptr;
}
} // namespace onnxruntime
@ -700,7 +732,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensor
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena);
if (!factory) {
LOGS_DEFAULT(ERROR) << "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library";
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library");
}

View file

@ -11,16 +11,6 @@
#include "dnnl_execution_provider.h"
#include "dnnl_fwd.h"
namespace {
struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};
Status st;
};
} // namespace
namespace onnxruntime {
constexpr const char* DNNL = "Dnnl";
@ -62,18 +52,24 @@ Status RegisterDNNLKernels(Provider_KernelRegistry& kernel_registry) {
return Status::OK();
}
KernelRegistryAndStatus GetDnnlKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterDNNLKernels(*ret.kernel_registry);
return ret;
}
} // namespace ort_dnnl
static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;
void Shutdown_DeleteRegistry() {
s_kernel_registry.reset();
}
std::shared_ptr<Provider_KernelRegistry> DNNLExecutionProvider::Provider_GetKernelRegistry() const {
static KernelRegistryAndStatus k = onnxruntime::ort_dnnl::GetDnnlKernelRegistry();
// throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
if (!s_kernel_registry) {
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}
return s_kernel_registry;
}
bool DNNLExecutionProvider::UseSubgraph(const onnxruntime::Provider_GraphViewer& graph_viewer) const {

View file

@ -11,6 +11,8 @@ using namespace onnxruntime;
namespace onnxruntime {
void Shutdown_DeleteRegistry();
struct DnnlProviderFactory : Provider_IExecutionProviderFactory {
DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {}
~DnnlProviderFactory() override {}
@ -47,9 +49,10 @@ struct Dnnl_Provider : Provider {
return std::make_shared<DnnlProviderFactory>(use_arena != 0);
}
void SetProviderHost(ProviderHost& host) {
onnxruntime::SetProviderHost(host);
void Shutdown() override {
Shutdown_DeleteRegistry();
}
} g_provider;
} // namespace onnxruntime

View file

@ -84,7 +84,7 @@ struct Provider_NodeAttributes;
struct Provider_OpKernelContext;
struct Provider_OpKernelInfo;
struct Provider_Tensor;
}
} // namespace onnxruntime
#include "provider_interfaces.h"
@ -127,8 +127,6 @@ enum OperatorStatus : int {
namespace onnxruntime {
void SetProviderHost(ProviderHost& host);
// The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own
// Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example.
void RunOnUnload(std::function<void()> function);

View file

@ -221,6 +221,7 @@ struct Provider_IExecutionProvider {
struct Provider {
virtual std::shared_ptr<Provider_IExecutionProviderFactory> CreateExecutionProviderFactory(int device_id) = 0;
virtual void Shutdown() = 0;
};
// There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to member function)
@ -543,35 +544,35 @@ struct CPUIDInfo {
bool HasAVX2() const { return g_host->CPUIDInfo__HasAVX2(this); }
bool HasAVX512f() const { return g_host->CPUIDInfo__HasAVX512f(this); }
PROVIDER_DISALLOW_ALL(CPUIDInfo)
PROVIDER_DISALLOW_ALL(CPUIDInfo)
};
namespace logging {
struct Logger {
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }
PROVIDER_DISALLOW_ALL(Logger)
PROVIDER_DISALLOW_ALL(Logger)
};
struct LoggingManager {
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }
PROVIDER_DISALLOW_ALL(LoggingManager)
};
struct Capture {
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }
std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }
std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }
Capture() = delete;
Capture(const Capture&) = delete;
void operator=(const Capture&) = delete;
Capture() = delete;
Capture(const Capture&) = delete;
void operator=(const Capture&) = delete;
};
}
} // namespace logging
struct Provider_TypeProto_Tensor {
int32_t elem_type() const { return g_host->Provider_TypeProto_Tensor__elem_type(this); }

View file

@ -31,11 +31,6 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
namespace fs = std::experimental::filesystem;
namespace {
struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};
Status st;
};
std::string GetEnginePath(const ::std::string& root, const std::string& name) {
if (root.empty()) {
return name + ".engine";
@ -151,17 +146,22 @@ static Status RegisterTensorrtKernels(Provider_KernelRegistry& kernel_registry)
return Status::OK();
}
KernelRegistryAndStatus GetTensorrtKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterTensorrtKernels(*ret.kernel_registry);
return ret;
static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;
void Shutdown_DeleteRegistry() {
s_kernel_registry.reset();
}
std::shared_ptr<Provider_KernelRegistry> TensorrtExecutionProvider::Provider_GetKernelRegistry() const {
static KernelRegistryAndStatus k = onnxruntime::GetTensorrtKernelRegistry();
// throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
if (!s_kernel_registry) {
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
auto status = RegisterTensorrtKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}
return s_kernel_registry;
}
// Per TensorRT documentation, logger needs to be a singleton.

View file

@ -10,6 +10,8 @@ using namespace onnxruntime;
namespace onnxruntime {
void Shutdown_DeleteRegistry();
struct TensorrtProviderFactory : Provider_IExecutionProviderFactory {
TensorrtProviderFactory(int device_id) : device_id_(device_id) {}
~TensorrtProviderFactory() override {}
@ -37,9 +39,10 @@ struct Tensorrt_Provider : Provider {
return std::make_shared<TensorrtProviderFactory>(device_id);
}
void SetProviderHost(ProviderHost& host) {
onnxruntime::SetProviderHost(host);
void Shutdown() override {
Shutdown_DeleteRegistry();
}
} g_provider;
} // namespace onnxruntime

View file

@ -10,6 +10,7 @@
#include "core/session/environment.h"
#include "core/session/allocator_impl.h"
#include "core/common/logging/logging.h"
#include "core/framework/provider_shutdown.h"
#ifdef __ANDROID__
#include "core/platform/android/logging/android_log_sink.h"
#else
@ -38,6 +39,13 @@ OrtEnv::OrtEnv(std::unique_ptr<onnxruntime::Environment> value1)
: value_(std::move(value1)) {
}
OrtEnv::~OrtEnv() {
// We don't support any shared providers in the minimal build yet
#if !defined(ORT_MINIMAL_BUILD)
UnloadSharedProviders();
#endif
}
OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info,
onnxruntime::common::Status& status,
const OrtThreadingOptions* tp_options) {
@ -106,4 +114,4 @@ void OrtEnv::SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingMana
onnxruntime::Status OrtEnv::RegisterAllocator(AllocatorPtr allocator) {
auto status = value_->RegisterAllocator(allocator);
return status;
}
}

View file

@ -70,7 +70,7 @@ struct OrtEnv {
std::unique_ptr<onnxruntime::Environment> value_;
OrtEnv(std::unique_ptr<onnxruntime::Environment> value1);
~OrtEnv() = default;
~OrtEnv();
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv);
};
};