diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 43b0b938f1..fd4ad0b6d9 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -5,7 +5,9 @@
///
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
-/// Please note that this struct is identical to OrtTensorRTProviderOptions but only to be used internally.
+/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
+/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
+/// OrtTensorRTProviderOptions will be deprecated over time.
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
///
struct OrtTensorRTProviderOptionsV2 {
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 8add1cddaa..84708c27e8 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -482,6 +482,9 @@ typedef struct OrtTensorRTProviderOptions {
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
+ // This is the legacy struct and don't add new fields here.
+ // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+ // For non-string field, need to create a new separate api to handle it.
} OrtTensorRTProviderOptions;
/** \brief MIGraphX Provider Options
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index c2463977f1..0007a9d044 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -360,6 +360,7 @@ struct SessionOptions : Base {
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
+ SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index d281bb5542..063acb1702 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
return *this;
}
+inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(p_, &provider_options));
+ return *this;
+}
+
inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options));
return *this;
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
index cfc43350a2..7386ce6c88 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
@@ -27,6 +27,7 @@ constexpr const char* kCachePath = "trt_engine_cache_path";
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
+// add new provider option name here.
} // namespace provider_option_names
} // namespace tensorrt
@@ -63,7 +64,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
- .Parse(options));
+ .Parse(options)); // add new provider option here.
return info;
}
@@ -87,6 +88,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
+ // add new provider option here.
};
return options;
}
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
index d65c91d88f..0929b193f3 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
@@ -6,6 +6,7 @@
#include
#include "tensorrt_execution_provider.h"
#include "core/framework/provider_options.h"
+#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include
using namespace onnxruntime;
@@ -48,7 +49,7 @@ struct Tensorrt_Provider : Provider {
}
std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override {
- auto& options = *reinterpret_cast(provider_options);
+ auto& options = *reinterpret_cast(provider_options);
TensorrtExecutionProviderInfo info;
info.device_id = options.device_id;
info.has_user_compute_stream = options.has_user_compute_stream != 0;
@@ -74,7 +75,7 @@ struct Tensorrt_Provider : Provider {
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options);
- auto& trt_options = *reinterpret_cast(provider_options);
+ auto& trt_options = *reinterpret_cast(provider_options);
trt_options.device_id = internal_options.device_id;
trt_options.trt_max_partition_iterations = internal_options.max_partition_iterations;
trt_options.trt_min_subgraph_size = internal_options.min_subgraph_size;
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index 950c8ba657..e94a58cb20 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -1148,7 +1148,43 @@ std::shared_ptr CreateExecutionProviderFactory_MIGrap
return nullptr;
}
+// Adapter to convert the legacy OrtTensorRTProviderOptions to the latest OrtTensorRTProviderOptionsV2
+OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(const OrtTensorRTProviderOptions* legacy_trt_options) {
+ OrtTensorRTProviderOptionsV2 trt_options_converted;
+
+ trt_options_converted.device_id = legacy_trt_options->device_id;
+ trt_options_converted.has_user_compute_stream = legacy_trt_options->has_user_compute_stream;
+ trt_options_converted.user_compute_stream = legacy_trt_options->user_compute_stream;
+ trt_options_converted.trt_max_partition_iterations = legacy_trt_options->trt_max_partition_iterations;
+ trt_options_converted.trt_min_subgraph_size = legacy_trt_options->trt_min_subgraph_size;
+ trt_options_converted.trt_max_workspace_size = legacy_trt_options->trt_max_workspace_size;
+ trt_options_converted.trt_fp16_enable = legacy_trt_options->trt_fp16_enable;
+ trt_options_converted.trt_int8_enable = legacy_trt_options->trt_int8_enable;
+ trt_options_converted.trt_int8_calibration_table_name = legacy_trt_options->trt_int8_calibration_table_name;
+ trt_options_converted.trt_int8_use_native_calibration_table = legacy_trt_options->trt_int8_use_native_calibration_table;
+ trt_options_converted.trt_dla_enable = legacy_trt_options->trt_dla_enable;
+ trt_options_converted.trt_dla_core = legacy_trt_options->trt_dla_core;
+ trt_options_converted.trt_dump_subgraphs = legacy_trt_options->trt_dump_subgraphs;
+ trt_options_converted.trt_engine_cache_enable = legacy_trt_options->trt_engine_cache_enable;
+ trt_options_converted.trt_engine_cache_path = legacy_trt_options->trt_engine_cache_path;
+ trt_options_converted.trt_engine_decryption_enable = legacy_trt_options->trt_engine_decryption_enable;
+ trt_options_converted.trt_engine_decryption_lib_path = legacy_trt_options->trt_engine_decryption_lib_path;
+ trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build;
+ // Add new provider option below
+ // Use default value as this field is not available in OrtTensorRTProviderOptionsV
+
+ return trt_options_converted;
+}
+
std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* provider_options) {
+ OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options);
+ if (auto* provider = s_library_tensorrt.Get())
+ return provider->CreateExecutionProviderFactory(&trt_options_converted);
+
+ return nullptr;
+}
+
+std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* provider_options) {
if (auto* provider = s_library_tensorrt.Get())
return provider->CreateExecutionProviderFactory(provider_options);
@@ -1420,7 +1456,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) {
- return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT(options, reinterpret_cast(tensorrt_options));
+ API_IMPL_BEGIN
+ auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options);
+ if (!factory) {
+ return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library");
+ }
+
+ options->provider_factories.push_back(factory);
+ return nullptr;
+ API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out) {
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 6834c4acb2..56c59406bc 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -27,6 +27,7 @@
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
+#include "core/providers/tensorrt/tensorrt_provider_options.h"
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
@@ -374,7 +375,7 @@ std::unique_ptr CreateExecutionProviderInstance(
std::string calibration_table, cache_path, lib_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
- OrtTensorRTProviderOptions params{
+ OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h
index 773db017e5..5477f3c91a 100644
--- a/onnxruntime/python/onnxruntime_pybind_state_common.h
+++ b/onnxruntime/python/onnxruntime_pybind_state_common.h
@@ -28,6 +28,7 @@ struct OrtStatus {
#include "core/providers/providers.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/providers/cpu/cpu_provider_factory_creator.h"
+#include "core/providers/tensorrt/tensorrt_provider_options.h"
#if defined(USE_CUDA) || defined(USE_ROCM)
#define BACKEND_PROC "GPU"
@@ -474,6 +475,7 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor);
} // namespace python
std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
+std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id);
std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params);
std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id);
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 24f283a428..2334f9e11a 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -1,6 +1,7 @@
#include "ort_test_session.h"
#include
#include "core/session/onnxruntime_session_options_config_keys.h"
+#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include
#include "providers.h"
#include "TestCase.h"
@@ -209,7 +210,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n");
}
}
- OrtTensorRTProviderOptions tensorrt_options;
+ OrtTensorRTProviderOptionsV2 tensorrt_options;
tensorrt_options.device_id = device_id;
tensorrt_options.has_user_compute_stream = 0;
tensorrt_options.user_compute_stream = nullptr;
@@ -228,7 +229,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable;
tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str();
tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build;
- session_options.AppendExecutionProvider_TensorRT(tensorrt_options);
+ session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id=device_id;
diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc
index cba31055b7..ac7c6a69f9 100644
--- a/onnxruntime/test/providers/cpu/model_tests.cc
+++ b/onnxruntime/test/providers/cpu/model_tests.cc
@@ -8,6 +8,7 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/inference_session.h"
#include "core/session/ort_env.h"
+#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "asserts.h"
#include
#include "default_providers.h"
@@ -591,7 +592,7 @@ TEST_P(ModelTest, Run) {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNupharExecutionProvider()));
} else if (provider_name == "tensorrt") {
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
- OrtTensorRTProviderOptions params{
+ OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index 209d424422..5e0975d667 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -54,6 +54,16 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O
return nullptr;
}
+std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) {
+#ifdef USE_TENSORRT
+ if (auto factory = CreateExecutionProviderFactory_Tensorrt(params))
+ return factory->CreateProvider();
+#else
+ ORT_UNUSED_PARAMETER(params);
+#endif
+ return nullptr;
+}
+
std::unique_ptr DefaultMIGraphXExecutionProvider() {
#ifdef USE_MIGRAPHX
OrtMIGraphXProviderOptions params{
diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h
index 6fa50c61cd..980129e95c 100644
--- a/onnxruntime/test/util/include/default_providers.h
+++ b/onnxruntime/test/util/include/default_providers.h
@@ -24,6 +24,7 @@ std::shared_ptr CreateExecutionProviderFactory_OpenVI
std::shared_ptr CreateExecutionProviderFactory_Rknpu();
std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options);
std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
+std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
// EP for internal testing
std::shared_ptr CreateExecutionProviderFactory_InternalTesting(const std::unordered_set& supported_ops);
@@ -38,6 +39,7 @@ std::unique_ptr DefaultNupharExecutionProvider(bool allow_un
//std::unique_ptr DefaultStvmExecutionProvider();
std::unique_ptr DefaultTensorrtExecutionProvider();
std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params);
+std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params);
std::unique_ptr DefaultMIGraphXExecutionProvider();
std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params);
std::unique_ptr DefaultOpenVINOExecutionProvider();