From 6115c8fd1f4c32e058ac7a841fdddb0e99e509cc Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 18 Apr 2023 20:24:32 -0700 Subject: [PATCH] Add TRT plugins support using custom ops (#13847) This PR makes ORT support TRT plugin using custom ops. ORT TRT can automatically register all TRT plugins from TRT plugins registry as custom ops. There is no code change needed for ORT when new TRT plugins are introduced. Previous way for ORT to support TRT plugins was using contrib ops, but there are some concerns about it: - Contrib ops are shipped as part of the ORT binary by default. TRT related plugins should not be in the default ORT. - Contrib ops are designed for internal ops and developed for cpu and cuda EPs. Therefore, using custom ops is a good approach to support TRT plugins. Followings are the major modifications: 1. Add new `GetCustomOpDomainList` provider api which allows provider to create its own custom op domain list and ORT can register this domain list. Provider has the responsibility to free all the custom op domain instances it created. 2. Move OrtCustomOpDomain struct definition to framework_provider_common.h since this struct is being used by framework and EPs now. 3. There are several TRT plugins registered as onnx schema op through contrib op with onnx domain. In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain backward compatible, we need to keep the old/legacy TRT plugins with onnx domain. Moving forward, all newly added TRT plugins should be registered with `trt.plugins` domain. 4. TRT plugin doesn't have an api to get number of inputs/outputs of the registered plugins, so ORT TRT uses variadic inputs/outputs to bypass the onnx node validation. 5. Add new trt provider option, `trt_extra_plugin_lib_paths`, user can specify any extra plugin lib, for example, `fastertransformer/build/lib/libvit_plugin.so` or `fastertransformer/build/lib/libvit_plugin.so;fastertransformer/build/lib/libvit_plugin_v2.so` --- .../core/framework/execution_provider.h | 15 +++ .../framework/framework_provider_common.h | 10 ++ .../tensorrt/tensorrt_provider_options.h | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 8 ++ .../provider_bridge_provider.cc | 10 ++ .../shared_library/provider_host_api.h | 3 + .../shared_library/provider_interfaces.h | 5 + .../tensorrt/tensorrt_execution_provider.cc | 6 + .../tensorrt/tensorrt_execution_provider.h | 3 + .../tensorrt_execution_provider_custom_ops.cc | 108 ++++++++++++++++++ .../tensorrt_execution_provider_custom_ops.h | 65 +++++++++++ .../tensorrt_execution_provider_info.cc | 3 + .../tensorrt_execution_provider_info.h | 5 + .../tensorrt/tensorrt_provider_factory.cc | 24 ++++ onnxruntime/core/session/inference_session.cc | 12 ++ onnxruntime/core/session/inference_session.h | 6 +- .../core/session/provider_bridge_ort.cc | 47 ++++++++ .../python/onnxruntime_pybind_state.cc | 1 + onnxruntime/test/perftest/ort_test_session.cc | 10 +- onnxruntime/test/providers/cpu/model_tests.cc | 2 +- .../providers/tensorrt/tensorrt_basic_test.cc | 74 ++++++++++++ .../testdata/trt_plugin_custom_op_test.onnx | 27 +++++ .../testdata/trt_plugin_custom_op_test.py | 40 +++++++ 23 files changed, 478 insertions(+), 7 deletions(-) create mode 100644 include/onnxruntime/core/framework/framework_provider_common.h create mode 100644 onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc create mode 100644 onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h create mode 100644 onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx create mode 100644 onnxruntime/test/testdata/trt_plugin_custom_op_test.py diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 4cc7b14433..a65e5af723 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -29,6 +29,7 @@ class Node; #include "core/framework/allocatormgr.h" #include "core/framework/func_api.h" #include "core/framework/provider_options.h" +#include "core/framework/framework_provider_common.h" #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" @@ -146,6 +147,20 @@ class IExecutionProvider { */ virtual ProviderOptions GetProviderOptions() const { return {}; } + /** + Get provider specific custom op domain list. + Provider has the responsibility to release OrtCustomOpDomain instances it creates. + + NOTE: In the case of ONNX model having EP specific custom nodes and don't want to ask user to register those nodes, + EP might need to a way to register those custom nodes. This API is added for the purpose where EP can use it to + leverage ORT custom op to register those custom nodes with one or more custom op domains. + + For example, TensorRT EP uses this API to support TRT plugins where each custom op is mapped to TRT plugin and no + kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as + a role to help pass ONNX model validation. + */ + virtual void GetCustomOpDomainList(std::vector& /*provider custom op domain list*/) const {}; + /** Returns an opaque handle whose exact type varies based on the provider and is interpreted accordingly by the corresponding kernel implementation. diff --git a/include/onnxruntime/core/framework/framework_provider_common.h b/include/onnxruntime/core/framework/framework_provider_common.h new file mode 100644 index 0000000000..7c53f82898 --- /dev/null +++ b/include/onnxruntime/core/framework/framework_provider_common.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/session/onnxruntime_c_api.h" + +struct OrtCustomOpDomain { + std::string domain_; + std::vector custom_ops_; +}; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 8fc06cf2c2..1c9f7ecd09 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -40,4 +40,5 @@ struct OrtTensorRTProviderOptionsV2 { int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default // tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS" + const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4a2a925cb1..f9e541a24d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2460,6 +2460,12 @@ void RegisterContribSchemas() { } }); + // ORT will not regsiter TRT plugins as contrib ops, instead it will use custom ops handled by TRT EP. + // In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain backward compatible, + // we still keep EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT as legacy code. + // We don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us. + // Moving forward, please create TRT plugin node with "trt.plugins" domain. + static const char* EfficientNMS_TRT_ver1_doc = R"DOC(Efficient NMS TensorRT Plugin.)DOC"; @@ -2662,6 +2668,8 @@ void RegisterContribSchemas() { propagateShapeFromInputToOutput(ctx, 0, 0); }); + // Please note that we don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us. + ONNX_CONTRIB_OPERATOR_SCHEMA(Snpe) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 97e5db9a03..2e4d5b6bfa 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -713,9 +713,19 @@ void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) { return g_host->MurmurHash3__x86_128(key, len, seed, out); } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +Status LoadDynamicLibrary(onnxruntime::PathString library_name) { + return g_host->LoadDynamicLibrary(library_name); +} +#endif + #ifdef _WIN32 std::string ToUTF8String(const std::wstring& s) { return g_host->ToUTF8String(s); } + +std::wstring ToWideString(const std::string& s) { + return g_host->ToWideString(s); +} #endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_host_api.h b/onnxruntime/core/providers/shared_library/provider_host_api.h index 4b877abacf..43d661344d 100644 --- a/onnxruntime/core/providers/shared_library/provider_host_api.h +++ b/onnxruntime/core/providers/shared_library/provider_host_api.h @@ -26,6 +26,9 @@ struct Provider { // Update provider options from key-value string configuration virtual void UpdateProviderOptions(void* /*provider options to be configured*/, const ProviderOptions& /*key-value string provider options*/){}; + // Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates. + virtual void GetCustomOpDomainList(IExecutionProviderFactory* /*pointer to factory instance*/, std::vector& /*provider custom op domain list*/){}; + virtual void Initialize() = 0; // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded virtual void Shutdown() = 0; // Called right before unloading the shared library diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index dd9b2b2ddc..123af6ca38 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -891,9 +891,14 @@ struct ProviderHost { #ifdef _WIN32 virtual std::string ToUTF8String(const std::wstring& s) = 0; + virtual std::wstring ToWideString(const std::string& s) = 0; #endif virtual ProviderHostCPU& GetProviderHostCPU() = 0; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; +#endif }; #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7330409568..ce57e54b15 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -10,6 +10,7 @@ #include "core/common/safeint.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" +#include "tensorrt_execution_provider_custom_ops.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/providers/cuda/gpu_data_transfer.h" @@ -634,6 +635,7 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { if (!external_stream_ && stream_) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } + ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); } AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const { @@ -724,6 +726,10 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { return Status::OK(); } +void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { + custom_op_domain_list = info_.custom_op_domain_list; +} + // Check the graph is the subgraph of control flow op bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const GraphViewer& graph) const { if (graph.IsSubgraph()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 26ecb1f5cf..b66945d806 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -38,6 +38,7 @@ static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE"; static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL"; static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS"; static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES"; +static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS"; // Old env variable for backward compatibility static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; } // namespace tensorrt_env_vars @@ -163,6 +164,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override; + void GetCustomOpDomainList(std::vector& custom_op_domain_list) const override; + private: TensorrtExecutionProviderInfo info_; bool external_stream_ = false; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc new file mode 100644 index 0000000000..54a4d16e4e --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/provider_options.h" +#include "tensorrt_execution_provider_custom_ops.h" +#include "tensorrt_execution_provider.h" +#include +#include +#include + +namespace onnxruntime { +extern TensorrtLogger& GetTensorrtLogger(); + +/* + * Create custom op domain list for TRT plugins. + * + * Here, we collect all registered TRT plugins from TRT registry and create custom ops with "trt.plugins" domain. + * Additionally, if users specify extra plugin libraries, TRT EP will load them at runtime which will register those + * plugins to TRT plugin registry and later TRT EP can get them as well. + * + * There are several TRT plugins registered as onnx schema op through contrib op with ONNX domain in the past, + * for example, EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT. + * In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain + * backward compatible, we need to keep those legacy TRT plugins registered with ONNX domain with contrib ops. + * + * Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin. + * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. + */ +common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { + std::unique_ptr custom_op_domain = std::make_unique(); + custom_op_domain->domain_ = "trt.plugins"; + + // Load any extra TRT plugin library if any. + // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. + // This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator). + std::string extra_plugin_lib_paths{""}; + if (info.has_trt_options) { + if (!info.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info.extra_plugin_lib_paths; + } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + + // extra_plugin_lib_paths has the format of "path_1;path_2....;path_n" + if (!extra_plugin_lib_paths.empty()) { + std::stringstream extra_plugin_libs(extra_plugin_lib_paths); + std::string lib; + while (std::getline(extra_plugin_libs, lib, ';')) { + auto status = LoadDynamicLibrary(ToPathString(lib)); + if (status == Status::OK()) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully load " << lib; + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString(); + } + } + } + + // Get all registered TRT plugins from registry + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; + TensorrtLogger trt_logger = GetTensorrtLogger(); + initLibNvInferPlugins(&trt_logger, ""); + + int num_plugin_creator = 0; + auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator); + std::unordered_set registered_plugin_names; + + for (int i = 0; i < num_plugin_creator; i++) { + auto plugin_creator = plugin_creators[i]; + std::string plugin_name(plugin_creator->getPluginName()); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion(); + + // plugin has different versions and we only register once + if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) { + continue; + } + + std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); + trt_custom_op->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + registered_plugin_names.insert(plugin_name); + } + info.custom_op_domain_list.push_back(custom_op_domain.release()); + + return common::Status::OK(); +} + +void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { + if (domain != nullptr) { + for (auto ptr : domain->custom_ops_) { + if (ptr != nullptr) { + delete ptr; + } + } + delete domain; + } +} + +void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) { + for (auto ptr : custom_op_domain_list) { + ReleaseTensorRTCustomOpDomain(ptr); + } +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h new file mode 100644 index 0000000000..98ac3220ab --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#define ORT_API_MANUAL_INIT +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/shared_library/provider_api.h" +#include "tensorrt_execution_provider_info.h" + +using namespace onnxruntime; + +namespace onnxruntime { + +common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); +common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); +void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); +void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); + +struct TensorRTCustomKernel { + TensorRTCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream) + : compute_stream_(compute_stream) { + } + + void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here. + + private: + void* compute_stream_; +}; + +struct TensorRTCustomOp : Ort::CustomOpBase { + explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} + + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); }; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; }; + + private: + const char* provider_{onnxruntime::kTensorrtExecutionProvider}; + void* compute_stream_; + const char* name_; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index af9000bb17..3cc5b7bdb5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -39,6 +39,7 @@ constexpr const char* kSparsityEnable = "trt_sparsity_enable"; constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level"; constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams"; constexpr const char* kTacticSources = "trt_tactic_sources"; +constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; } // namespace provider_option_names } // namespace tensorrt @@ -83,6 +84,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) + .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) .Parse(options)); // add new provider option here. return info; @@ -117,6 +119,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)}, {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)}, {tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)}, + {tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)}, }; return options; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 262fc0854f..0734300883 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -7,7 +7,9 @@ #include "core/framework/ortdevice.h" #include "core/framework/provider_options.h" +#include "core/framework/framework_provider_common.h" #include "core/session/onnxruntime_c_api.h" +#include "core/framework/library_handles.h" namespace onnxruntime { // Information needed to construct trt execution providers. @@ -41,9 +43,12 @@ struct TensorrtExecutionProviderInfo { int builder_optimization_level{2}; int auxiliary_streams{-1}; std::string tactic_sources{""}; + std::string extra_plugin_lib_paths{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info); + + std::vector custom_op_domain_list; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index aef7db7166..7d13832196 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -8,6 +8,7 @@ #include "tensorrt_provider_factory_creator.h" #include "core/framework/provider_options.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" +#include "core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h" #include using namespace onnxruntime; @@ -23,10 +24,16 @@ struct TensorrtProviderFactory : IExecutionProviderFactory { std::unique_ptr CreateProvider() override; + void GetCustomOpDomainList(std::vector& custom_op_domain_list); + private: TensorrtExecutionProviderInfo info_; }; +void TensorrtProviderFactory::GetCustomOpDomainList(std::vector& custom_op_domain_list) { + custom_op_domain_list = info_.custom_op_domain_list; +} + std::unique_ptr TensorrtProviderFactory::CreateProvider() { return std::make_unique(info_); } @@ -43,6 +50,11 @@ struct Tensorrt_Provider : Provider { TensorrtExecutionProviderInfo info; info.device_id = device_id; info.has_trt_options = false; + + common::Status status = CreateTensorRTCustomOpDomainList(info); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } return std::make_shared(info); } @@ -78,6 +90,13 @@ struct Tensorrt_Provider : Provider { info.builder_optimization_level = options.trt_builder_optimization_level; info.auxiliary_streams = options.trt_auxiliary_streams; info.tactic_sources = options.trt_tactic_sources == nullptr ? "" : options.trt_tactic_sources; + info.extra_plugin_lib_paths = options.trt_extra_plugin_lib_paths == nullptr ? "" : options.trt_extra_plugin_lib_paths; + + common::Status status = CreateTensorRTCustomOpDomainList(info); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + return std::make_shared(info); } @@ -172,6 +191,11 @@ struct Tensorrt_Provider : Provider { return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options); } + void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector& custom_op_domains_ptr) override { + TensorrtProviderFactory* trt_factory = reinterpret_cast(factory); + trt_factory->GetCustomOpDomainList(custom_op_domains_ptr); + } + void Initialize() override { InitializeRegistry(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 6ff5d7136a..5ae8ab8cf7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -508,6 +508,18 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Create Custom Op if EP requests it + std::vector custom_op_domains; + p_exec_provider->GetCustomOpDomainList(custom_op_domains); + + if (!custom_op_domains.empty()) { + if (AddCustomOpDomains(custom_op_domains) != Status::OK()) { + LOGS(*session_logger_, WARNING) << "Can't register custom op domains with ORT for " << provider_type; + } + } +#endif + // if any EPs do not support concurrent calls to Run we add locking around graph execution if (p_exec_provider->ConcurrentRunSupported() == false) { is_concurrent_run_supported_ = false; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 3cf36a1d80..85721250bf 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -19,6 +19,7 @@ #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" +#include "core/framework/framework_provider_common.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" @@ -36,11 +37,6 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE -struct OrtCustomOpDomain { - std::string domain_; - std::vector custom_ops_; -}; - namespace onnxruntime { // forward declarations class CustomRegistry; class Environment; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 6c1d9307c8..e65ba9e06e 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -163,6 +163,22 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator { Node::EdgeConstIterator v_; }; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_name) { + const auto& platform_env = onnxruntime::Env::Default(); + void* library_handle = nullptr; + + ORT_RETURN_IF_ERROR(platform_env.LoadDynamicLibrary(library_name, false, &library_handle)); + if (!library_handle) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load dynamic library ", + onnxruntime::PathToUTF8String(library_name)); + } + + return onnxruntime::Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26436) @@ -1042,9 +1058,14 @@ struct ProviderHostImpl : ProviderHost { #ifdef _WIN32 std::string ToUTF8String(const std::wstring& s) override { return onnxruntime::ToUTF8String(s); } + std::wstring ToWideString(const std::string& s) override { return onnxruntime::ToWideString(s); } #endif ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; +#endif } provider_host_; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) @@ -1286,6 +1307,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_builder_optimization_level = 2; trt_options_converted.trt_auxiliary_streams = -1; trt_options_converted.trt_tactic_sources = ""; + trt_options_converted.trt_extra_plugin_lib_paths = ""; return trt_options_converted; } @@ -1298,6 +1320,10 @@ std::shared_ptr TensorrtProviderFactoryCreator::Creat return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options); } +void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector& custom_op_domains_ptr) { + s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr); +} + std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); } @@ -1457,6 +1483,13 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS } options->provider_factories.push_back(factory); + + std::vector custom_op_domains; + TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains); + for (auto ptr : custom_op_domains) { + options->custom_op_domains_.push_back(ptr); + } + return nullptr; API_IMPL_END } @@ -1481,6 +1514,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In } options->provider_factories.push_back(factory); + + std::vector custom_op_domains; + TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains); + for (auto ptr : custom_op_domains) { + options->custom_op_domains_.push_back(ptr); + } + return nullptr; API_IMPL_END } @@ -1582,6 +1622,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, } options->provider_factories.push_back(factory); + + std::vector custom_op_domains; + TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains); + for (auto ptr : custom_op_domains) { + options->custom_op_domains_.push_back(ptr); + } return nullptr; API_IMPL_END } @@ -1613,6 +1659,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT (*out)->trt_timing_cache_enable = false; (*out)->trt_force_timing_cache = false; (*out)->trt_detailed_build_log = false; + (*out)->trt_extra_plugin_lib_paths = nullptr; return nullptr; #else ORT_UNUSED_PARAMETER(out); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f5d2e02719..c1036e4a9e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -376,6 +376,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 2, -1, + nullptr, nullptr}; for (auto option : it->second) { if (option.first == "device_id") { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index e63c9bbd91..cbadcef09a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -129,6 +129,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device int trt_builder_optimization_level = 2; int trt_auxiliary_streams = -1; std::string trt_tactic_sources = ""; + std::string trt_extra_plugin_lib_paths = ""; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -334,8 +335,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a non-emtpy string.\n"); } + } else if (key == "trt_extra_plugin_lib_paths") { + if (!value.empty()) { + trt_extra_plugin_lib_paths = value; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a non-emtpy string.\n"); + } } else { - ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback'] \n"); + ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_extra_plugin_lib_paths'] \n"); } } OrtTensorRTProviderOptionsV2 tensorrt_options; @@ -367,6 +374,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device tensorrt_options.trt_builder_optimization_level = trt_builder_optimization_level; tensorrt_options.trt_auxiliary_streams = trt_auxiliary_streams; tensorrt_options.trt_tactic_sources = trt_tactic_sources.c_str(); + tensorrt_options.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str(); session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); OrtCUDAProviderOptions cuda_options; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 60f5870f09..9886d9fdd7 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -686,7 +686,7 @@ TEST_P(ModelTest, Run) { OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30, 1, // enable fp16 0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0, - 2, -1, nullptr}; + 2, -1, nullptr, nullptr}; ortso.AppendExecutionProvider_TensorRT_V2(params); } else { OrtTensorRTProviderOptionsV2* ep_option = nullptr; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 75da23ccb5..bbfba5cfe8 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -161,6 +161,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string 0, 2, -1, + nullptr, nullptr}; params.trt_engine_cache_enable = 1; @@ -240,6 +241,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string 0, 2, -1, + nullptr, nullptr}; params.trt_engine_cache_enable = 1; @@ -333,6 +335,77 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded"; } +TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { + std::string model_name = "testdata/trt_plugin_custom_op_test.onnx"; + SessionOptions so; + so.session_logid = "TensorrtExecutionProviderTRTPluginsTest"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + InferenceSession session_object{so, GetEnvironment()}; + onnxruntime::AllocatorManager allocator_manager; + auto cuda_provider = DefaultCudaExecutionProvider(); + cuda_provider->RegisterAllocator(allocator_manager); + auto cpu_allocator = cuda_provider->GetAllocator(OrtMemTypeCPU); + std::vector dims_op_x = {12, 256, 256}; + std::vector values_op_x(1.0f, 786432); // 786432=12*256*256 + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_x); + OrtValue ml_value_y; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_y); + OrtValue ml_value_z; + CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_z); + NameMLValMap feeds; + feeds.insert(std::make_pair("input1", ml_value_x)); + feeds.insert(std::make_pair("input2", ml_value_y)); + feeds.insert(std::make_pair("input3", ml_value_z)); + + // prepare outputs + std::vector output_names; + output_names.push_back("output"); + std::vector fetches; + + OrtTensorRTProviderOptionsV2 params{ + 0, + 0, + nullptr, + 1000, + 1, + 1 << 30, + 0, + 0, + nullptr, + 0, + 0, + 0, + 0, + 0, + nullptr, + 0, + nullptr, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + -1, + nullptr, + nullptr}; + + std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + std::cout << model_name << std::endl; + auto status = session_object.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()); + status = session_object.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); +} + TEST_P(TensorrtExecutionProviderCacheTest, Run) { // GetParam() returns the parameter of following format: // ##cache type##_##input shape type## @@ -411,6 +484,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { 0, 2, -1, + nullptr, nullptr}; if (cache_type.compare("engine") == 0) { diff --git a/onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx b/onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx new file mode 100644 index 0000000000..99ac5a0bd9 --- /dev/null +++ b/onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx @@ -0,0 +1,27 @@ +:œ +ƒ +input1 +input2 +input3outputDisentangledAttention_TRT"DisentangledAttention_TRT* +factormēū= * +span€ : trt.pluginstrt_plugin_custom_opZ +input1 + + +€ +€Z +input2 + + +€ +€Z +input3 + + +€ +€b +output + + +€ +€B \ No newline at end of file diff --git a/onnxruntime/test/testdata/trt_plugin_custom_op_test.py b/onnxruntime/test/testdata/trt_plugin_custom_op_test.py new file mode 100644 index 0000000000..e387e80973 --- /dev/null +++ b/onnxruntime/test/testdata/trt_plugin_custom_op_test.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import TensorProto, helper + + +def generate_model(model_name): + nodes = [ + helper.make_node( + "DisentangledAttention_TRT", + ["input1", "input2", "input3"], + ["output"], + "DisentangledAttention_TRT", + domain="trt.plugins", + factor=0.123, + span=128, + ), + ] + + graph = helper.make_graph( + nodes, + "trt_plugin_custom_op", + [ # input + helper.make_tensor_value_info("input1", TensorProto.FLOAT, [12, 256, 256]), + helper.make_tensor_value_info("input2", TensorProto.FLOAT, [12, 256, 256]), + helper.make_tensor_value_info("input3", TensorProto.FLOAT, [12, 256, 256]), + ], + [ # output + helper.make_tensor_value_info("output", TensorProto.FLOAT, [12, 256, 256]), + ], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + generate_model("trt_plugin_custom_op_test.onnx")