diff --git a/include/onnxruntime/core/framework/provider_shutdown.h b/include/onnxruntime/core/framework/provider_shutdown.h new file mode 100644 index 0000000000..7488f12a0f --- /dev/null +++ b/include/onnxruntime/core/framework/provider_shutdown.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +void UnloadSharedProviders(); +} diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index 9c1c4e2d46..cbc9aaa850 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -10,6 +10,7 @@ #include "core/framework/data_transfer_manager.h" #include "core/framework/execution_provider.h" #include "core/framework/kernel_registry.h" +#include "core/framework/provider_shutdown.h" #include "core/graph/model.h" #include "core/platform/env.h" #include "core/providers/common.h" @@ -328,7 +329,7 @@ struct ProviderHostImpl : ProviderHost { return onnxruntime::make_unique(logger, severity, category, dataType, location); } void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; } - std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } + std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } // Provider_TypeProto_Tensor int32_t Provider_TypeProto_Tensor__elem_type(const Provider_TypeProto_Tensor* p) override { return p->elem_type(); } @@ -609,62 +610,97 @@ struct ProviderHostImpl : ProviderHost { } provider_host_; struct ProviderSharedLibrary { - ProviderSharedLibrary() { + bool Ensure() { + if (handle_) + return true; + std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION); auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_); if (!error.IsOK()) { LOGS_DEFAULT(ERROR) << error.ErrorMessage(); - return; + return false; } void (*PProvider_SetHost)(void*); Env::Default().GetSymbolFromLibrary(handle_, "Provider_SetHost", (void**)&PProvider_SetHost); PProvider_SetHost(&provider_host_); + return true; } - ~ProviderSharedLibrary() { - Env::Default().UnloadDynamicLibrary(handle_); + void Unload() { + if (handle_) { + Env::Default().UnloadDynamicLibrary(handle_); + handle_ = nullptr; + } } + ProviderSharedLibrary() = default; + ~ProviderSharedLibrary() { /*assert(!handle_);*/ + } // We should already be unloaded at this point (disabled until Python shuts down deterministically) + + private: void* handle_{}; ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderSharedLibrary); }; -bool EnsureSharedProviderLibrary() { - static ProviderSharedLibrary shared_library; - return shared_library.handle_; -} +static ProviderSharedLibrary s_library_shared; struct ProviderLibrary { - ProviderLibrary(const char* filename) { - if (!EnsureSharedProviderLibrary()) - return; + ProviderLibrary(const char* filename) : filename_{filename} {} + ~ProviderLibrary() { /*assert(!handle_);*/ + } // We should already be unloaded at this point (disabled until Python shuts down deterministically) - std::string full_path = Env::Default().GetRuntimePath() + std::string(filename); + Provider* Get() { + if (provider_) + return provider_; + + if (!s_library_shared.Ensure()) + return nullptr; + + std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_); auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_); if (!error.IsOK()) { LOGS_DEFAULT(ERROR) << error.ErrorMessage(); - return; + return nullptr; } Provider* (*PGetProvider)(); Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider); provider_ = PGetProvider(); + return provider_; } - ~ProviderLibrary() { - Env::Default().UnloadDynamicLibrary(handle_); + void Unload() { + if (handle_) { + if (provider_) + provider_->Shutdown(); + + Env::Default().UnloadDynamicLibrary(handle_); + handle_ = nullptr; + provider_ = nullptr; + } } + private: + const char* filename_; Provider* provider_{}; void* handle_{}; ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary); }; +static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION); +static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION); + +void UnloadSharedProviders() { + s_library_dnnl.Unload(); + s_library_tensorrt.Unload(); + s_library_shared.Unload(); +} + // This class translates the IExecutionProviderFactory interface to work with the interface providers implement struct IExecutionProviderFactory_Translator : IExecutionProviderFactory { IExecutionProviderFactory_Translator(std::shared_ptr p) : p_{p} {} @@ -677,22 +713,18 @@ struct IExecutionProviderFactory_Translator : IExecutionProviderFactory { std::shared_ptr p_; }; -std::shared_ptr CreateExecutionProviderFactory_Dnnl(int device_id) { - static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION); - if (!library.provider_) - return nullptr; +std::shared_ptr CreateExecutionProviderFactory_Dnnl(int use_arena) { + if (auto provider = s_library_dnnl.Get()) + return std::make_shared(provider->CreateExecutionProviderFactory(use_arena)); - //return std::make_shared(device_id); - //TODO: This is apparently a bug. The constructor parameter is create-arena-flag, not the device-id - return std::make_shared(library.provider_->CreateExecutionProviderFactory(device_id)); + return nullptr; } std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id) { - static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION); - if (!library.provider_) - return nullptr; + if (auto provider = s_library_tensorrt.Get()) + return std::make_shared(provider->CreateExecutionProviderFactory(device_id)); - return std::make_shared(library.provider_->CreateExecutionProviderFactory(device_id)); + return nullptr; } } // namespace onnxruntime @@ -700,7 +732,6 @@ std::shared_ptr CreateExecutionProviderFactory_Tensor ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena); if (!factory) { - LOGS_DEFAULT(ERROR) << "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library"; return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library"); } diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index b03c05ad65..67b24f5f20 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -11,16 +11,6 @@ #include "dnnl_execution_provider.h" #include "dnnl_fwd.h" -namespace { - -struct KernelRegistryAndStatus { - std::shared_ptr kernel_registry{onnxruntime::Provider_KernelRegistry::Create()}; - - Status st; -}; - -} // namespace - namespace onnxruntime { constexpr const char* DNNL = "Dnnl"; @@ -62,18 +52,24 @@ Status RegisterDNNLKernels(Provider_KernelRegistry& kernel_registry) { return Status::OK(); } -KernelRegistryAndStatus GetDnnlKernelRegistry() { - KernelRegistryAndStatus ret; - ret.st = RegisterDNNLKernels(*ret.kernel_registry); - return ret; -} } // namespace ort_dnnl +static std::shared_ptr s_kernel_registry; + +void Shutdown_DeleteRegistry() { + s_kernel_registry.reset(); +} + std::shared_ptr DNNLExecutionProvider::Provider_GetKernelRegistry() const { - static KernelRegistryAndStatus k = onnxruntime::ort_dnnl::GetDnnlKernelRegistry(); - // throw if the registry failed to initialize - ORT_THROW_IF_ERROR(k.st); - return k.kernel_registry; + if (!s_kernel_registry) { + s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create(); + auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry); + if (!status.IsOK()) + s_kernel_registry.reset(); + ORT_THROW_IF_ERROR(status); + } + + return s_kernel_registry; } bool DNNLExecutionProvider::UseSubgraph(const onnxruntime::Provider_GraphViewer& graph_viewer) const { diff --git a/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc b/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc index 9471fcdb6b..63425b720a 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc @@ -11,6 +11,8 @@ using namespace onnxruntime; namespace onnxruntime { +void Shutdown_DeleteRegistry(); + struct DnnlProviderFactory : Provider_IExecutionProviderFactory { DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {} ~DnnlProviderFactory() override {} @@ -47,9 +49,10 @@ struct Dnnl_Provider : Provider { return std::make_shared(use_arena != 0); } - void SetProviderHost(ProviderHost& host) { - onnxruntime::SetProviderHost(host); + void Shutdown() override { + Shutdown_DeleteRegistry(); } + } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 14e2305a56..fb4d0a8d91 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -84,7 +84,7 @@ struct Provider_NodeAttributes; struct Provider_OpKernelContext; struct Provider_OpKernelInfo; struct Provider_Tensor; -} +} // namespace onnxruntime #include "provider_interfaces.h" @@ -127,8 +127,6 @@ enum OperatorStatus : int { namespace onnxruntime { -void SetProviderHost(ProviderHost& host); - // The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own // Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example. void RunOnUnload(std::function function); diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index fa094d1d30..8ff9cebe73 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -221,6 +221,7 @@ struct Provider_IExecutionProvider { struct Provider { virtual std::shared_ptr CreateExecutionProviderFactory(int device_id) = 0; + virtual void Shutdown() = 0; }; // There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to member function) @@ -543,35 +544,35 @@ struct CPUIDInfo { bool HasAVX2() const { return g_host->CPUIDInfo__HasAVX2(this); } bool HasAVX512f() const { return g_host->CPUIDInfo__HasAVX512f(this); } -PROVIDER_DISALLOW_ALL(CPUIDInfo) + PROVIDER_DISALLOW_ALL(CPUIDInfo) }; namespace logging { struct Logger { - bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); } + bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); } -PROVIDER_DISALLOW_ALL(Logger) + PROVIDER_DISALLOW_ALL(Logger) }; struct LoggingManager { - static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); } + static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); } PROVIDER_DISALLOW_ALL(LoggingManager) }; struct Capture { - static std::unique_ptr Create(const Logger& logger, logging::Severity severity, const char* category, - logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); } - static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast(p)); } + static std::unique_ptr Create(const Logger& logger, logging::Severity severity, const char* category, + logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); } + static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast(p)); } - std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); } + std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); } - Capture() = delete; - Capture(const Capture&) = delete; - void operator=(const Capture&) = delete; + Capture() = delete; + Capture(const Capture&) = delete; + void operator=(const Capture&) = delete; }; -} +} // namespace logging struct Provider_TypeProto_Tensor { int32_t elem_type() const { return g_host->Provider_TypeProto_Tensor__elem_type(this); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 964b37a72c..1dc0d5f863 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -31,11 +31,6 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; namespace fs = std::experimental::filesystem; namespace { -struct KernelRegistryAndStatus { - std::shared_ptr kernel_registry{onnxruntime::Provider_KernelRegistry::Create()}; - Status st; -}; - std::string GetEnginePath(const ::std::string& root, const std::string& name) { if (root.empty()) { return name + ".engine"; @@ -151,17 +146,22 @@ static Status RegisterTensorrtKernels(Provider_KernelRegistry& kernel_registry) return Status::OK(); } -KernelRegistryAndStatus GetTensorrtKernelRegistry() { - KernelRegistryAndStatus ret; - ret.st = RegisterTensorrtKernels(*ret.kernel_registry); - return ret; +static std::shared_ptr s_kernel_registry; + +void Shutdown_DeleteRegistry() { + s_kernel_registry.reset(); } std::shared_ptr TensorrtExecutionProvider::Provider_GetKernelRegistry() const { - static KernelRegistryAndStatus k = onnxruntime::GetTensorrtKernelRegistry(); - // throw if the registry failed to initialize - ORT_THROW_IF_ERROR(k.st); - return k.kernel_registry; + if (!s_kernel_registry) { + s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create(); + auto status = RegisterTensorrtKernels(*s_kernel_registry); + if (!status.IsOK()) + s_kernel_registry.reset(); + ORT_THROW_IF_ERROR(status); + } + + return s_kernel_registry; } // Per TensorRT documentation, logger needs to be a singleton. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 77bd24204f..8f5f80a672 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -10,6 +10,8 @@ using namespace onnxruntime; namespace onnxruntime { +void Shutdown_DeleteRegistry(); + struct TensorrtProviderFactory : Provider_IExecutionProviderFactory { TensorrtProviderFactory(int device_id) : device_id_(device_id) {} ~TensorrtProviderFactory() override {} @@ -37,9 +39,10 @@ struct Tensorrt_Provider : Provider { return std::make_shared(device_id); } - void SetProviderHost(ProviderHost& host) { - onnxruntime::SetProviderHost(host); + void Shutdown() override { + Shutdown_DeleteRegistry(); } + } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index f7f460eec9..6577aa8ae1 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -10,6 +10,7 @@ #include "core/session/environment.h" #include "core/session/allocator_impl.h" #include "core/common/logging/logging.h" +#include "core/framework/provider_shutdown.h" #ifdef __ANDROID__ #include "core/platform/android/logging/android_log_sink.h" #else @@ -38,6 +39,13 @@ OrtEnv::OrtEnv(std::unique_ptr value1) : value_(std::move(value1)) { } +OrtEnv::~OrtEnv() { +// We don't support any shared providers in the minimal build yet +#if !defined(ORT_MINIMAL_BUILD) + UnloadSharedProviders(); +#endif +} + OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status, const OrtThreadingOptions* tp_options) { @@ -106,4 +114,4 @@ void OrtEnv::SetLoggingManager(std::unique_ptrRegisterAllocator(allocator); return status; -} \ No newline at end of file +} diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 041b072a93..800214c480 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -70,7 +70,7 @@ struct OrtEnv { std::unique_ptr value_; OrtEnv(std::unique_ptr value1); - ~OrtEnv() = default; + ~OrtEnv(); ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); -}; \ No newline at end of file +};