[TensorRT EP] Refactor OrtTensorRTProviderOptions initialization and make it easy to add new field (#17617)

Two major modifications of this PR:

1. Refactor OrtTensorRTProviderOptions initialization and make it easy
to add new field.
2. Make Python API capable of using TensorRT plugins by adding new
Python binding api `register_tensorrt_plugins_as_custom_ops`. (It needs
to register ep's custom op domain before model load. For C++ API, it's
slightly different, when calling
SessionOptionsAppendExecutionProvider_TensorRT_XX, it appends cutom op
domain to session option. Later ORT can register custom op domain from
session option before model loading)
This commit is contained in:
Chi Lo 2023-10-06 14:12:20 -07:00 committed by GitHub
parent 6ea493571e
commit 569876fb16
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 452 additions and 492 deletions

View file

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
#ifdef __cplusplus
}
#endif

View file

@ -11,38 +11,38 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
int device_id; // cuda device id.
int has_user_compute_stream; // indicator of user specified CUDA compute stream.
void* user_compute_stream; // user specified CUDA compute stream.
int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size; // maximum workspace size for TensorRT.
int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true
int trt_dla_core; // DLA core number. Default 0
int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true
int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true
const char* trt_engine_cache_path; // specify engine cache path
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_context_memory_sharing_enable; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true
int trt_layer_norm_fp32_fallback; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true
int trt_timing_cache_enable; // enable TensorRT timing cache. Default 0 = false, nonzero = true
int trt_force_timing_cache; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true
int trt_detailed_build_log; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true
int trt_build_heuristics_enable; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true
int trt_sparsity_enable; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true
int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5]
int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths
const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT
int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT.
int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name.
int trt_int8_use_native_calibration_table{0}; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
int trt_dla_enable{0}; // enable DLA. Default 0 = false, nonzero = true
int trt_dla_core{0}; // DLA core number. Default 0
int trt_dump_subgraphs{0}; // dump TRT subgraph. Default 0 = false, nonzero = true
int trt_engine_cache_enable{0}; // enable engine caching. Default 0 = false, nonzero = true
const char* trt_engine_cache_path{nullptr}; // specify engine cache path
int trt_engine_decryption_enable{0}; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path{nullptr}; // specify engine decryption library path
int trt_force_sequential_engine_build{0}; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_context_memory_sharing_enable{0}; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true
int trt_layer_norm_fp32_fallback{0}; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true
int trt_timing_cache_enable{0}; // enable TensorRT timing cache. Default 0 = false, nonzero = true
int trt_force_timing_cache{0}; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true
int trt_detailed_build_log{0}; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true
int trt_build_heuristics_enable{0}; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true
int trt_sparsity_enable{0}; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true
int trt_builder_optimization_level{3}; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5]
int trt_auxiliary_streams{-1}; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
const char* trt_tactic_sources{nullptr}; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
const char* trt_extra_plugin_lib_paths{nullptr}; // specify extra TensorRT plugin library paths
const char* trt_profile_min_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
};

View file

@ -4572,6 +4572,14 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessio
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena);
/*
* This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality
* This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists
*
* \param device_id CUDA device id, starts from zero.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
#ifdef __cplusplus
}
#endif

View file

@ -19,7 +19,6 @@
#include "onnxruntime/core/providers/nnapi/nnapi_provider_factory.h"
#include "onnxruntime/core/providers/tvm/tvm_provider_factory.h"
#include "onnxruntime/core/providers/openvino/openvino_provider_factory.h"
#include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h"
#include "onnxruntime/core/providers/acl/acl_provider_factory.h"
#include "onnxruntime/core/providers/armnn/armnn_provider_factory.h"
#include "onnxruntime/core/providers/coreml/coreml_provider_factory.h"

View file

@ -16,7 +16,6 @@
#include "core/providers/dml/dml_provider_factory.h"
#endif
#ifdef USE_TENSORRT
#include "core/providers/tensorrt/tensorrt_provider_factory.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#endif
#ifdef USE_COREML

View file

@ -26,13 +26,63 @@ extern TensorrtLogger& GetTensorrtLogger();
* Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin.
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
custom_op_domain->domain_ = "trt.plugins";
// Load any extra TRT plugin library if any.
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
// This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator).
// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
static bool is_loaded = false;
if (!extra_plugin_lib_paths.empty() && !is_loaded) {
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
std::string lib;
while (std::getline(extra_plugin_libs, lib, ';')) {
auto status = LoadDynamicLibrary(ToPathString(lib));
if (status == Status::OK()) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully load " << lib;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString();
}
}
is_loaded = true;
}
try {
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
initLibNvInferPlugins(&trt_logger, "");
int num_plugin_creator = 0;
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
std::unordered_set<std::string> registered_plugin_names;
for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
std::string plugin_name(plugin_creator->getPluginName());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
// plugin has different versions and we only register once
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
continue;
}
std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
registered_plugin_names.insert(plugin_name);
}
domain_list.push_back(custom_op_domain.release());
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}
return Status::OK();
}
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
std::vector<OrtCustomOpDomain*> domain_list;
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
@ -44,48 +94,11 @@ common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& i
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
if (!extra_plugin_lib_paths.empty()) {
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
std::string lib;
while (std::getline(extra_plugin_libs, lib, ';')) {
auto status = LoadDynamicLibrary(ToPathString(lib));
if (status == Status::OK()) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully load " << lib;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString();
}
}
auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!domain_list.empty()) {
info.custom_op_domain_list = domain_list;
}
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
initLibNvInferPlugins(&trt_logger, "");
int num_plugin_creator = 0;
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
std::unordered_set<std::string> registered_plugin_names;
for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
std::string plugin_name(plugin_creator->getPluginName());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
// plugin has different versions and we only register once
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
continue;
}
std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
registered_plugin_names.insert(plugin_name);
}
info.custom_op_domain_list.push_back(custom_op_domain.release());
return common::Status::OK();
return Status::OK();
}
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {

View file

@ -13,6 +13,7 @@ using namespace onnxruntime;
namespace onnxruntime {
common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths);
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);

View file

@ -186,4 +186,211 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
};
return options;
}
/**
* Update OrtTensorRTProviderOptionsV2 instance with ProviderOptions (map of string-based key-value pairs)
*
* Please note that it will reset the OrtTensorRTProviderOptionsV2 instance first and then set up the provided provider options
* See TensorrtExecutionProviderInfo::FromProviderOptions() for more details. This function will be called by the C API UpdateTensorRTProviderOptions() also.
*
* \param provider_options - a pointer to OrtTensorRTProviderOptionsV2 instance
* \param options - a reference to ProviderOptions instance
* \param string_copy - if it's true, it uses strncpy() to copy 'provider option' string from ProviderOptions instance to where the 'provider option' const char pointer in OrtTensorRTProviderOptionsV2 instance points to.
* it it's false, it only saves the pointer and no strncpy().
*
* Note: If there is strncpy involved, please remember to deallocate or simply call C API ReleaseTensorRTProviderOptions.
*/
void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) {
if (provider_options == nullptr) {
return;
}
TensorrtExecutionProviderInfo internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options);
auto& trt_provider_options_v2 = *reinterpret_cast<OrtTensorRTProviderOptionsV2*>(provider_options);
trt_provider_options_v2.device_id = internal_options.device_id;
// The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well
// We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options
if (options.find("has_user_compute_stream") != options.end()) {
trt_provider_options_v2.has_user_compute_stream = internal_options.has_user_compute_stream;
}
trt_provider_options_v2.trt_max_partition_iterations = internal_options.max_partition_iterations;
trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size;
trt_provider_options_v2.trt_max_workspace_size = internal_options.max_workspace_size;
trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable;
trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable;
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.int8_calibration_table_name.size();
if (str_size == 0) {
trt_provider_options_v2.trt_int8_calibration_table_name = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size);
#else
strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_int8_calibration_table_name = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_int8_calibration_table_name = internal_options.int8_calibration_table_name.c_str();
}
trt_provider_options_v2.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
trt_provider_options_v2.trt_dla_enable = internal_options.dla_enable;
trt_provider_options_v2.trt_dla_core = internal_options.dla_core;
trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs;
trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable;
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.engine_cache_path.size();
if (str_size == 0) {
trt_provider_options_v2.trt_engine_cache_path = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.engine_cache_path.c_str(), str_size);
#else
strncpy(dest, internal_options.engine_cache_path.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_engine_cache_path = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_engine_cache_path = internal_options.engine_cache_path.c_str();
}
trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable;
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.engine_decryption_lib_path.size();
if (str_size == 0) {
trt_provider_options_v2.trt_engine_decryption_lib_path = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.engine_decryption_lib_path.c_str(), str_size);
#else
strncpy(dest, internal_options.engine_decryption_lib_path.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_engine_decryption_lib_path = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_engine_decryption_lib_path = internal_options.engine_decryption_lib_path.c_str();
}
trt_provider_options_v2.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build;
trt_provider_options_v2.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable;
trt_provider_options_v2.trt_layer_norm_fp32_fallback = internal_options.layer_norm_fp32_fallback;
trt_provider_options_v2.trt_timing_cache_enable = internal_options.timing_cache_enable;
trt_provider_options_v2.trt_force_timing_cache = internal_options.force_timing_cache;
trt_provider_options_v2.trt_detailed_build_log = internal_options.detailed_build_log;
trt_provider_options_v2.trt_build_heuristics_enable = internal_options.build_heuristics_enable;
trt_provider_options_v2.trt_sparsity_enable = internal_options.sparsity_enable;
trt_provider_options_v2.trt_builder_optimization_level = internal_options.builder_optimization_level;
trt_provider_options_v2.trt_auxiliary_streams = internal_options.auxiliary_streams;
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.tactic_sources.size();
if (str_size == 0) {
trt_provider_options_v2.trt_tactic_sources = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.tactic_sources.c_str(), str_size);
#else
strncpy(dest, internal_options.tactic_sources.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_tactic_sources = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_tactic_sources = internal_options.tactic_sources.c_str();
}
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.extra_plugin_lib_paths.size();
if (str_size == 0) {
trt_provider_options_v2.trt_extra_plugin_lib_paths = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.extra_plugin_lib_paths.c_str(), str_size);
#else
strncpy(dest, internal_options.extra_plugin_lib_paths.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_extra_plugin_lib_paths = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_extra_plugin_lib_paths = internal_options.extra_plugin_lib_paths.c_str();
}
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.profile_min_shapes.size();
if (str_size == 0) {
trt_provider_options_v2.trt_profile_min_shapes = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.profile_min_shapes.c_str(), str_size);
#else
strncpy(dest, internal_options.profile_min_shapes.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_profile_min_shapes = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_profile_min_shapes = internal_options.profile_min_shapes.c_str();
}
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.profile_max_shapes.size();
if (str_size == 0) {
trt_provider_options_v2.trt_profile_max_shapes = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.profile_max_shapes.c_str(), str_size);
#else
strncpy(dest, internal_options.profile_max_shapes.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_profile_max_shapes = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_profile_max_shapes = internal_options.profile_max_shapes.c_str();
}
if (string_copy) {
char* dest = nullptr;
auto str_size = internal_options.profile_opt_shapes.size();
if (str_size == 0) {
trt_provider_options_v2.trt_profile_opt_shapes = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.profile_opt_shapes.c_str(), str_size);
#else
strncpy(dest, internal_options.profile_opt_shapes.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_provider_options_v2.trt_profile_opt_shapes = (const char*)dest;
}
} else {
trt_provider_options_v2.trt_profile_opt_shapes = internal_options.profile_opt_shapes.c_str();
}
trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
}
} // namespace onnxruntime

View file

@ -54,6 +54,7 @@ struct TensorrtExecutionProviderInfo {
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
static ProviderOptions ToProviderOptions(const OrtTensorRTProviderOptionsV2& info);
static void UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy);
std::vector<OrtCustomOpDomain*> custom_op_domain_list;
};

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/tensorrt/tensorrt_provider_factory.h"
#include "tensorrt_provider_factory.h"
#include <atomic>
#include "tensorrt_execution_provider.h"
#include "tensorrt_provider_factory_creator.h"
@ -18,22 +18,45 @@ namespace onnxruntime {
void InitializeRegistry();
void DeleteRegistry();
struct ProviderInfo_TensorRT_Impl final : ProviderInfo_TensorRT {
OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) override {
auto cuda_err = cudaGetDevice(device_id);
if (cuda_err != cudaSuccess) {
return CreateStatus(ORT_FAIL, "Failed to get device id.");
}
return nullptr;
}
OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) override {
TensorrtExecutionProviderInfo::UpdateProviderOptions(provider_options, options, string_copy);
return nullptr;
}
OrtStatus* GetTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) override {
common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!status.IsOK()) {
return CreateStatus(ORT_FAIL, "[TensorRT EP] Can't create custom ops for TRT plugins.");
}
return nullptr;
}
OrtStatus* ReleaseCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list) override {
ReleaseTensorRTCustomOpDomainList(domain_list);
return nullptr;
}
} g_info;
struct TensorrtProviderFactory : IExecutionProviderFactory {
TensorrtProviderFactory(const TensorrtExecutionProviderInfo& info) : info_{info} {}
~TensorrtProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
private:
TensorrtExecutionProviderInfo info_;
};
void TensorrtProviderFactory::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
custom_op_domain_list = info_.custom_op_domain_list;
}
std::unique_ptr<IExecutionProvider> TensorrtProviderFactory::CreateProvider() {
return std::make_unique<TensorrtExecutionProvider>(info_);
}
@ -46,6 +69,7 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
}
struct Tensorrt_Provider : Provider {
void* GetInfo() override { return &g_info; }
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(int device_id) override {
TensorrtExecutionProviderInfo info;
info.device_id = device_id;
@ -55,6 +79,7 @@ struct Tensorrt_Provider : Provider {
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}
return std::make_shared<TensorrtProviderFactory>(info);
}
@ -104,161 +129,8 @@ struct Tensorrt_Provider : Provider {
return std::make_shared<TensorrtProviderFactory>(info);
}
/**
* This function will be called by the C API UpdateTensorRTProviderOptions().
*
* Please note that it will reset the OrtProviderOptionsV2 instance first and then set up the provided provider options
* See TensorrtExecutionProviderInfo::FromProviderOptions() for more details
*/
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options);
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptionsV2*>(provider_options);
trt_options.device_id = internal_options.device_id;
// The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well
// We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options
if (options.find("has_user_compute_stream") != options.end()) {
trt_options.has_user_compute_stream = internal_options.has_user_compute_stream;
}
trt_options.trt_max_partition_iterations = internal_options.max_partition_iterations;
trt_options.trt_min_subgraph_size = internal_options.min_subgraph_size;
trt_options.trt_max_workspace_size = internal_options.max_workspace_size;
trt_options.trt_fp16_enable = internal_options.fp16_enable;
trt_options.trt_int8_enable = internal_options.int8_enable;
char* dest = nullptr;
auto str_size = internal_options.int8_calibration_table_name.size();
if (str_size == 0) {
trt_options.trt_int8_calibration_table_name = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size);
#else
strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_int8_calibration_table_name = (const char*)dest;
}
trt_options.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
trt_options.trt_dla_enable = internal_options.dla_enable;
trt_options.trt_dla_core = internal_options.dla_core;
trt_options.trt_dump_subgraphs = internal_options.dump_subgraphs;
trt_options.trt_engine_cache_enable = internal_options.engine_cache_enable;
str_size = internal_options.engine_cache_path.size();
if (str_size == 0) {
trt_options.trt_engine_cache_path = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.engine_cache_path.c_str(), str_size);
#else
strncpy(dest, internal_options.engine_cache_path.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_engine_cache_path = (const char*)dest;
}
trt_options.trt_engine_decryption_enable = internal_options.engine_decryption_enable;
str_size = internal_options.engine_decryption_lib_path.size();
if (str_size == 0) {
trt_options.trt_engine_decryption_lib_path = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.engine_decryption_lib_path.c_str(), str_size);
#else
strncpy(dest, internal_options.engine_decryption_lib_path.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_engine_decryption_lib_path = (const char*)dest;
}
trt_options.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build;
trt_options.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable;
trt_options.trt_layer_norm_fp32_fallback = internal_options.layer_norm_fp32_fallback;
trt_options.trt_timing_cache_enable = internal_options.timing_cache_enable;
trt_options.trt_force_timing_cache = internal_options.force_timing_cache;
trt_options.trt_detailed_build_log = internal_options.detailed_build_log;
trt_options.trt_build_heuristics_enable = internal_options.build_heuristics_enable;
trt_options.trt_sparsity_enable = internal_options.sparsity_enable;
trt_options.trt_builder_optimization_level = internal_options.builder_optimization_level;
trt_options.trt_auxiliary_streams = internal_options.auxiliary_streams;
str_size = internal_options.tactic_sources.size();
if (str_size == 0) {
trt_options.trt_tactic_sources = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.tactic_sources.c_str(), str_size);
#else
strncpy(dest, internal_options.tactic_sources.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_tactic_sources = (const char*)dest;
}
str_size = internal_options.extra_plugin_lib_paths.size();
if (str_size == 0) {
trt_options.trt_extra_plugin_lib_paths = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.extra_plugin_lib_paths.c_str(), str_size);
#else
strncpy(dest, internal_options.extra_plugin_lib_paths.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_extra_plugin_lib_paths = (const char*)dest;
}
str_size = internal_options.profile_min_shapes.size();
if (str_size == 0) {
trt_options.trt_profile_min_shapes = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.profile_min_shapes.c_str(), str_size);
#else
strncpy(dest, internal_options.profile_min_shapes.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_profile_min_shapes = (const char*)dest;
}
str_size = internal_options.profile_max_shapes.size();
if (str_size == 0) {
trt_options.trt_profile_max_shapes = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.profile_max_shapes.c_str(), str_size);
#else
strncpy(dest, internal_options.profile_max_shapes.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_profile_max_shapes = (const char*)dest;
}
str_size = internal_options.profile_opt_shapes.size();
if (str_size == 0) {
trt_options.trt_profile_opt_shapes = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.profile_opt_shapes.c_str(), str_size);
#else
strncpy(dest, internal_options.profile_opt_shapes.c_str(), str_size);
#endif
dest[str_size] = '\0';
trt_options.trt_profile_opt_shapes = (const char*)dest;
}
trt_options.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
TensorrtExecutionProviderInfo::UpdateProviderOptions(provider_options, options, true);
}
ProviderOptions GetProviderOptions(const void* provider_options) override {
@ -266,11 +138,6 @@ struct Tensorrt_Provider : Provider {
return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options);
}
void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) override {
TensorrtProviderFactory* trt_factory = reinterpret_cast<TensorrtProviderFactory*>(factory);
trt_factory->GetCustomOpDomainList(custom_op_domains_ptr);
}
void Initialize() override {
InitializeRegistry();
}

View file

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_c_api.h"
#include "core/framework/provider_options.h"
namespace onnxruntime {
struct ProviderInfo_TensorRT {
virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0;
virtual OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) = 0;
virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) = 0;
virtual OrtStatus* ReleaseCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list) = 0;
protected:
~ProviderInfo_TensorRT() = default; // Can only be destroyed through a subclass instance
};
} // namespace onnxruntime

View file

@ -108,6 +108,8 @@ namespace onnxruntime {
ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
ProviderInfo_CUDA& GetProviderInfo_CUDA();
ProviderInfo_TensorRT* TryGetProviderInfo_TensorRT();
ProviderInfo_TensorRT& GetProviderInfo_TensorRT();
ProviderInfo_CANN* TryGetProviderInfo_CANN();
ProviderInfo_CANN& GetProviderInfo_CANN();
ProviderInfo_Dnnl* TryGetProviderInfo_Dnnl();
@ -1418,10 +1420,6 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options);
}
void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) {
s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr);
}
std::shared_ptr<IExecutionProviderFactory> MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) {
return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options);
}
@ -1474,6 +1472,20 @@ ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() {
return reinterpret_cast<ProviderInfo_OpenVINO*>(s_library_openvino.Get().GetInfo());
}
ProviderInfo_TensorRT* TryGetProviderInfo_TensorRT() try {
return reinterpret_cast<ProviderInfo_TensorRT*>(s_library_tensorrt.Get().GetInfo());
} catch (const std::exception& exception) {
LOGS_DEFAULT(ERROR) << exception.what();
return nullptr;
}
ProviderInfo_TensorRT& GetProviderInfo_TensorRT() {
if (auto* info = TryGetProviderInfo_TensorRT())
return *info;
ORT_THROW("TensorRT Provider not available, can't get interface for it");
}
ProviderInfo_CUDA* TryGetProviderInfo_CUDA() try {
return reinterpret_cast<ProviderInfo_CUDA*>(s_library_cuda.Get().GetInfo());
} catch (const std::exception& exception) {
@ -1633,7 +1645,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS
options->provider_factories.push_back(factory);
std::vector<OrtCustomOpDomain*> custom_op_domains;
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths");
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
@ -1664,7 +1678,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In
options->provider_factories.push_back(factory);
std::vector<OrtCustomOpDomain*> custom_op_domains;
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, "");
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
@ -1772,10 +1787,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2,
options->provider_factories.push_back(factory);
std::vector<OrtCustomOpDomain*> custom_op_domains;
TensorrtProviderGetCustomOpDomainList(factory.get(), custom_op_domains);
std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths;
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
options->custom_op_domains_.push_back(ptr);
}
return nullptr;
API_IMPL_END
}
@ -1784,34 +1802,6 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT
API_IMPL_BEGIN
#ifdef USE_TENSORRT
auto options = std::make_unique<OrtTensorRTProviderOptionsV2>();
options->device_id = 0;
options->has_user_compute_stream = 0;
options->user_compute_stream = nullptr;
options->trt_max_partition_iterations = 1000;
options->trt_min_subgraph_size = 1;
options->trt_max_workspace_size = 1 << 30;
options->trt_fp16_enable = false;
options->trt_int8_enable = false;
options->trt_int8_calibration_table_name = nullptr;
options->trt_int8_use_native_calibration_table = false;
options->trt_dla_enable = false;
options->trt_dla_core = false;
options->trt_dump_subgraphs = false;
options->trt_engine_cache_enable = false;
options->trt_engine_cache_path = nullptr;
options->trt_engine_decryption_enable = false;
options->trt_engine_decryption_lib_path = nullptr;
options->trt_force_sequential_engine_build = false;
options->trt_context_memory_sharing_enable = false;
options->trt_layer_norm_fp32_fallback = false;
options->trt_timing_cache_enable = false;
options->trt_force_timing_cache = false;
options->trt_detailed_build_log = false;
options->trt_extra_plugin_lib_paths = nullptr;
options->trt_profile_min_shapes = nullptr;
options->trt_profile_max_shapes = nullptr;
options->trt_profile_opt_shapes = nullptr;
options->trt_cuda_graph_enable = false;
*out = options.release();
return nullptr;
#else

View file

@ -465,6 +465,9 @@ class InferenceSession(Session):
)
session_options = self._sess_options if self._sess_options else C.get_default_session_options()
self._register_ep_custom_ops(session_options, providers, provider_options)
if self._model_path:
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
else:
@ -507,6 +510,13 @@ class InferenceSession(Session):
self._sess_options = self._sess_options_initial
self._create_inference_session(providers, provider_options)
def _register_ep_custom_ops(self, session_options, providers, provider_options):
for i in range(len(providers)):
if providers[i] == "TensorrtExecutionProvider":
C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i])
elif isinstance(providers[i], tuple) and providers[i][0] == "TensorrtExecutionProvider":
C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])
class IOBinding:
"""

View file

@ -430,6 +430,25 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM*
}
#endif
#ifdef USE_TENSORRT
void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) {
if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) {
std::string trt_extra_plugin_lib_paths = "";
const auto it = options.find("trt_extra_plugin_lib_paths");
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<OrtCustomOpDomain*> domain_list;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths);
for (auto ptr : domain_list) {
so.custom_op_domains_.push_back(ptr);
}
} else {
ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.");
}
}
#endif
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
const SessionOptions& session_options,
const std::string& type,
@ -443,43 +462,14 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
// If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case
// as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies.
if (Env::Default().GetEnvironmentVar("ORT_TENSORRT_UNAVAILABLE").empty()) {
std::string calibration_table, cache_path, lib_path, min_profile, max_profile, opt_profile;
// provider_options_map is just a reference to the ProviderOptionsMap instance, so it can be released anytime from application.
// So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance.
// (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0,
0,
0,
0,
0,
0,
0,
0,
2,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
0};
OrtTensorRTProviderOptionsV2 params;
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
@ -666,13 +656,15 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
}
} else if (option.first == "trt_tactic_sources") {
if (!option.second.empty()) {
params.trt_tactic_sources = option.second.c_str();
trt_tactic_sources = option.second;
params.trt_tactic_sources = trt_tactic_sources.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a string. e.g. \"-CUDNN,+CUBLAS\" available keys: \"CUBLAS\"|\"CUBLAS_LT\"|\"CUDNN\"|\"EDGE_MASK_CONVOLUTIONS\".\n");
}
} else if (option.first == "trt_extra_plugin_lib_paths") {
if (!option.second.empty()) {
params.trt_extra_plugin_lib_paths = option.second.c_str();
trt_extra_plugin_lib_paths = option.second;
params.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a path string.\n");
}
@ -1209,6 +1201,12 @@ void addGlobalMethods(py::module& m) {
});
#endif
#ifdef USE_TENSORRT
m.def(
"register_tensorrt_plugins_as_custom_ops", [](PySessionOptions& so, const ProviderOptions& options) { RegisterTensorRTPluginsAsCustomOps(so, options); },
"Register TensorRT plugins as custom ops.");
#endif
#ifdef ENABLE_ATEN
m.def("register_aten_op_executor",
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {

View file

@ -180,6 +180,13 @@ extern onnxruntime::ArenaExtendStrategy arena_extend_strategy;
} // namespace onnxruntime
#endif
#ifdef USE_TENSORRT
namespace onnxruntime {
ProviderInfo_TensorRT* TryGetProviderInfo_TensorRT();
ProviderInfo_TensorRT& GetProviderInfo_TensorRT();
} // namespace onnxruntime
#endif
#ifdef USE_CANN
namespace onnxruntime {
ProviderInfo_CANN* TryGetProviderInfo_CANN();

View file

@ -690,11 +690,7 @@ TEST_P(ModelTest, Run) {
#endif
else if (provider_name == "tensorrt") {
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30,
1, // enable fp16
0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0,
3, -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0};
OrtTensorRTProviderOptionsV2 params;
ortso.AppendExecutionProvider_TensorRT_V2(params);
} else {
OrtTensorRTProviderOptionsV2* ep_option = nullptr;

View file

@ -175,41 +175,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0,
0,
0,
0,
0,
0,
0,
0,
3,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
0};
OrtTensorRTProviderOptionsV2 params;
params.trt_engine_cache_enable = 1;
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
@ -259,41 +225,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string
std::vector<int64_t> expected_dims_nonzero_m = {3, 6};
std::vector<int64_t> expected_values_nonzero_m = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 0, 1, 0, 1, 0, 1};
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0,
0,
0,
0,
0,
0,
0,
0,
3,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
0};
OrtTensorRTProviderOptionsV2 params;
params.trt_engine_cache_enable = 1;
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
@ -422,41 +354,7 @@ TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
output_names.push_back("output");
std::vector<OrtValue> fetches;
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0,
0,
0,
0,
0,
0,
0,
0,
3,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
0};
OrtTensorRTProviderOptionsV2 params;
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
std::cout << model_name << std::endl;
@ -516,41 +414,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) {
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0,
0,
0,
0,
0,
0,
0,
0,
3,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
0};
OrtTensorRTProviderOptionsV2 params;
if (cache_type.compare("engine") == 0) {
/* Following code block tests the functionality of engine and optimization profile of ORT TRT, including:
* - engine cache serialization/de-serialization

View file

@ -10,9 +10,6 @@
#ifdef USE_TVM
#include "core/providers/tvm/tvm_provider_factory.h"
#endif
#ifdef USE_TENSORRT
#include "core/providers/tensorrt/tensorrt_provider_factory.h"
#endif
#ifdef USE_OPENVINO
#include "core/providers/openvino/openvino_provider_factory.h"
#endif

View file

@ -67,7 +67,19 @@ with open(args.output_source, "w") as file:
# external symbols are removed, xnnpack ep will be created via the standard ORT API.
# https://github.com/microsoft/onnxruntime/pull/11798
if c not in ("vitisai", "winml", "cuda", "rocm", "migraphx", "qnn", "snpe", "xnnpack", "cann", "dnnl"):
if c not in (
"vitisai",
"winml",
"cuda",
"rocm",
"migraphx",
"qnn",
"snpe",
"xnnpack",
"cann",
"dnnl",
"tensorrt",
):
file.write(f"#include <core/providers/{c}/{c}_provider_factory.h>\n")
file.write("void* GetFunctionEntryByName(const char* name){\n")
for symbol in symbols:

View file

@ -58,7 +58,6 @@ steps:
copy $(Build.SourcesDirectory)\include\onnxruntime\core\session\onnxruntime_*.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include
copy $(Build.SourcesDirectory)\include\onnxruntime\core\framework\provider_options.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include
copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\cpu\cpu_provider_factory.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include
copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\tensorrt\tensorrt_provider_factory.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include
copy $(Build.SourcesDirectory)\orttraining\orttraining\training_api\include\onnxruntime_training*.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include
REM copy the README, license and TPN

View file

@ -27,7 +27,6 @@ if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then
fi
if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_tensorrt.so" ]]; then
cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_tensorrt.so $BINARY_DIR/$ARTIFACT_NAME/lib
cp $SOURCE_DIR/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include
fi
if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_rocm.so" ]]; then
cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_shared.so $BINARY_DIR/$ARTIFACT_NAME/lib

View file

@ -28,4 +28,3 @@ rm $ARTIFACT_DIR/onnxruntime-linux-x64-cuda-*.tgz
cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime.so* onnxruntime-linux-x64-gpu/*/lib
cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_tensorrt.so onnxruntime-linux-x64-gpu/*/lib
cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_shared.so onnxruntime-linux-x64-gpu/*/lib
cp onnxruntime-linux-x64-tensorrt/*/include/*tensorrt* onnxruntime-linux-x64-gpu/*/include

View file

@ -7,7 +7,6 @@ FOR /R %%i IN (*.nupkg) do (
set filename=%%~ni
IF NOT "!filename:~25,7!"=="Managed" (
mkdir build\native\include
copy %BUILD_SOURCESDIRECTORY%\include\onnxruntime\core\providers\tensorrt\tensorrt_provider_factory.h build\native\include\tensorrt_provider_factory.h
7z a %%~ni.nupkg build
)
)

View file

@ -437,14 +437,7 @@ def generate_files(line_list, args):
)
if args.execution_provider == "tensorrt":
files_list.append(
"<file src="
+ '"'
+ os.path.join(
args.sources_path, "include\\onnxruntime\\core\\providers\\tensorrt\\tensorrt_provider_factory.h"
)
+ '" target="build\\native\\include" />'
)
files_list.append("<file src=" + '"' + '" target="build\\native\\include" />')
if args.execution_provider == "dnnl":
files_list.append(

View file

@ -23,7 +23,6 @@ win_gpu_package_libraries = [
]
gpu_related_header_files = [
"cpu_provider_factory.h",
"tensorrt_provider_factory.h",
"onnxruntime_c_api.h",
"onnxruntime_cxx_api.h",
"onnxruntime_cxx_inline.h",