diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 43b0b938f1..fd4ad0b6d9 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -5,7 +5,9 @@ /// /// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2. -/// Please note that this struct is identical to OrtTensorRTProviderOptions but only to be used internally. +/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally. +/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined +/// OrtTensorRTProviderOptions will be deprecated over time. /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. /// struct OrtTensorRTProviderOptionsV2 { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8add1cddaa..84708c27e8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -482,6 +482,9 @@ typedef struct OrtTensorRTProviderOptions { int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true const char* trt_engine_decryption_lib_path; // specify engine decryption library path int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true + // This is the legacy struct and don't add new fields here. + // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h + // For non-string field, need to create a new separate api to handle it. } OrtTensorRTProviderOptions; /** \brief MIGraphX Provider Options diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c2463977f1..0007a9d044 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -360,6 +360,7 @@ struct SessionOptions : Base { SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d281bb5542..063acb1702 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or return *this; } +inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(p_, &provider_options)); + return *this; +} + inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options)); return *this; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index cfc43350a2..7386ce6c88 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -27,6 +27,7 @@ constexpr const char* kCachePath = "trt_engine_cache_path"; constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; +// add new provider option name here. } // namespace provider_option_names } // namespace tensorrt @@ -63,7 +64,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) - .Parse(options)); + .Parse(options)); // add new provider option here. return info; } @@ -87,6 +88,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, + // add new provider option here. }; return options; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index d65c91d88f..0929b193f3 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -6,6 +6,7 @@ #include #include "tensorrt_execution_provider.h" #include "core/framework/provider_options.h" +#include "core/providers/tensorrt/tensorrt_provider_options.h" #include using namespace onnxruntime; @@ -48,7 +49,7 @@ struct Tensorrt_Provider : Provider { } std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); + auto& options = *reinterpret_cast(provider_options); TensorrtExecutionProviderInfo info; info.device_id = options.device_id; info.has_user_compute_stream = options.has_user_compute_stream != 0; @@ -74,7 +75,7 @@ struct Tensorrt_Provider : Provider { void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { auto internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options); - auto& trt_options = *reinterpret_cast(provider_options); + auto& trt_options = *reinterpret_cast(provider_options); trt_options.device_id = internal_options.device_id; trt_options.trt_max_partition_iterations = internal_options.max_partition_iterations; trt_options.trt_min_subgraph_size = internal_options.min_subgraph_size; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 950c8ba657..e94a58cb20 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1148,7 +1148,43 @@ std::shared_ptr CreateExecutionProviderFactory_MIGrap return nullptr; } +// Adapter to convert the legacy OrtTensorRTProviderOptions to the latest OrtTensorRTProviderOptionsV2 +OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(const OrtTensorRTProviderOptions* legacy_trt_options) { + OrtTensorRTProviderOptionsV2 trt_options_converted; + + trt_options_converted.device_id = legacy_trt_options->device_id; + trt_options_converted.has_user_compute_stream = legacy_trt_options->has_user_compute_stream; + trt_options_converted.user_compute_stream = legacy_trt_options->user_compute_stream; + trt_options_converted.trt_max_partition_iterations = legacy_trt_options->trt_max_partition_iterations; + trt_options_converted.trt_min_subgraph_size = legacy_trt_options->trt_min_subgraph_size; + trt_options_converted.trt_max_workspace_size = legacy_trt_options->trt_max_workspace_size; + trt_options_converted.trt_fp16_enable = legacy_trt_options->trt_fp16_enable; + trt_options_converted.trt_int8_enable = legacy_trt_options->trt_int8_enable; + trt_options_converted.trt_int8_calibration_table_name = legacy_trt_options->trt_int8_calibration_table_name; + trt_options_converted.trt_int8_use_native_calibration_table = legacy_trt_options->trt_int8_use_native_calibration_table; + trt_options_converted.trt_dla_enable = legacy_trt_options->trt_dla_enable; + trt_options_converted.trt_dla_core = legacy_trt_options->trt_dla_core; + trt_options_converted.trt_dump_subgraphs = legacy_trt_options->trt_dump_subgraphs; + trt_options_converted.trt_engine_cache_enable = legacy_trt_options->trt_engine_cache_enable; + trt_options_converted.trt_engine_cache_path = legacy_trt_options->trt_engine_cache_path; + trt_options_converted.trt_engine_decryption_enable = legacy_trt_options->trt_engine_decryption_enable; + trt_options_converted.trt_engine_decryption_lib_path = legacy_trt_options->trt_engine_decryption_lib_path; + trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build; + // Add new provider option below + // Use default value as this field is not available in OrtTensorRTProviderOptionsV + + return trt_options_converted; +} + std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* provider_options) { + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options); + if (auto* provider = s_library_tensorrt.Get()) + return provider->CreateExecutionProviderFactory(&trt_options_converted); + + return nullptr; +} + +std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* provider_options) { if (auto* provider = s_library_tensorrt.Get()) return provider->CreateExecutionProviderFactory(provider_options); @@ -1420,7 +1456,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or } ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) { - return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT(options, reinterpret_cast(tensorrt_options)); + API_IMPL_BEGIN + auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library"); + } + + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 6834c4acb2..56c59406bc 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -27,6 +27,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" +#include "core/providers/tensorrt/tensorrt_provider_options.h" // Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, // GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses @@ -374,7 +375,7 @@ std::unique_ptr CreateExecutionProviderInstance( std::string calibration_table, cache_path, lib_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { - OrtTensorRTProviderOptions params{ + OrtTensorRTProviderOptionsV2 params{ 0, 0, nullptr, diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 773db017e5..5477f3c91a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -28,6 +28,7 @@ struct OrtStatus { #include "core/providers/providers.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/providers/cpu/cpu_provider_factory_creator.h" +#include "core/providers/tensorrt/tensorrt_provider_options.h" #if defined(USE_CUDA) || defined(USE_ROCM) #define BACKEND_PROC "GPU" @@ -474,6 +475,7 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor); } // namespace python std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params); +std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id); std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 24f283a428..2334f9e11a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -1,6 +1,7 @@ #include "ort_test_session.h" #include #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/providers/tensorrt/tensorrt_provider_options.h" #include #include "providers.h" #include "TestCase.h" @@ -209,7 +210,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n"); } } - OrtTensorRTProviderOptions tensorrt_options; + OrtTensorRTProviderOptionsV2 tensorrt_options; tensorrt_options.device_id = device_id; tensorrt_options.has_user_compute_stream = 0; tensorrt_options.user_compute_stream = nullptr; @@ -228,7 +229,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable; tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str(); tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build; - session_options.AppendExecutionProvider_TensorRT(tensorrt_options); + session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); OrtCUDAProviderOptions cuda_options; cuda_options.device_id=device_id; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index cba31055b7..ac7c6a69f9 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -8,6 +8,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/inference_session.h" #include "core/session/ort_env.h" +#include "core/providers/tensorrt/tensorrt_provider_options.h" #include "asserts.h" #include #include "default_providers.h" @@ -591,7 +592,7 @@ TEST_P(ModelTest, Run) { ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNupharExecutionProvider())); } else if (provider_name == "tensorrt") { if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) { - OrtTensorRTProviderOptions params{ + OrtTensorRTProviderOptionsV2 params{ 0, 0, nullptr, diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 209d424422..5e0975d667 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -54,6 +54,16 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O return nullptr; } +std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) { +#ifdef USE_TENSORRT + if (auto factory = CreateExecutionProviderFactory_Tensorrt(params)) + return factory->CreateProvider(); +#else + ORT_UNUSED_PARAMETER(params); +#endif + return nullptr; +} + std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX OrtMIGraphXProviderOptions params{ diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 6fa50c61cd..980129e95c 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -24,6 +24,7 @@ std::shared_ptr CreateExecutionProviderFactory_OpenVI std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params); +std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params); // EP for internal testing std::shared_ptr CreateExecutionProviderFactory_InternalTesting(const std::unordered_set& supported_ops); @@ -38,6 +39,7 @@ std::unique_ptr DefaultNupharExecutionProvider(bool allow_un //std::unique_ptr DefaultStvmExecutionProvider(); std::unique_ptr DefaultTensorrtExecutionProvider(); std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params); +std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params); std::unique_ptr DefaultMIGraphXExecutionProvider(); std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params); std::unique_ptr DefaultOpenVINOExecutionProvider();