From 4ffd022b0b1a62880d2498679f360878fdbbd796 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 24 Oct 2023 00:46:38 +0000 Subject: [PATCH] [TensorRT EP] Refactor of TRT plugins support (#17946) Make sure "trt.plugins" custom op domain only being registered once. The bottom line is "trt.plugins" custom op domain needs to be registered before model load. `CreateTensorRTCustomOpDomainList()` is TRT EP's function to create "trt.plugins" custom op domain. Following are places where this function will be called. (This function only fetches all the TRT plugins from TRT plugin registry but not yet registered them to ORT custom op registry. The real registration happens in AddCustomOpDomains()) C/C++ APIs: - `OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_XX`: This function will make session option object contain the "trt.plugins" custom op domain for ORT to register. So that later the session creation api can register the custom op domain accordingly and won't complain about invalid onnx node. - `InferenceSession::RegisterExecutionProvider`: In some cases, users might create the session object first and later call session_object.RegisterExecutionProvider(). This function will call p_exec_provider->GetCustomOpDomainList() which returns "trt.plugins" custom op domain. Otherwise, session_object.Load(model) will complain. Python APIs: - `RegisterTensorRTPluginsAsCustomOps`: Need to call this function so that session option object contains the "trt.plugins" custom op domain for ORT to register. Different language bindings have slightly different workflow of initializing the session. This might cause duplicate custom op domain in `session_option.custom_op_domains_` or `CreateTensorRTCustomOpDomainList()` being called more than once, but we put checks to make sure ep's custom op domain won't be registered twice. --- .../tensorrt/tensorrt_execution_provider.cc | 6 +++ .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 10 ----- onnxruntime/core/session/inference_session.cc | 30 ++++++++++++- .../core/session/provider_bridge_ort.cc | 43 +++++++++++-------- .../python/onnxruntime_pybind_state.cc | 15 ++++++- .../test/python/onnxruntime_test_python.py | 14 ++++++ 7 files changed, 88 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index d9238e41a2..ef1f0bf9f8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1210,6 +1210,12 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { + if (info_.custom_op_domain_list.empty()) { + common::Status status = CreateTensorRTCustomOpDomainList(info_); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + } custom_op_domain_list = info_.custom_op_domain_list; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 3bf6bc05a6..24c391ee11 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -197,7 +197,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { Status ReplayGraph() override; private: - TensorrtExecutionProviderInfo info_; + mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; int max_partition_iterations_ = 1000; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index b5dbe1ac45..d7e13df000 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -75,11 +75,6 @@ struct Tensorrt_Provider : Provider { info.device_id = device_id; info.has_trt_options = false; - common::Status status = CreateTensorRTCustomOpDomainList(info); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; - } - return std::make_shared(info); } @@ -121,11 +116,6 @@ struct Tensorrt_Provider : Provider { info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; - common::Status status = CreateTensorRTCustomOpDomainList(info); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; - } - return std::make_shared(info); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cad55afdf7..077b10ffc5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -613,9 +613,35 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Create Custom Op if EP requests it + // Register Custom Op if EP requests it std::vector custom_op_domains; - p_exec_provider->GetCustomOpDomainList(custom_op_domains); + std::vector candidate_custom_op_domains; + p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); + + auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); + + // Register the custom op domain only if it has not been registered before + if (registry_kernels.empty()) { + custom_op_domains = candidate_custom_op_domains; + } else { + for (auto candidate_custom_op_domain : candidate_custom_op_domains) { + for (auto registry_kernel : registry_kernels) { + const auto& kernel_map = registry_kernel->GetKernelCreateMap(); + bool need_register = true; + // If the kernel registry is the ep's custom op registry, we only need to check the first kernel, + // because all kernels in one kernel registry should have the same domain name. + for (auto iter = kernel_map.begin(); iter != kernel_map.end(); iter++) { + if (iter->second.kernel_def->Domain() == candidate_custom_op_domain->domain_) { + need_register = false; + break; + } + } + if (need_register) { + custom_op_domains.push_back(candidate_custom_op_domain); + } + } + } + } if (!custom_op_domains.empty()) { if (AddCustomOpDomains(custom_op_domains) != Status::OK()) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d950223f2d..d307f79c37 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1625,6 +1625,28 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op } // namespace onnxruntime +void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + + std::vector custom_op_domains; + onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); + provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { + if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { + options->custom_op_domains_.push_back(ptr); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; + } + } +} + ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { API_IMPL_BEGIN auto factory = onnxruntime::DnnlProviderFactoryCreator::Create(use_arena); @@ -1646,13 +1668,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS options->provider_factories.push_back(factory); - std::vector custom_op_domains; std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); return nullptr; API_IMPL_END @@ -1679,12 +1696,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In options->provider_factories.push_back(factory); - std::vector custom_op_domains; - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, ""); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, ""); return nullptr; API_IMPL_END @@ -1788,13 +1800,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, options->provider_factories.push_back(factory); - std::vector custom_op_domains; std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); return nullptr; API_IMPL_END diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 35e03bf9ea..a72f563601 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -433,6 +433,15 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* #ifdef USE_TENSORRT void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) { if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + std::string trt_extra_plugin_lib_paths = ""; const auto it = options.find("trt_extra_plugin_lib_paths"); if (it != options.end()) { @@ -441,7 +450,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti std::vector domain_list; tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); for (auto ptr : domain_list) { - so.custom_op_domains_.push_back(ptr); + if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { + so.custom_op_domains_.push_back(ptr); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; + } } } else { ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported."); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1d954fe437..d8628c4288 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -298,6 +298,20 @@ class TestInferenceSession(unittest.TestCase): self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path)) self.assertEqual(option["trt_force_sequential_engine_build"], "1") + from onnxruntime.capi import _pybind_state as C + + session_options = C.get_default_session_options() + + # TRT plugins registered as custom op domain should only be added once in session option regardless of number of session creation + sess1 = onnxrt.InferenceSession( + get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"] + ) + sess2 = onnxrt.InferenceSession( + get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"] + ) + self.assertIn("TensorrtExecutionProvider", sess1.get_providers()) + self.assertIn("TensorrtExecutionProvider", sess2.get_providers()) + # We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability """