diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 4cc7b14433..a65e5af723 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -29,6 +29,7 @@ class Node; #include "core/framework/allocatormgr.h" #include "core/framework/func_api.h" #include "core/framework/provider_options.h" +#include "core/framework/framework_provider_common.h" #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" @@ -146,6 +147,20 @@ class IExecutionProvider { */ virtual ProviderOptions GetProviderOptions() const { return {}; } + /** + Get provider specific custom op domain list. + Provider has the responsibility to release OrtCustomOpDomain instances it creates. + + NOTE: In the case of ONNX model having EP specific custom nodes and don't want to ask user to register those nodes, + EP might need to a way to register those custom nodes. This API is added for the purpose where EP can use it to + leverage ORT custom op to register those custom nodes with one or more custom op domains. + + For example, TensorRT EP uses this API to support TRT plugins where each custom op is mapped to TRT plugin and no + kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as + a role to help pass ONNX model validation. + */ + virtual void GetCustomOpDomainList(std::vector& /*provider custom op domain list*/) const {}; + /** Returns an opaque handle whose exact type varies based on the provider and is interpreted accordingly by the corresponding kernel implementation. diff --git a/include/onnxruntime/core/framework/framework_provider_common.h b/include/onnxruntime/core/framework/framework_provider_common.h new file mode 100644 index 0000000000..7c53f82898 --- /dev/null +++ b/include/onnxruntime/core/framework/framework_provider_common.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/session/onnxruntime_c_api.h" + +struct OrtCustomOpDomain { + std::string domain_; + std::vector custom_ops_; +}; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 8fc06cf2c2..1c9f7ecd09 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -40,4 +40,5 @@ struct OrtTensorRTProviderOptionsV2 { int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default // tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS" + const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4a2a925cb1..f9e541a24d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2460,6 +2460,12 @@ void RegisterContribSchemas() { } }); + // ORT will not regsiter TRT plugins as contrib ops, instead it will use custom ops handled by TRT EP. + // In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain backward compatible, + // we still keep EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT as legacy code. + // We don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us. + // Moving forward, please create TRT plugin node with "trt.plugins" domain. + static const char* EfficientNMS_TRT_ver1_doc = R"DOC(Efficient NMS TensorRT Plugin.)DOC"; @@ -2662,6 +2668,8 @@ void RegisterContribSchemas() { propagateShapeFromInputToOutput(ctx, 0, 0); }); + // Please note that we don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us. + ONNX_CONTRIB_OPERATOR_SCHEMA(Snpe) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 97e5db9a03..2e4d5b6bfa 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -713,9 +713,19 @@ void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) { return g_host->MurmurHash3__x86_128(key, len, seed, out); } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +Status LoadDynamicLibrary(onnxruntime::PathString library_name) { + return g_host->LoadDynamicLibrary(library_name); +} +#endif + #ifdef _WIN32 std::string ToUTF8String(const std::wstring& s) { return g_host->ToUTF8String(s); } + +std::wstring ToWideString(const std::string& s) { + return g_host->ToWideString(s); +} #endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_host_api.h b/onnxruntime/core/providers/shared_library/provider_host_api.h index 4b877abacf..43d661344d 100644 --- a/onnxruntime/core/providers/shared_library/provider_host_api.h +++ b/onnxruntime/core/providers/shared_library/provider_host_api.h @@ -26,6 +26,9 @@ struct Provider { // Update provider options from key-value string configuration virtual void UpdateProviderOptions(void* /*provider options to be configured*/, const ProviderOptions& /*key-value string provider options*/){}; + // Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates. + virtual void GetCustomOpDomainList(IExecutionProviderFactory* /*pointer to factory instance*/, std::vector& /*provider custom op domain list*/){}; + virtual void Initialize() = 0; // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded virtual void Shutdown() = 0; // Called right before unloading the shared library diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index dd9b2b2ddc..123af6ca38 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -891,9 +891,14 @@ struct ProviderHost { #ifdef _WIN32 virtual std::string ToUTF8String(const std::wstring& s) = 0; + virtual std::wstring ToWideString(const std::string& s) = 0; #endif virtual ProviderHostCPU& GetProviderHostCPU() = 0; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; +#endif }; #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7330409568..ce57e54b15 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -10,6 +10,7 @@ #include "core/common/safeint.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" +#include "tensorrt_execution_provider_custom_ops.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/providers/cuda/gpu_data_transfer.h" @@ -634,6 +635,7 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { if (!external_stream_ && stream_) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } + ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); } AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const { @@ -724,6 +726,10 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { return Status::OK(); } +void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { + custom_op_domain_list = info_.custom_op_domain_list; +} + // Check the graph is the subgraph of control flow op bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const GraphViewer& graph) const { if (graph.IsSubgraph()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 26ecb1f5cf..b66945d806 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -38,6 +38,7 @@ static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE"; static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL"; static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS"; static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES"; +static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS"; // Old env variable for backward compatibility static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; } // namespace tensorrt_env_vars @@ -163,6 +164,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override; + void GetCustomOpDomainList(std::vector& custom_op_domain_list) const override; + private: TensorrtExecutionProviderInfo info_; bool external_stream_ = false; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc new file mode 100644 index 0000000000..54a4d16e4e --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/provider_options.h" +#include "tensorrt_execution_provider_custom_ops.h" +#include "tensorrt_execution_provider.h" +#include +#include +#include + +namespace onnxruntime { +extern TensorrtLogger& GetTensorrtLogger(); + +/* + * Create custom op domain list for TRT plugins. + * + * Here, we collect all registered TRT plugins from TRT registry and create custom ops with "trt.plugins" domain. + * Additionally, if users specify extra plugin libraries, TRT EP will load them at runtime which will register those + * plugins to TRT plugin registry and later TRT EP can get them as well. + * + * There are several TRT plugins registered as onnx schema op through contrib op with ONNX domain in the past, + * for example, EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT. + * In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain + * backward compatible, we need to keep those legacy TRT plugins registered with ONNX domain with contrib ops. + * + * Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin. + * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. + */ +common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { + std::unique_ptr custom_op_domain = std::make_unique(); + custom_op_domain->domain_ = "trt.plugins"; + + // Load any extra TRT plugin library if any. + // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. + // This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator). + std::string extra_plugin_lib_paths{""}; + if (info.has_trt_options) { + if (!info.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info.extra_plugin_lib_paths; + } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + + // extra_plugin_lib_paths has the format of "path_1;path_2....;path_n" + if (!extra_plugin_lib_paths.empty()) { + std::stringstream extra_plugin_libs(extra_plugin_lib_paths); + std::string lib; + while (std::getline(extra_plugin_libs, lib, ';')) { + auto status = LoadDynamicLibrary(ToPathString(lib)); + if (status == Status::OK()) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully load " << lib; + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString(); + } + } + } + + // Get all registered TRT plugins from registry + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; + TensorrtLogger trt_logger = GetTensorrtLogger(); + initLibNvInferPlugins(&trt_logger, ""); + + int num_plugin_creator = 0; + auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator); + std::unordered_set registered_plugin_names; + + for (int i = 0; i < num_plugin_creator; i++) { + auto plugin_creator = plugin_creators[i]; + std::string plugin_name(plugin_creator->getPluginName()); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion(); + + // plugin has different versions and we only register once + if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) { + continue; + } + + std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); + trt_custom_op->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + registered_plugin_names.insert(plugin_name); + } + info.custom_op_domain_list.push_back(custom_op_domain.release()); + + return common::Status::OK(); +} + +void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { + if (domain != nullptr) { + for (auto ptr : domain->custom_ops_) { + if (ptr != nullptr) { + delete ptr; + } + } + delete domain; + } +} + +void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) { + for (auto ptr : custom_op_domain_list) { + ReleaseTensorRTCustomOpDomain(ptr); + } +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h new file mode 100644 index 0000000000..98ac3220ab --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#define ORT_API_MANUAL_INIT +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/shared_library/provider_api.h" +#include "tensorrt_execution_provider_info.h" + +using namespace onnxruntime; + +namespace onnxruntime { + +common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); +common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); +void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); +void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); + +struct TensorRTCustomKernel { + TensorRTCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream) + : compute_stream_(compute_stream) { + } + + void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here. + + private: + void* compute_stream_; +}; + +struct TensorRTCustomOp : Ort::CustomOpBase { + explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} + + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); }; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; }; + + private: + const char* provider_{onnxruntime::kTensorrtExecutionProvider}; + void* compute_stream_; + const char* name_; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index af9000bb17..3cc5b7bdb5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -39,6 +39,7 @@ constexpr const char* kSparsityEnable = "trt_sparsity_enable"; constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level"; constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams"; constexpr const char* kTacticSources = "trt_tactic_sources"; +constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; } // namespace provider_option_names } // namespace tensorrt @@ -83,6 +84,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) + .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) .Parse(options)); // add new provider option here. return info; @@ -117,6 +119,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, + {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, }; return options; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 262fc0854f..0734300883 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -7,7 +7,9 @@ #include "core/framework/ortdevice.h" #include "core/framework/provider_options.h" +#include "core/framework/framework_provider_common.h" #include "core/session/onnxruntime_c_api.h" +#include "core/framework/library_handles.h" namespace onnxruntime { // Information needed to construct trt execution providers. @@ -41,9 +43,12 @@ struct TensorrtExecutionProviderInfo { int builder_optimization_level{2}; int auxiliary_streams{-1}; std::string tactic_sources{""}; + std::string extra_plugin_lib_paths{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); + + std::vector custom_op_domain_list; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index aef7db7166..7d13832196 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -8,6 +8,7 @@ #include "tensorrt_provider_factory_creator.h" #include "core/framework/provider_options.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" +#include "core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h" #include using namespace onnxruntime; @@ -23,10 +24,16 @@ struct TensorrtProviderFactory : IExecutionProviderFactory { std::unique_ptr CreateProvider() override; + void GetCustomOpDomainList(std::vector& custom_op_domain_list); + private: TensorrtExecutionProviderInfo info_; }; +void TensorrtProviderFactory::GetCustomOpDomainList(std::vector& custom_op_domain_list) { + custom_op_domain_list = info_.custom_op_domain_list; +} + std::unique_ptr TensorrtProviderFactory::CreateProvider() { return std::make_unique(info_); } @@ -43,6 +50,11 @@ struct Tensorrt_Provider : Provider { TensorrtExecutionProviderInfo info; info.device_id = device_id; info.has_trt_options = false; + + common::Status status = CreateTensorRTCustomOpDomainList(info); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } return std::make_shared(info); } @@ -78,6 +90,13 @@ struct Tensorrt_Provider : Provider { info.builder_optimization_level = options.trt_builder_optimization_level; info.auxiliary_streams = options.trt_auxiliary_streams; info.tactic_sources = options.trt_tactic_sources == nullptr ? "" : options.trt_tactic_sources; + info.extra_plugin_lib_paths = options.trt_extra_plugin_lib_paths == nullptr ? "" : options.trt_extra_plugin_lib_paths; + + common::Status status = CreateTensorRTCustomOpDomainList(info); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + return std::make_shared(info); } @@ -172,6 +191,11 @@ struct Tensorrt_Provider : Provider { return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options); } + void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector& custom_op_domains_ptr) override { + TensorrtProviderFactory* trt_factory = reinterpret_cast(factory); + trt_factory->GetCustomOpDomainList(custom_op_domains_ptr); + } + void Initialize() override { InitializeRegistry(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 6ff5d7136a..5ae8ab8cf7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -508,6 +508,18 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Create Custom Op if EP requests it + std::vector custom_op_domains; + p_exec_provider->GetCustomOpDomainList(custom_op_domains); + + if (!custom_op_domains.empty()) { + if (AddCustomOpDomains(custom_op_domains) != Status::OK()) { + LOGS(*session_logger_, WARNING) << "Can't register custom op domains with ORT for " << provider_type; + } + } +#endif + // if any EPs do not support concurrent calls to Run we add locking around graph execution if (p_exec_provider->ConcurrentRunSupported() == false) { is_concurrent_run_supported_ = false; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 3cf36a1d80..85721250bf 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -19,6 +19,7 @@ #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" +#include "core/framework/framework_provider_common.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" @@ -36,11 +37,6 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE -struct OrtCustomOpDomain { - std::string domain_; - std::vector custom_ops_; -}; - namespace onnxruntime { // forward declarations class CustomRegistry; class Environment; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 6c1d9307c8..e65ba9e06e 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -163,6 +163,22 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator { Node::EdgeConstIterator v_; }; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_name) { + const auto& platform_env = onnxruntime::Env::Default(); + void* library_handle = nullptr; + + ORT_RETURN_IF_ERROR(platform_env.LoadDynamicLibrary(library_name, false, &library_handle)); + if (!library_handle) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load dynamic library ", + onnxruntime::PathToUTF8String(library_name)); + } + + return onnxruntime::Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26436) @@ -1042,9 +1058,14 @@ struct ProviderHostImpl : ProviderHost { #ifdef _WIN32 std::string ToUTF8String(const std::wstring& s) override { return onnxruntime::ToUTF8String(s); } + std::wstring ToWideString(const std::string& s) override { return onnxruntime::ToWideString(s); } #endif ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; +#endif } provider_host_; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) @@ -1286,6 +1307,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_builder_optimization_level = 2; trt_options_converted.trt_auxiliary_streams = -1; trt_options_converted.trt_tactic_sources = ""; + trt_options_converted.trt_extra_plugin_lib_paths = ""; return trt_options_converted; } @@ -1298,6 +1320,10 @@ std::shared_ptr TensorrtProviderFactoryCreator::Creat return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options); } +void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector& custom_op_domains_ptr) { + s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr); +} + std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); } @@ -1457,6 +1483,13 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS } options->provider_factories.push_back(factory); + + std::vector custom_op_domains; + TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains); + for (auto ptr : custom_op_domains) { + options->custom_op_domains_.push_back(ptr); + } + return nullptr; API_IMPL_END } @@ -1481,6 +1514,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In } options->provider_factories.push_back(factory); + + std::vector custom_op_domains; + TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains); + for (auto ptr : custom_op_domains) { + options->custom_op_domains_.push_back(ptr); + } + return nullptr; API_IMPL_END } @@ -1582,6 +1622,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, } options->provider_factories.push_back(factory); + + std::vector custom_op_domains; + TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains); + for (auto ptr : custom_op_domains) { + options->custom_op_domains_.push_back(ptr); + } return nullptr; API_IMPL_END } @@ -1613,6 +1659,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT (*out)->trt_timing_cache_enable = false; (*out)->trt_force_timing_cache = false; (*out)->trt_detailed_build_log = false; + (*out)->trt_extra_plugin_lib_paths = nullptr; return nullptr; #else ORT_UNUSED_PARAMETER(out); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f5d2e02719..c1036e4a9e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -376,6 +376,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 2, -1, + nullptr, nullptr}; for (auto option : it->second) { if (option.first == "device_id") { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index e63c9bbd91..cbadcef09a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -129,6 +129,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device int trt_builder_optimization_level = 2; int trt_auxiliary_streams = -1; std::string trt_tactic_sources = ""; + std::string trt_extra_plugin_lib_paths = ""; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -334,8 +335,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a non-emtpy string.\n"); } + } else if (key == "trt_extra_plugin_lib_paths") { + if (!value.empty()) { + trt_extra_plugin_lib_paths = value; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a non-emtpy string.\n"); + } } else { - ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback'] \n"); + ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_extra_plugin_lib_paths'] \n"); } } OrtTensorRTProviderOptionsV2 tensorrt_options; @@ -367,6 +374,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device tensorrt_options.trt_builder_optimization_level = trt_builder_optimization_level; tensorrt_options.trt_auxiliary_streams = trt_auxiliary_streams; tensorrt_options.trt_tactic_sources = trt_tactic_sources.c_str(); + tensorrt_options.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str(); session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); OrtCUDAProviderOptions cuda_options; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 60f5870f09..9886d9fdd7 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -686,7 +686,7 @@ TEST_P(ModelTest, Run) { OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30, 1, // enable fp16 0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0, - 2, -1, nullptr}; + 2, -1, nullptr, nullptr}; ortso.AppendExecutionProvider_TensorRT_V2(params); } else { OrtTensorRTProviderOptionsV2* ep_option = nullptr; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 75da23ccb5..bbfba5cfe8 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -161,6 +161,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string 0, 2, -1, + nullptr, nullptr}; params.trt_engine_cache_enable = 1; @@ -240,6 +241,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string 0, 2, -1, + nullptr, nullptr}; params.trt_engine_cache_enable = 1; @@ -333,6 +335,77 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded"; } +TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { + std::string model_name = "testdata/trt_plugin_custom_op_test.onnx"; + SessionOptions so; + so.session_logid = "TensorrtExecutionProviderTRTPluginsTest"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + InferenceSession session_object{so, GetEnvironment()}; + onnxruntime::AllocatorManager allocator_manager; + auto cuda_provider = DefaultCudaExecutionProvider(); + cuda_provider->RegisterAllocator(allocator_manager); + auto cpu_allocator = cuda_provider->GetAllocator(OrtMemTypeCPU); + std::vector dims_op_x = {12, 256, 256}; + std::vector values_op_x(1.0f, 786432); // 786432=12*256*256 + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_x); + OrtValue ml_value_y; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_y); + OrtValue ml_value_z; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_z); + NameMLValMap feeds; + feeds.insert(std::make_pair("input1", ml_value_x)); + feeds.insert(std::make_pair("input2", ml_value_y)); + feeds.insert(std::make_pair("input3", ml_value_z)); + + // prepare outputs + std::vector output_names; + output_names.push_back("output"); + std::vector fetches; + + OrtTensorRTProviderOptionsV2 params{ + 0, + 0, + nullptr, + 1000, + 1, + 1 << 30, + 0, + 0, + nullptr, + 0, + 0, + 0, + 0, + 0, + nullptr, + 0, + nullptr, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + -1, + nullptr, + nullptr}; + + std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + std::cout << model_name << std::endl; + auto status = session_object.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()); + status = session_object.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); +} + TEST_P(TensorrtExecutionProviderCacheTest, Run) { // GetParam() returns the parameter of following format: // ##cache type##_##input shape type## @@ -411,6 +484,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { 0, 2, -1, + nullptr, nullptr}; if (cache_type.compare("engine") == 0) { diff --git a/onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx b/onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx new file mode 100644 index 0000000000..99ac5a0bd9 --- /dev/null +++ b/onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx @@ -0,0 +1,27 @@ +:œ +ƒ +input1 +input2 +input3outputDisentangledAttention_TRT"DisentangledAttention_TRT* +factormēū= * +span€ : trt.pluginstrt_plugin_custom_opZ +input1 + + +€ +€Z +input2 + + +€ +€Z +input3 + + +€ +€b +output + + +€ +€B \ No newline at end of file diff --git a/onnxruntime/test/testdata/trt_plugin_custom_op_test.py b/onnxruntime/test/testdata/trt_plugin_custom_op_test.py new file mode 100644 index 0000000000..e387e80973 --- /dev/null +++ b/onnxruntime/test/testdata/trt_plugin_custom_op_test.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import TensorProto, helper + + +def generate_model(model_name): + nodes = [ + helper.make_node( + "DisentangledAttention_TRT", + ["input1", "input2", "input3"], + ["output"], + "DisentangledAttention_TRT", + domain="trt.plugins", + factor=0.123, + span=128, + ), + ] + + graph = helper.make_graph( + nodes, + "trt_plugin_custom_op", + [ # input + helper.make_tensor_value_info("input1", TensorProto.FLOAT, [12, 256, 256]), + helper.make_tensor_value_info("input2", TensorProto.FLOAT, [12, 256, 256]), + helper.make_tensor_value_info("input3", TensorProto.FLOAT, [12, 256, 256]), + ], + [ # output + helper.make_tensor_value_info("output", TensorProto.FLOAT, [12, 256, 256]), + ], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + generate_model("trt_plugin_custom_op_test.onnx")