Remove extra debug messages

Try a more clean python shutdown through DllMain
This commit is contained in:
Ryan Hill 2021-05-04 20:10:11 -07:00
parent 076846190e
commit b2c202d930
2 changed files with 26 additions and 17 deletions

View file

@ -873,7 +873,6 @@ struct ProviderSharedLibrary {
Env::Default().GetSymbolFromLibrary(handle_, "Provider_SetHost", (void**)&PProvider_SetHost);
PProvider_SetHost(&provider_host_);
LOGS_DEFAULT(WARNING) << "(RyanHill) Initialized Provider Shared Library";
return true;
}
@ -914,7 +913,6 @@ struct ProviderLibrary {
return nullptr;
std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
LOGS_DEFAULT(WARNING) << "(RyanHill) Loading provider: " << full_path;
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
@ -925,13 +923,11 @@ struct ProviderLibrary {
Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider);
provider_ = PGetProvider();
LOGS_DEFAULT(WARNING) << "(RyanHill) Provider Loaded " << full_path;
return provider_;
}
void Unload() {
if (handle_) {
LOGS_DEFAULT(WARNING) << "(RyanHill) Shutting down provider " << filename_;
if (provider_)
provider_->Shutdown();
@ -940,7 +936,6 @@ struct ProviderLibrary {
#endif
handle_ = nullptr;
provider_ = nullptr;
LOGS_DEFAULT(WARNING) << "(RyanHill) Shut down successful for " << filename_;
}
}
@ -958,15 +953,11 @@ static ProviderLibrary s_library_openvino(LIBRARY_PREFIX "onnxruntime_providers_
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
void UnloadSharedProviders() {
LOGS_DEFAULT(WARNING) << "Unloading shared providers... (RyanHill)";
s_library_dnnl.Unload();
s_library_openvino.Unload();
s_library_tensorrt.Unload();
s_library_cuda.Unload();
s_library_shared.Unload();
LOGS_DEFAULT(WARNING) << "Finished Unloading shared providers (RyanHill)";
}
// Used by test code
@ -981,7 +972,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Cuda(c
if (auto provider = s_library_cuda.Get())
return provider->CreateExecutionProviderFactory(provider_options);
LOGS_DEFAULT(WARNING) << "FAILED TO LOAD CUDA PROVIDER (RyanHill)";
return nullptr;
}

View file

@ -484,16 +484,17 @@ static bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int
return true;
}
static std::unordered_map<OrtDevice::DeviceId, AllocatorPtr> cuda_id_to_allocator_map;
static AllocatorPtr GetCudaAllocator(OrtDevice::DeviceId id) {
// Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make
// multi-threaded CUDA allocation work we need to maintain a per-thread CUDA allocator
static std::unordered_map<OrtDevice::DeviceId, AllocatorPtr>* id_to_allocator_map = new std::unordered_map<OrtDevice::DeviceId, AllocatorPtr>();
if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) {
id_to_allocator_map->insert({id, GetProviderInfo_CUDA()->CreateCudaAllocator(id, gpu_mem_limit, arena_extend_strategy, external_allocator_info)});
if (cuda_id_to_allocator_map.find(id) == cuda_id_to_allocator_map.end()) {
cuda_id_to_allocator_map.insert({id, GetProviderInfo_CUDA()->CreateCudaAllocator(id, gpu_mem_limit, arena_extend_strategy, external_allocator_info)});
}
return (*id_to_allocator_map)[id];
return cuda_id_to_allocator_map[id];
}
static void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes) {
@ -2097,9 +2098,9 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
LOGS(default_logger, WARNING) << "Init provider bridge failed.";
}
atexit([] {
UnloadSharedProviders();
});
// atexit([] {
// UnloadSharedProviders();
// });
#endif
#ifdef ENABLE_TRAINING
@ -2147,3 +2148,21 @@ onnxruntime::Environment& GetEnv() {
} // namespace python
} // namespace onnxruntime
#ifdef _WIN32
BOOL WINAPI DllMain(HINSTANCE hinstance, DWORD reason, void* reserved) {
switch (reason) {
case DLL_PROCESS_ATTACH:
break;
case DLL_PROCESS_DETACH:
onnxruntime::python::cuda_id_to_allocator_map.clear();
onnxruntime::UnloadSharedProviders();
break;
case DLL_THREAD_ATTACH:
break;
case DLL_THREAD_DETACH:
break;
}
return TRUE;
}
#endif