Add TRT plugins support using custom ops (#13847)

This PR makes ORT support TRT plugin using custom ops. ORT TRT can
automatically register all TRT plugins from TRT plugins registry as
custom ops. There is no code change needed for ORT when new TRT plugins
are introduced.

Previous way for ORT to support TRT plugins was using contrib ops, but
there are some concerns about it:

- Contrib ops are shipped as part of the ORT binary by default. TRT
related plugins should not be in the default ORT.
- Contrib ops are designed for internal ops and developed for cpu and
cuda EPs.

Therefore, using custom ops is a good approach to support TRT plugins. 

Followings are the major modifications:

1. Add new `GetCustomOpDomainList` provider api which allows provider to
create its own custom op domain list and ORT can register this domain
list. Provider has the responsibility to free all the custom op domain
instances it created.
2. Move OrtCustomOpDomain struct definition to
framework_provider_common.h since this struct is being used by framework
and EPs now.
3. There are several TRT plugins registered as onnx schema op through
contrib op with onnx domain. In order not to break the old models using
those TRT plugins which were registered with ONNX domain and maintain
backward compatible, we need to keep the old/legacy TRT plugins with
onnx domain. Moving forward, all newly added TRT plugins should be
registered with `trt.plugins` domain.
4. TRT plugin doesn't have an api to get number of inputs/outputs of the
registered plugins, so ORT TRT uses variadic inputs/outputs to bypass
the onnx node validation.
5. Add new trt provider option, `trt_extra_plugin_lib_paths`, user can
specify any extra plugin lib, for example,
`fastertransformer/build/lib/libvit_plugin.so` or
`fastertransformer/build/lib/libvit_plugin.so;fastertransformer/build/lib/libvit_plugin_v2.so`
This commit is contained in:
Chi Lo 2023-04-18 20:24:32 -07:00 committed by GitHub
parent cb83d2b1a9
commit 6115c8fd1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 478 additions and 7 deletions

View file

@ -29,6 +29,7 @@ class Node;
#include "core/framework/allocatormgr.h"
#include "core/framework/func_api.h"
#include "core/framework/provider_options.h"
#include "core/framework/framework_provider_common.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"
@ -146,6 +147,20 @@ class IExecutionProvider {
*/
virtual ProviderOptions GetProviderOptions() const { return {}; }
/**
Get provider specific custom op domain list.
Provider has the responsibility to release OrtCustomOpDomain instances it creates.
NOTE: In the case of ONNX model having EP specific custom nodes and don't want to ask user to register those nodes,
EP might need to a way to register those custom nodes. This API is added for the purpose where EP can use it to
leverage ORT custom op to register those custom nodes with one or more custom op domains.
For example, TensorRT EP uses this API to support TRT plugins where each custom op is mapped to TRT plugin and no
kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as
a role to help pass ONNX model validation.
*/
virtual void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/) const {};
/**
Returns an opaque handle whose exact type varies based on the provider
and is interpreted accordingly by the corresponding kernel implementation.

View file

@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/session/onnxruntime_c_api.h"
struct OrtCustomOpDomain {
std::string domain_;
std::vector<const OrtCustomOp*> custom_ops_;
};

View file

@ -40,4 +40,5 @@ struct OrtTensorRTProviderOptionsV2 {
int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths
};

View file

@ -2460,6 +2460,12 @@ void RegisterContribSchemas() {
}
});
// ORT will not regsiter TRT plugins as contrib ops, instead it will use custom ops handled by TRT EP.
// In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain backward compatible,
// we still keep EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT as legacy code.
// We don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us.
// Moving forward, please create TRT plugin node with "trt.plugins" domain.
static const char* EfficientNMS_TRT_ver1_doc =
R"DOC(Efficient NMS TensorRT Plugin.)DOC";
@ -2662,6 +2668,8 @@ void RegisterContribSchemas() {
propagateShapeFromInputToOutput(ctx, 0, 0);
});
// Please note that we don't need to add new schema definition when a new TRT plugin is introduced, TRT EP will register it as custom op for us.
ONNX_CONTRIB_OPERATOR_SCHEMA(Snpe)
.SetDomain(kMSDomain)
.SinceVersion(1)

View file

@ -713,9 +713,19 @@ void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) {
return g_host->MurmurHash3__x86_128(key, len, seed, out);
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
Status LoadDynamicLibrary(onnxruntime::PathString library_name) {
return g_host->LoadDynamicLibrary(library_name);
}
#endif
#ifdef _WIN32
std::string ToUTF8String(const std::wstring& s) {
return g_host->ToUTF8String(s);
}
std::wstring ToWideString(const std::string& s) {
return g_host->ToWideString(s);
}
#endif
} // namespace onnxruntime

View file

@ -26,6 +26,9 @@ struct Provider {
// Update provider options from key-value string configuration
virtual void UpdateProviderOptions(void* /*provider options to be configured*/, const ProviderOptions& /*key-value string provider options*/){};
// Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates.
virtual void GetCustomOpDomainList(IExecutionProviderFactory* /*pointer to factory instance*/, std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/){};
virtual void Initialize() = 0; // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded
virtual void Shutdown() = 0; // Called right before unloading the shared library

View file

@ -891,9 +891,14 @@ struct ProviderHost {
#ifdef _WIN32
virtual std::string ToUTF8String(const std::wstring& s) = 0;
virtual std::wstring ToWideString(const std::string& s) = 0;
#endif
virtual ProviderHostCPU& GetProviderHostCPU() = 0;
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0;
#endif
};
#if defined(_MSC_VER) && !defined(__clang__)

View file

@ -10,6 +10,7 @@
#include "core/common/safeint.h"
#include "tensorrt_execution_provider.h"
#include "tensorrt_execution_provider_utils.h"
#include "tensorrt_execution_provider_custom_ops.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
#include "core/providers/cuda/gpu_data_transfer.h"
@ -634,6 +635,7 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
if (!external_stream_ && stream_) {
ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_)));
}
ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list);
}
AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const {
@ -724,6 +726,10 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
return Status::OK();
}
void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
custom_op_domain_list = info_.custom_op_domain_list;
}
// Check the graph is the subgraph of control flow op
bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const GraphViewer& graph) const {
if (graph.IsSubgraph()) {

View file

@ -38,6 +38,7 @@ static const std::string kSparsityEnable = "ORT_TENSORRT_SPARSITY_ENABLE";
static const std::string kBuilderOptimizationLevel = "ORT_TENSORRT_BUILDER_OPTIMIZATION_LEVEL";
static const std::string kAuxiliaryStreams = "ORT_TENSORRT_AUXILIARY_STREAMS";
static const std::string kTacticSources = "ORT_TENSORRT_TACTIC_SOURCES";
static const std::string kExtraPluginLibPaths = "ORT_TENSORRT_EXTRA_PLUGIN_LIB_PATHS";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
@ -163,6 +164,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry) const override;
void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const override;
private:
TensorrtExecutionProviderInfo info_;
bool external_stream_ = false;

View file

@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/provider_options.h"
#include "tensorrt_execution_provider_custom_ops.h"
#include "tensorrt_execution_provider.h"
#include <NvInferRuntime.h>
#include <NvInferPlugin.h>
#include <unordered_set>
namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger();
/*
* Create custom op domain list for TRT plugins.
*
* Here, we collect all registered TRT plugins from TRT registry and create custom ops with "trt.plugins" domain.
* Additionally, if users specify extra plugin libraries, TRT EP will load them at runtime which will register those
* plugins to TRT plugin registry and later TRT EP can get them as well.
*
* There are several TRT plugins registered as onnx schema op through contrib op with ONNX domain in the past,
* for example, EfficientNMS_TRT, MultilevelCropAndResize_TRT, PyramidROIAlign_TRT and DisentangledAttention_TRT.
* In order not to break the old models using those TRT plugins which were registered with ONNX domain and maintain
* backward compatible, we need to keep those legacy TRT plugins registered with ONNX domain with contrib ops.
*
* Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin.
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
custom_op_domain->domain_ = "trt.plugins";
// Load any extra TRT plugin library if any.
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
// This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator).
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
if (!extra_plugin_lib_paths.empty()) {
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
std::string lib;
while (std::getline(extra_plugin_libs, lib, ';')) {
auto status = LoadDynamicLibrary(ToPathString(lib));
if (status == Status::OK()) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully load " << lib;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString();
}
}
}
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
initLibNvInferPlugins(&trt_logger, "");
int num_plugin_creator = 0;
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
std::unordered_set<std::string> registered_plugin_names;
for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
std::string plugin_name(plugin_creator->getPluginName());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
// plugin has different versions and we only register once
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
continue;
}
std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
registered_plugin_names.insert(plugin_name);
}
info.custom_op_domain_list.push_back(custom_op_domain.release());
return common::Status::OK();
}
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
if (domain != nullptr) {
for (auto ptr : domain->custom_ops_) {
if (ptr != nullptr) {
delete ptr;
}
}
delete domain;
}
}
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
for (auto ptr : custom_op_domain_list) {
ReleaseTensorRTCustomOpDomain(ptr);
}
}
} // namespace onnxruntime

View file

@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#define ORT_API_MANUAL_INIT
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/shared_library/provider_api.h"
#include "tensorrt_execution_provider_info.h"
using namespace onnxruntime;
namespace onnxruntime {
common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
struct TensorRTCustomKernel {
TensorRTCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream)
: compute_stream_(compute_stream) {
}
void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here.
private:
void* compute_stream_;
};
struct TensorRTCustomOp : Ort::CustomOpBase<TensorRTCustomOp, TensorRTCustomKernel> {
explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); };
const char* GetName() const { return name_; };
void SetName(const char* name) { name_ = name; };
const char* GetExecutionProviderType() const { return provider_; };
size_t GetInputTypeCount() const { return num_inputs_; };
void SetInputTypeCount(size_t num) { num_inputs_ = num; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; };
size_t GetOutputTypeCount() const { return num_outputs_; };
void SetOutputTypeCount(size_t num) { num_outputs_ = num; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; };
private:
const char* provider_{onnxruntime::kTensorrtExecutionProvider};
void* compute_stream_;
const char* name_;
size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input
size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output
};
} // namespace onnxruntime

View file

@ -39,6 +39,7 @@ constexpr const char* kSparsityEnable = "trt_sparsity_enable";
constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level";
constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams";
constexpr const char* kTacticSources = "trt_tactic_sources";
constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths";
} // namespace provider_option_names
} // namespace tensorrt
@ -83,6 +84,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level)
.AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams)
.AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources)
.AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths)
.Parse(options)); // add new provider option here.
return info;
@ -117,6 +119,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.builder_optimization_level)},
{tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.auxiliary_streams)},
{tensorrt::provider_option_names::kTacticSources, MakeStringWithClassicLocale(info.tactic_sources)},
{tensorrt::provider_option_names::kExtraPluginLibPaths, MakeStringWithClassicLocale(info.extra_plugin_lib_paths)},
};
return options;
}

View file

@ -7,7 +7,9 @@
#include "core/framework/ortdevice.h"
#include "core/framework/provider_options.h"
#include "core/framework/framework_provider_common.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/library_handles.h"
namespace onnxruntime {
// Information needed to construct trt execution providers.
@ -41,9 +43,12 @@ struct TensorrtExecutionProviderInfo {
int builder_optimization_level{2};
int auxiliary_streams{-1};
std::string tactic_sources{""};
std::string extra_plugin_lib_paths{""};
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info);
std::vector<OrtCustomOpDomain*> custom_op_domain_list;
};
} // namespace onnxruntime

View file

@ -8,6 +8,7 @@
#include "tensorrt_provider_factory_creator.h"
#include "core/framework/provider_options.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h"
#include <string.h>
using namespace onnxruntime;
@ -23,10 +24,16 @@ struct TensorrtProviderFactory : IExecutionProviderFactory {
std::unique_ptr<IExecutionProvider> CreateProvider() override;
void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
private:
TensorrtExecutionProviderInfo info_;
};
void TensorrtProviderFactory::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
custom_op_domain_list = info_.custom_op_domain_list;
}
std::unique_ptr<IExecutionProvider> TensorrtProviderFactory::CreateProvider() {
return std::make_unique<TensorrtExecutionProvider>(info_);
}
@ -43,6 +50,11 @@ struct Tensorrt_Provider : Provider {
TensorrtExecutionProviderInfo info;
info.device_id = device_id;
info.has_trt_options = false;
common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}
return std::make_shared<TensorrtProviderFactory>(info);
}
@ -78,6 +90,13 @@ struct Tensorrt_Provider : Provider {
info.builder_optimization_level = options.trt_builder_optimization_level;
info.auxiliary_streams = options.trt_auxiliary_streams;
info.tactic_sources = options.trt_tactic_sources == nullptr ? "" : options.trt_tactic_sources;
info.extra_plugin_lib_paths = options.trt_extra_plugin_lib_paths == nullptr ? "" : options.trt_extra_plugin_lib_paths;
common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}
return std::make_shared<TensorrtProviderFactory>(info);
}
@ -172,6 +191,11 @@ struct Tensorrt_Provider : Provider {
return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options);
}
void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) override {
TensorrtProviderFactory* trt_factory = reinterpret_cast<TensorrtProviderFactory*>(factory);
trt_factory->GetCustomOpDomainList(custom_op_domains_ptr);
}
void Initialize() override {
InitializeRegistry();
}

View file

@ -508,6 +508,18 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
}
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// Create Custom Op if EP requests it
std::vector<OrtCustomOpDomain*> custom_op_domains;
p_exec_provider->GetCustomOpDomainList(custom_op_domains);
if (!custom_op_domains.empty()) {
if (AddCustomOpDomains(custom_op_domains) != Status::OK()) {
LOGS(*session_logger_, WARNING) << "Can't register custom op domains with ORT for " << provider_type;
}
}
#endif
// if any EPs do not support concurrent calls to Run we add locking around graph execution
if (p_exec_provider->ConcurrentRunSupported() == false) {
is_concurrent_run_supported_ = false;

View file

@ -19,6 +19,7 @@
#include "core/framework/prepacked_weights_container.h"
#include "core/framework/session_state.h"
#include "core/framework/tuning_results.h"
#include "core/framework/framework_provider_common.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/graph_transformer_mgr.h"
@ -36,11 +37,6 @@ namespace ONNX_NAMESPACE {
class ModelProto;
} // namespace ONNX_NAMESPACE
struct OrtCustomOpDomain {
std::string domain_;
std::vector<const OrtCustomOp*> custom_ops_;
};
namespace onnxruntime { // forward declarations
class CustomRegistry;
class Environment;

View file

@ -163,6 +163,22 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator {
Node::EdgeConstIterator v_;
};
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_name) {
const auto& platform_env = onnxruntime::Env::Default();
void* library_handle = nullptr;
ORT_RETURN_IF_ERROR(platform_env.LoadDynamicLibrary(library_name, false, &library_handle));
if (!library_handle) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load dynamic library ",
onnxruntime::PathToUTF8String(library_name));
}
return onnxruntime::Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26436)
@ -1042,9 +1058,14 @@ struct ProviderHostImpl : ProviderHost {
#ifdef _WIN32
std::string ToUTF8String(const std::wstring& s) override { return onnxruntime::ToUTF8String(s); }
std::wstring ToWideString(const std::string& s) override { return onnxruntime::ToWideString(s); }
#endif
ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); }
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); };
#endif
} provider_host_;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
@ -1286,6 +1307,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti
trt_options_converted.trt_builder_optimization_level = 2;
trt_options_converted.trt_auxiliary_streams = -1;
trt_options_converted.trt_tactic_sources = "";
trt_options_converted.trt_extra_plugin_lib_paths = "";
return trt_options_converted;
}
@ -1298,6 +1320,10 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options);
}
void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) {
s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr);
}
std::shared_ptr<IExecutionProviderFactory> MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) {
return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options);
}
@ -1457,6 +1483,13 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS
}
options->provider_factories.push_back(factory);
std::vector<OrtCustomOpDomain*> custom_op_domains;
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
return nullptr;
API_IMPL_END
}
@ -1481,6 +1514,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In
}
options->provider_factories.push_back(factory);
std::vector<OrtCustomOpDomain*> custom_op_domains;
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
return nullptr;
API_IMPL_END
}
@ -1582,6 +1622,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2,
}
options->provider_factories.push_back(factory);
std::vector<OrtCustomOpDomain*> custom_op_domains;
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
return nullptr;
API_IMPL_END
}
@ -1613,6 +1659,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT
(*out)->trt_timing_cache_enable = false;
(*out)->trt_force_timing_cache = false;
(*out)->trt_detailed_build_log = false;
(*out)->trt_extra_plugin_lib_paths = nullptr;
return nullptr;
#else
ORT_UNUSED_PARAMETER(out);

View file

@ -376,6 +376,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
0,
2,
-1,
nullptr,
nullptr};
for (auto option : it->second) {
if (option.first == "device_id") {

View file

@ -129,6 +129,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
int trt_builder_optimization_level = 2;
int trt_auxiliary_streams = -1;
std::string trt_tactic_sources = "";
std::string trt_extra_plugin_lib_paths = "";
#ifdef _MSC_VER
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
@ -334,8 +335,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a non-emtpy string.\n");
}
} else if (key == "trt_extra_plugin_lib_paths") {
if (!value.empty()) {
trt_extra_plugin_lib_paths = value;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a non-emtpy string.\n");
}
} else {
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback'] \n");
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_extra_plugin_lib_paths'] \n");
}
}
OrtTensorRTProviderOptionsV2 tensorrt_options;
@ -367,6 +374,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_builder_optimization_level = trt_builder_optimization_level;
tensorrt_options.trt_auxiliary_streams = trt_auxiliary_streams;
tensorrt_options.trt_tactic_sources = trt_tactic_sources.c_str();
tensorrt_options.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str();
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);
OrtCUDAProviderOptions cuda_options;

View file

@ -686,7 +686,7 @@ TEST_P(ModelTest, Run) {
OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30,
1, // enable fp16
0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0,
2, -1, nullptr};
2, -1, nullptr, nullptr};
ortso.AppendExecutionProvider_TensorRT_V2(params);
} else {
OrtTensorRTProviderOptionsV2* ep_option = nullptr;

View file

@ -161,6 +161,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string
0,
2,
-1,
nullptr,
nullptr};
params.trt_engine_cache_enable = 1;
@ -240,6 +241,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string
0,
2,
-1,
nullptr,
nullptr};
params.trt_engine_cache_enable = 1;
@ -333,6 +335,77 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) {
ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded";
}
TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
std::string model_name = "testdata/trt_plugin_custom_op_test.onnx";
SessionOptions so;
so.session_logid = "TensorrtExecutionProviderTRTPluginsTest";
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};
onnxruntime::AllocatorManager allocator_manager;
auto cuda_provider = DefaultCudaExecutionProvider();
cuda_provider->RegisterAllocator(allocator_manager);
auto cpu_allocator = cuda_provider->GetAllocator(OrtMemTypeCPU);
std::vector<int64_t> dims_op_x = {12, 256, 256};
std::vector<float> values_op_x(1.0f, 786432); // 786432=12*256*256
OrtValue ml_value_x;
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_y);
OrtValue ml_value_z;
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("input1", ml_value_x));
feeds.insert(std::make_pair("input2", ml_value_y));
feeds.insert(std::make_pair("input3", ml_value_z));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("output");
std::vector<OrtValue> fetches;
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0,
0,
0,
0,
0,
0,
0,
0,
2,
-1,
nullptr,
nullptr};
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
std::cout << model_name << std::endl;
auto status = session_object.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());
}
TEST_P(TensorrtExecutionProviderCacheTest, Run) {
// GetParam() returns the parameter of following format:
// ##cache type##_##input shape type##
@ -411,6 +484,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) {
0,
2,
-1,
nullptr,
nullptr};
if (cache_type.compare("engine") == 0) {

View file

@ -0,0 +1,27 @@

ƒ
input1
input2
input3outputDisentangledAttention_TRT"DisentangledAttention_TRT*
factormçû= *
span : trt.pluginstrt_plugin_custom_opZ
input1



Z
input2



Z
input3



b
output



B

View file

@ -0,0 +1,40 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import TensorProto, helper
def generate_model(model_name):
nodes = [
helper.make_node(
"DisentangledAttention_TRT",
["input1", "input2", "input3"],
["output"],
"DisentangledAttention_TRT",
domain="trt.plugins",
factor=0.123,
span=128,
),
]
graph = helper.make_graph(
nodes,
"trt_plugin_custom_op",
[ # input
helper.make_tensor_value_info("input1", TensorProto.FLOAT, [12, 256, 256]),
helper.make_tensor_value_info("input2", TensorProto.FLOAT, [12, 256, 256]),
helper.make_tensor_value_info("input3", TensorProto.FLOAT, [12, 256, 256]),
],
[ # output
helper.make_tensor_value_info("output", TensorProto.FLOAT, [12, 256, 256]),
],
)
model = helper.make_model(graph)
onnx.save(model, model_name)
if __name__ == "__main__":
generate_model("trt_plugin_custom_op_test.onnx")