mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add TRT plugins support using custom ops (#13847)
This PR makes ORT support TRT plugin using custom ops. ORT TRT can automatically register all TRT plugins from TRT plugins registry as custom ops. There is no code change needed for ORT when new TRT plugins are introduced. Previous way for ORT to support TRT plugins was using contrib ops, but there are some concerns about it: - Contrib ops are shipped as part of the ORT binary by default. TRT related plugins should not be in the default ORT. - Contrib ops are designed for internal ops and developed for cpu and cuda EPs. Therefore, using custom ops is a good approach to support TRT plugins. Followings are the major modifications: 1. Add new `GetCustomOpDomainList` provider api which allows provider to create its own custom op domain list and ORT can register this domain list. Provider has the responsibility to free all the custom op domain instances it created. 2. Move OrtCustomOpDomain struct definition to framework_provider_common.h since this struct is being used by framework and EPs now. 3. There are several TRT plugins registered as onnx schema op through contrib op with onnx domain. In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain backward compatible, we need to keep the old/legacy TRT plugins with onnx domain. Moving forward, all newly added TRT plugins should be registered with `trt.plugins` domain. 4. TRT plugin doesn't have an api to get number of inputs/outputs of the registered plugins, so ORT TRT uses variadic inputs/outputs to bypass the onnx node validation. 5. Add new trt provider option, `trt_extra_plugin_lib_paths`, user can specify any extra plugin lib, for example, `fastertransformer/build/lib/libvit_plugin.so` or `fastertransformer/build/lib/libvit_plugin.so;fastertransformer/build/lib/libvit_plugin_v2.so`
This commit is contained in:
parent
cb83d2b1a9
commit
6115c8fd1f
23 changed files with 478 additions and 7 deletions
|
|
@ -29,6 +29,7 @@ class Node;
|
|||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/func_api.h"
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "core/framework/framework_provider_common.h"
|
||||
#include "core/framework/stream_handles.h"
|
||||
#include "core/framework/tuning_context.h"
|
||||
|
||||
|
|
@ -146,6 +147,20 @@ class IExecutionProvider {
|
|||
*/
|
||||
virtual ProviderOptions GetProviderOptions() const { return {}; }
|
||||
|
||||
/**
|
||||
Get provider specific custom op domain list.
|
||||
Provider has the responsibility to release OrtCustomOpDomain instances it creates.
|
||||
|
||||
NOTE: In the case of ONNX model having EP specific custom nodes and don't want to ask user to register those nodes,
|
||||
EP might need to a way to register those custom nodes. This API is added for the purpose where EP can use it to
|
||||
leverage ORT custom op to register those custom nodes with one or more custom op domains.
|
||||
|
||||
For example, TensorRT EP uses this API to support TRT plugins where each custom op is mapped to TRT plugin and no
|
||||
kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as
|
||||
a role to help pass ONNX model validation.
|
||||
*/
|
||||
virtual void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/) const {};
|
||||
|
||||
/**
|
||||
Returns an opaque handle whose exact type varies based on the provider
|
||||
and is interpreted accordingly by the corresponding kernel implementation.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
struct OrtCustomOpDomain {
|
||||
std::string domain_;
|
||||
std::vector<const OrtCustomOp*> custom_ops_;
|
||||
};
|
||||
|
|
@ -40,4 +40,5 @@ struct OrtTensorRTProviderOptionsV2 {
|
|||
int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
|
||||
const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
|
||||
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
|
||||
const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2460,6 +2460,12 @@ void RegisterContribSchemas() {
|
|||
}
|
||||
});
|
||||
|
||||
// ORT will not regsiter TRT plugins as contrib ops, instead it will use custom ops handled by TRT EP.
|
||||
// In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain backward compatible,
|
||||
// we still keep EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT as legacy code.
|
||||
// We don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us.
|
||||
// Moving forward, please create TRT plugin node with "trt.plugins" domain.
|
||||
|
||||
static const char* EfficientNMS_TRT_ver1_doc =
|
||||
R"DOC(Efficient NMS TensorRT Plugin.)DOC";
|
||||
|
||||
|
|
@ -2662,6 +2668,8 @@ void RegisterContribSchemas() {
|
|||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
});
|
||||
|
||||
// Please note that we don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us.
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Snpe)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
|
|
@ -713,9 +713,19 @@ void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) {
|
|||
return g_host->MurmurHash3__x86_128(key, len, seed, out);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
Status LoadDynamicLibrary(onnxruntime::PathString library_name) {
|
||||
return g_host->LoadDynamicLibrary(library_name);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
std::string ToUTF8String(const std::wstring& s) {
|
||||
return g_host->ToUTF8String(s);
|
||||
}
|
||||
|
||||
std::wstring ToWideString(const std::string& s) {
|
||||
return g_host->ToWideString(s);
|
||||
}
|
||||
#endif
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ struct Provider {
|
|||
// Update provider options from key-value string configuration
|
||||
virtual void UpdateProviderOptions(void* /*provider options to be configured*/, const ProviderOptions& /*key-value string provider options*/){};
|
||||
|
||||
// Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates.
|
||||
virtual void GetCustomOpDomainList(IExecutionProviderFactory* /*pointer to factory instance*/, std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/){};
|
||||
|
||||
virtual void Initialize() = 0; // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded
|
||||
virtual void Shutdown() = 0; // Called right before unloading the shared library
|
||||
|
||||
|
|
|
|||
|
|
@ -891,9 +891,14 @@ struct ProviderHost {
|
|||
|
||||
#ifdef _WIN32
|
||||
virtual std::string ToUTF8String(const std::wstring& s) = 0;
|
||||
virtual std::wstring ToWideString(const std::string& s) = 0;
|
||||
#endif
|
||||
|
||||
virtual ProviderHostCPU& GetProviderHostCPU() = 0;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include "core/common/safeint.h"
|
||||
#include "tensorrt_execution_provider.h"
|
||||
#include "tensorrt_execution_provider_utils.h"
|
||||
#include "tensorrt_execution_provider_custom_ops.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_call.h"
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/gpu_data_transfer.h"
|
||||
|
|
@ -634,6 +635,7 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
|
|||
if (!external_stream_ && stream_) {
|
||||
ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_)));
|
||||
}
|
||||
ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list);
|
||||
}
|
||||
|
||||
AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const {
|
||||
|
|
@ -724,6 +726,10 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
|
||||
custom_op_domain_list = info_.custom_op_domain_list;
|
||||
}
|
||||
|
||||
// Check the graph is the subgraph of control flow op
|
||||
bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const GraphViewer& graph) const {
|
||||
if (graph.IsSubgraph()) {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE";
|
|||
static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL";
|
||||
static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS";
|
||||
static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES";
|
||||
static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS";
|
||||
// Old env variable for backward compatibility
|
||||
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
|
||||
} // namespace tensorrt_env_vars
|
||||
|
|
@ -163,6 +164,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
|
||||
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override;
|
||||
|
||||
void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const override;
|
||||
|
||||
private:
|
||||
TensorrtExecutionProviderInfo info_;
|
||||
bool external_stream_ = false;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "tensorrt_execution_provider_custom_ops.h"
|
||||
#include "tensorrt_execution_provider.h"
|
||||
#include <NvInferRuntime.h>
|
||||
#include <NvInferPlugin.h>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace onnxruntime {
|
||||
extern TensorrtLogger& GetTensorrtLogger();
|
||||
|
||||
/*
|
||||
* Create custom op domain list for TRT plugins.
|
||||
*
|
||||
* Here, we collect all registered TRT plugins from TRT registry and create custom ops with "trt.plugins" domain.
|
||||
* Additionally, if users specify extra plugin libraries, TRT EP will load them at runtime which will register those
|
||||
* plugins to TRT plugin registry and later TRT EP can get them as well.
|
||||
*
|
||||
* There are several TRT plugins registered as onnx schema op through contrib op with ONNX domain in the past,
|
||||
* for example, EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT.
|
||||
* In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain
|
||||
* backward compatible, we need to keep those legacy TRT plugins registered with ONNX domain with contrib ops.
|
||||
*
|
||||
* Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin.
|
||||
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
|
||||
*/
|
||||
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
|
||||
std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
|
||||
custom_op_domain->domain_ = "trt.plugins";
|
||||
|
||||
// Load any extra TRT plugin library if any.
|
||||
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
|
||||
// This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator).
|
||||
std::string extra_plugin_lib_paths{""};
|
||||
if (info.has_trt_options) {
|
||||
if (!info.extra_plugin_lib_paths.empty()) {
|
||||
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
|
||||
}
|
||||
} else {
|
||||
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
|
||||
if (!extra_plugin_lib_paths_env.empty()) {
|
||||
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
|
||||
}
|
||||
}
|
||||
|
||||
// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
|
||||
if (!extra_plugin_lib_paths.empty()) {
|
||||
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
|
||||
std::string lib;
|
||||
while (std::getline(extra_plugin_libs, lib, ';')) {
|
||||
auto status = LoadDynamicLibrary(ToPathString(lib));
|
||||
if (status == Status::OK()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully load " << lib;
|
||||
} else {
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get all registered TRT plugins from registry
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
|
||||
TensorrtLogger trt_logger = GetTensorrtLogger();
|
||||
initLibNvInferPlugins(&trt_logger, "");
|
||||
|
||||
int num_plugin_creator = 0;
|
||||
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
|
||||
std::unordered_set<std::string> registered_plugin_names;
|
||||
|
||||
for (int i = 0; i < num_plugin_creator; i++) {
|
||||
auto plugin_creator = plugin_creators[i];
|
||||
std::string plugin_name(plugin_creator->getPluginName());
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
|
||||
|
||||
// plugin has different versions and we only register once
|
||||
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
|
||||
trt_custom_op->SetName(plugin_creator->getPluginName());
|
||||
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
|
||||
registered_plugin_names.insert(plugin_name);
|
||||
}
|
||||
info.custom_op_domain_list.push_back(custom_op_domain.release());
|
||||
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
||||
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
|
||||
if (domain != nullptr) {
|
||||
for (auto ptr : domain->custom_ops_) {
|
||||
if (ptr != nullptr) {
|
||||
delete ptr;
|
||||
}
|
||||
}
|
||||
delete domain;
|
||||
}
|
||||
}
|
||||
|
||||
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
|
||||
for (auto ptr : custom_op_domain_list) {
|
||||
ReleaseTensorRTCustomOpDomain(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "tensorrt_execution_provider_info.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
|
||||
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
|
||||
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
|
||||
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
|
||||
|
||||
struct TensorRTCustomKernel {
|
||||
TensorRTCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream)
|
||||
: compute_stream_(compute_stream) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here.
|
||||
|
||||
private:
|
||||
void* compute_stream_;
|
||||
};
|
||||
|
||||
struct TensorRTCustomOp : Ort::CustomOpBase<TensorRTCustomOp, TensorRTCustomKernel> {
|
||||
explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
|
||||
|
||||
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); };
|
||||
|
||||
const char* GetName() const { return name_; };
|
||||
|
||||
void SetName(const char* name) { name_ = name; };
|
||||
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return num_inputs_; };
|
||||
|
||||
void SetInputTypeCount(size_t num) { num_inputs_ = num; };
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return num_outputs_; };
|
||||
|
||||
void SetOutputTypeCount(size_t num) { num_outputs_ = num; };
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; };
|
||||
|
||||
private:
|
||||
const char* provider_{onnxruntime::kTensorrtExecutionProvider};
|
||||
void* compute_stream_;
|
||||
const char* name_;
|
||||
size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input
|
||||
size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -39,6 +39,7 @@ constexpr const char* kSparsityEnable = "trt_sparsity_enable";
|
|||
constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level";
|
||||
constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams";
|
||||
constexpr const char* kTacticSources = "trt_tactic_sources";
|
||||
constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths";
|
||||
} // namespace provider_option_names
|
||||
} // namespace tensorrt
|
||||
|
||||
|
|
@ -83,6 +84,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
|
|||
.AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths)
|
||||
.Parse(options)); // add new provider option here.
|
||||
|
||||
return info;
|
||||
|
|
@ -117,6 +119,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
|
|||
{tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)},
|
||||
{tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)},
|
||||
{tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)},
|
||||
{tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)},
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@
|
|||
|
||||
#include "core/framework/ortdevice.h"
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "core/framework/framework_provider_common.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/framework/library_handles.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
// Information needed to construct trt execution providers.
|
||||
|
|
@ -41,9 +43,12 @@ struct TensorrtExecutionProviderInfo {
|
|||
int builder_optimization_level{2};
|
||||
int auxiliary_streams{-1};
|
||||
std::string tactic_sources{""};
|
||||
std::string extra_plugin_lib_paths{""};
|
||||
|
||||
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
|
||||
static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info);
|
||||
|
||||
std::vector<OrtCustomOpDomain*> custom_op_domain_list;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "tensorrt_provider_factory_creator.h"
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
#include "core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h"
|
||||
#include <string.h>
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
|
@ -23,10 +24,16 @@ struct TensorrtProviderFactory : IExecutionProviderFactory {
|
|||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
||||
void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
|
||||
|
||||
private:
|
||||
TensorrtExecutionProviderInfo info_;
|
||||
};
|
||||
|
||||
void TensorrtProviderFactory::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
|
||||
custom_op_domain_list = info_.custom_op_domain_list;
|
||||
}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> TensorrtProviderFactory::CreateProvider() {
|
||||
return std::make_unique<TensorrtExecutionProvider>(info_);
|
||||
}
|
||||
|
|
@ -43,6 +50,11 @@ struct Tensorrt_Provider : Provider {
|
|||
TensorrtExecutionProviderInfo info;
|
||||
info.device_id = device_id;
|
||||
info.has_trt_options = false;
|
||||
|
||||
common::Status status = CreateTensorRTCustomOpDomainList(info);
|
||||
if (!status.IsOK()) {
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
|
||||
}
|
||||
return std::make_shared<TensorrtProviderFactory>(info);
|
||||
}
|
||||
|
||||
|
|
@ -78,6 +90,13 @@ struct Tensorrt_Provider : Provider {
|
|||
info.builder_optimization_level = options.trt_builder_optimization_level;
|
||||
info.auxiliary_streams = options.trt_auxiliary_streams;
|
||||
info.tactic_sources = options.trt_tactic_sources == nullptr ? "" : options.trt_tactic_sources;
|
||||
info.extra_plugin_lib_paths = options.trt_extra_plugin_lib_paths == nullptr ? "" : options.trt_extra_plugin_lib_paths;
|
||||
|
||||
common::Status status = CreateTensorRTCustomOpDomainList(info);
|
||||
if (!status.IsOK()) {
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
|
||||
}
|
||||
|
||||
return std::make_shared<TensorrtProviderFactory>(info);
|
||||
}
|
||||
|
||||
|
|
@ -172,6 +191,11 @@ struct Tensorrt_Provider : Provider {
|
|||
return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options);
|
||||
}
|
||||
|
||||
void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) override {
|
||||
TensorrtProviderFactory* trt_factory = reinterpret_cast<TensorrtProviderFactory*>(factory);
|
||||
trt_factory->GetCustomOpDomainList(custom_op_domains_ptr);
|
||||
}
|
||||
|
||||
void Initialize() override {
|
||||
InitializeRegistry();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -508,6 +508,18 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
|
|||
}
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
// Create Custom Op if EP requests it
|
||||
std::vector<OrtCustomOpDomain*> custom_op_domains;
|
||||
p_exec_provider->GetCustomOpDomainList(custom_op_domains);
|
||||
|
||||
if (!custom_op_domains.empty()) {
|
||||
if (AddCustomOpDomains(custom_op_domains) != Status::OK()) {
|
||||
LOGS(*session_logger_, WARNING) << "Can't register custom op domains with ORT for " << provider_type;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// if any EPs do not support concurrent calls to Run we add locking around graph execution
|
||||
if (p_exec_provider->ConcurrentRunSupported() == false) {
|
||||
is_concurrent_run_supported_ = false;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include "core/framework/prepacked_weights_container.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tuning_results.h"
|
||||
#include "core/framework/framework_provider_common.h"
|
||||
#include "core/graph/basic_types.h"
|
||||
#include "core/optimizer/graph_transformer_level.h"
|
||||
#include "core/optimizer/graph_transformer_mgr.h"
|
||||
|
|
@ -36,11 +37,6 @@ namespace ONNX_NAMESPACE {
|
|||
class ModelProto;
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
||||
struct OrtCustomOpDomain {
|
||||
std::string domain_;
|
||||
std::vector<const OrtCustomOp*> custom_ops_;
|
||||
};
|
||||
|
||||
namespace onnxruntime { // forward declarations
|
||||
class CustomRegistry;
|
||||
class Environment;
|
||||
|
|
|
|||
|
|
@ -163,6 +163,22 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator {
|
|||
|
||||
Node::EdgeConstIterator v_;
|
||||
};
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_name) {
|
||||
const auto& platform_env = onnxruntime::Env::Default();
|
||||
void* library_handle = nullptr;
|
||||
|
||||
ORT_RETURN_IF_ERROR(platform_env.LoadDynamicLibrary(library_name, false, &library_handle));
|
||||
if (!library_handle) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load dynamic library ",
|
||||
onnxruntime::PathToUTF8String(library_name));
|
||||
}
|
||||
|
||||
return onnxruntime::Status::OK();
|
||||
}
|
||||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26436)
|
||||
|
|
@ -1042,9 +1058,14 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
#ifdef _WIN32
|
||||
std::string ToUTF8String(const std::wstring& s) override { return onnxruntime::ToUTF8String(s); }
|
||||
std::wstring ToWideString(const std::string& s) override { return onnxruntime::ToWideString(s); }
|
||||
#endif
|
||||
|
||||
ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); }
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); };
|
||||
#endif
|
||||
} provider_host_;
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
|
|
@ -1286,6 +1307,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti
|
|||
trt_options_converted.trt_builder_optimization_level = 2;
|
||||
trt_options_converted.trt_auxiliary_streams = -1;
|
||||
trt_options_converted.trt_tactic_sources = "";
|
||||
trt_options_converted.trt_extra_plugin_lib_paths = "";
|
||||
return trt_options_converted;
|
||||
}
|
||||
|
||||
|
|
@ -1298,6 +1320,10 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
|
|||
return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options);
|
||||
}
|
||||
|
||||
void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) {
|
||||
s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) {
|
||||
return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options);
|
||||
}
|
||||
|
|
@ -1457,6 +1483,13 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS
|
|||
}
|
||||
|
||||
options->provider_factories.push_back(factory);
|
||||
|
||||
std::vector<OrtCustomOpDomain*> custom_op_domains;
|
||||
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
|
||||
for (auto ptr : custom_op_domains) {
|
||||
options->custom_op_domains_.push_back(ptr);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
|
@ -1481,6 +1514,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In
|
|||
}
|
||||
|
||||
options->provider_factories.push_back(factory);
|
||||
|
||||
std::vector<OrtCustomOpDomain*> custom_op_domains;
|
||||
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
|
||||
for (auto ptr : custom_op_domains) {
|
||||
options->custom_op_domains_.push_back(ptr);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
|
@ -1582,6 +1622,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2,
|
|||
}
|
||||
|
||||
options->provider_factories.push_back(factory);
|
||||
|
||||
std::vector<OrtCustomOpDomain*> custom_op_domains;
|
||||
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
|
||||
for (auto ptr : custom_op_domains) {
|
||||
options->custom_op_domains_.push_back(ptr);
|
||||
}
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
|
@ -1613,6 +1659,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT
|
|||
(*out)->trt_timing_cache_enable = false;
|
||||
(*out)->trt_force_timing_cache = false;
|
||||
(*out)->trt_detailed_build_log = false;
|
||||
(*out)->trt_extra_plugin_lib_paths = nullptr;
|
||||
return nullptr;
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(out);
|
||||
|
|
|
|||
|
|
@ -376,6 +376,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
0,
|
||||
2,
|
||||
-1,
|
||||
nullptr,
|
||||
nullptr};
|
||||
for (auto option : it->second) {
|
||||
if (option.first == "device_id") {
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
int trt_builder_optimization_level = 2;
|
||||
int trt_auxiliary_streams = -1;
|
||||
std::string trt_tactic_sources = "";
|
||||
std::string trt_extra_plugin_lib_paths = "";
|
||||
|
||||
#ifdef _MSC_VER
|
||||
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
|
||||
|
|
@ -334,8 +335,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
} else {
|
||||
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a non-emtpy string.\n");
|
||||
}
|
||||
} else if (key == "trt_extra_plugin_lib_paths") {
|
||||
if (!value.empty()) {
|
||||
trt_extra_plugin_lib_paths = value;
|
||||
} else {
|
||||
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a non-emtpy string.\n");
|
||||
}
|
||||
} else {
|
||||
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback'] \n");
|
||||
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_extra_plugin_lib_paths'] \n");
|
||||
}
|
||||
}
|
||||
OrtTensorRTProviderOptionsV2 tensorrt_options;
|
||||
|
|
@ -367,6 +374,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
tensorrt_options.trt_builder_optimization_level = trt_builder_optimization_level;
|
||||
tensorrt_options.trt_auxiliary_streams = trt_auxiliary_streams;
|
||||
tensorrt_options.trt_tactic_sources = trt_tactic_sources.c_str();
|
||||
tensorrt_options.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str();
|
||||
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);
|
||||
|
||||
OrtCUDAProviderOptions cuda_options;
|
||||
|
|
|
|||
|
|
@ -686,7 +686,7 @@ TEST_P(ModelTest, Run) {
|
|||
OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30,
|
||||
1, // enable fp16
|
||||
0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
2, -1, nullptr};
|
||||
2, -1, nullptr, nullptr};
|
||||
ortso.AppendExecutionProvider_TensorRT_V2(params);
|
||||
} else {
|
||||
OrtTensorRTProviderOptionsV2* ep_option = nullptr;
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string
|
|||
0,
|
||||
2,
|
||||
-1,
|
||||
nullptr,
|
||||
nullptr};
|
||||
|
||||
params.trt_engine_cache_enable = 1;
|
||||
|
|
@ -240,6 +241,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string
|
|||
0,
|
||||
2,
|
||||
-1,
|
||||
nullptr,
|
||||
nullptr};
|
||||
|
||||
params.trt_engine_cache_enable = 1;
|
||||
|
|
@ -333,6 +335,77 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) {
|
|||
ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded";
|
||||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
|
||||
std::string model_name = "testdata/trt_plugin_custom_op_test.onnx";
|
||||
SessionOptions so;
|
||||
so.session_logid = "TensorrtExecutionProviderTRTPluginsTest";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
onnxruntime::AllocatorManager allocator_manager;
|
||||
auto cuda_provider = DefaultCudaExecutionProvider();
|
||||
cuda_provider->RegisterAllocator(allocator_manager);
|
||||
auto cpu_allocator = cuda_provider->GetAllocator(OrtMemTypeCPU);
|
||||
std::vector<int64_t> dims_op_x = {12, 256, 256};
|
||||
std::vector<float> values_op_x(1.0f, 786432); // 786432=12*256*256
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("input1", ml_value_x));
|
||||
feeds.insert(std::make_pair("input2", ml_value_y));
|
||||
feeds.insert(std::make_pair("input3", ml_value_z));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("output");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
OrtTensorRTProviderOptionsV2 params{
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
1000,
|
||||
1,
|
||||
1 << 30,
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
0,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
2,
|
||||
-1,
|
||||
nullptr,
|
||||
nullptr};
|
||||
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(¶ms);
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
std::cout << model_name << std::endl;
|
||||
auto status = session_object.Load(model_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
}
|
||||
|
||||
TEST_P(TensorrtExecutionProviderCacheTest, Run) {
|
||||
// GetParam() returns the parameter of following format:
|
||||
// ##cache type##_##input shape type##
|
||||
|
|
@ -411,6 +484,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) {
|
|||
0,
|
||||
2,
|
||||
-1,
|
||||
nullptr,
|
||||
nullptr};
|
||||
|
||||
if (cache_type.compare("engine") == 0) {
|
||||
|
|
|
|||
27
onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx
vendored
Normal file
27
onnxruntime/test/testdata/trt_plugin_custom_op_test.onnx
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
:œ
|
||||
ƒ
|
||||
input1
|
||||
input2
|
||||
input3outputDisentangledAttention_TRT"DisentangledAttention_TRT*
|
||||
factormçû= *
|
||||
span€ :trt.pluginstrt_plugin_custom_opZ
|
||||
input1
|
||||
|
||||
|
||||
€
|
||||
€Z
|
||||
input2
|
||||
|
||||
|
||||
€
|
||||
€Z
|
||||
input3
|
||||
|
||||
|
||||
€
|
||||
€b
|
||||
output
|
||||
|
||||
|
||||
€
|
||||
€B
|
||||
40
onnxruntime/test/testdata/trt_plugin_custom_op_test.py
vendored
Normal file
40
onnxruntime/test/testdata/trt_plugin_custom_op_test.py
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
|
||||
def generate_model(model_name):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"DisentangledAttention_TRT",
|
||||
["input1", "input2", "input3"],
|
||||
["output"],
|
||||
"DisentangledAttention_TRT",
|
||||
domain="trt.plugins",
|
||||
factor=0.123,
|
||||
span=128,
|
||||
),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"trt_plugin_custom_op",
|
||||
[ # input
|
||||
helper.make_tensor_value_info("input1", TensorProto.FLOAT, [12, 256, 256]),
|
||||
helper.make_tensor_value_info("input2", TensorProto.FLOAT, [12, 256, 256]),
|
||||
helper.make_tensor_value_info("input3", TensorProto.FLOAT, [12, 256, 256]),
|
||||
],
|
||||
[ # output
|
||||
helper.make_tensor_value_info("output", TensorProto.FLOAT, [12, 256, 256]),
|
||||
],
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_model("trt_plugin_custom_op_test.onnx")
|
||||
Loading…
Reference in a new issue