mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Make user capable of adding new field in OrtTensorRTProviderOptionsV2 as new provider option (#10450)
* modify code for add additional field in OrtTensorRTProviderOptionsV2 * add include file * fix typo * fix bug * add comment * fix code * revert change
This commit is contained in:
parent
927f1f18c9
commit
0f5d0a091a
13 changed files with 84 additions and 9 deletions
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
/// <summary>
|
||||
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
|
||||
/// Please note that this struct is identical to OrtTensorRTProviderOptions but only to be used internally.
|
||||
/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
|
||||
/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
|
||||
/// OrtTensorRTProviderOptions will be deprecated over time.
|
||||
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
|
||||
/// </summary>
|
||||
struct OrtTensorRTProviderOptionsV2 {
|
||||
|
|
|
|||
|
|
@ -482,6 +482,9 @@ typedef struct OrtTensorRTProviderOptions {
|
|||
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
|
||||
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
|
||||
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
|
||||
// This is the legacy struct and don't add new fields here.
|
||||
// For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
|
||||
// For non-string field, need to create a new separate api to handle it.
|
||||
} OrtTensorRTProviderOptions;
|
||||
|
||||
/** \brief MIGraphX Provider Options
|
||||
|
|
|
|||
|
|
@ -360,6 +360,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
|||
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
|
||||
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
|
||||
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
||||
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
||||
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
|
||||
|
||||
SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
|
||||
|
|
|
|||
|
|
@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(p_, &provider_options));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options));
|
||||
return *this;
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ constexpr const char* kCachePath = "trt_engine_cache_path";
|
|||
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
|
||||
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
|
||||
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
|
||||
// add new provider option name here.
|
||||
} // namespace provider_option_names
|
||||
} // namespace tensorrt
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
|
|||
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
|
||||
.Parse(options));
|
||||
.Parse(options)); // add new provider option here.
|
||||
|
||||
return info;
|
||||
}
|
||||
|
|
@ -87,6 +88,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
|
|||
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
|
||||
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
|
||||
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
|
||||
// add new provider option here.
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <atomic>
|
||||
#include "tensorrt_execution_provider.h"
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
#include <string.h>
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
|
@ -48,7 +49,7 @@ struct Tensorrt_Provider : Provider {
|
|||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* provider_options) override {
|
||||
auto& options = *reinterpret_cast<const OrtTensorRTProviderOptions*>(provider_options);
|
||||
auto& options = *reinterpret_cast<const OrtTensorRTProviderOptionsV2*>(provider_options);
|
||||
TensorrtExecutionProviderInfo info;
|
||||
info.device_id = options.device_id;
|
||||
info.has_user_compute_stream = options.has_user_compute_stream != 0;
|
||||
|
|
@ -74,7 +75,7 @@ struct Tensorrt_Provider : Provider {
|
|||
|
||||
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
|
||||
auto internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options);
|
||||
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptions*>(provider_options);
|
||||
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptionsV2*>(provider_options);
|
||||
trt_options.device_id = internal_options.device_id;
|
||||
trt_options.trt_max_partition_iterations = internal_options.max_partition_iterations;
|
||||
trt_options.trt_min_subgraph_size = internal_options.min_subgraph_size;
|
||||
|
|
|
|||
|
|
@ -1148,7 +1148,43 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGrap
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Adapter to convert the legacy OrtTensorRTProviderOptions to the latest OrtTensorRTProviderOptionsV2
|
||||
OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(const OrtTensorRTProviderOptions* legacy_trt_options) {
|
||||
OrtTensorRTProviderOptionsV2 trt_options_converted;
|
||||
|
||||
trt_options_converted.device_id = legacy_trt_options->device_id;
|
||||
trt_options_converted.has_user_compute_stream = legacy_trt_options->has_user_compute_stream;
|
||||
trt_options_converted.user_compute_stream = legacy_trt_options->user_compute_stream;
|
||||
trt_options_converted.trt_max_partition_iterations = legacy_trt_options->trt_max_partition_iterations;
|
||||
trt_options_converted.trt_min_subgraph_size = legacy_trt_options->trt_min_subgraph_size;
|
||||
trt_options_converted.trt_max_workspace_size = legacy_trt_options->trt_max_workspace_size;
|
||||
trt_options_converted.trt_fp16_enable = legacy_trt_options->trt_fp16_enable;
|
||||
trt_options_converted.trt_int8_enable = legacy_trt_options->trt_int8_enable;
|
||||
trt_options_converted.trt_int8_calibration_table_name = legacy_trt_options->trt_int8_calibration_table_name;
|
||||
trt_options_converted.trt_int8_use_native_calibration_table = legacy_trt_options->trt_int8_use_native_calibration_table;
|
||||
trt_options_converted.trt_dla_enable = legacy_trt_options->trt_dla_enable;
|
||||
trt_options_converted.trt_dla_core = legacy_trt_options->trt_dla_core;
|
||||
trt_options_converted.trt_dump_subgraphs = legacy_trt_options->trt_dump_subgraphs;
|
||||
trt_options_converted.trt_engine_cache_enable = legacy_trt_options->trt_engine_cache_enable;
|
||||
trt_options_converted.trt_engine_cache_path = legacy_trt_options->trt_engine_cache_path;
|
||||
trt_options_converted.trt_engine_decryption_enable = legacy_trt_options->trt_engine_decryption_enable;
|
||||
trt_options_converted.trt_engine_decryption_lib_path = legacy_trt_options->trt_engine_decryption_lib_path;
|
||||
trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build;
|
||||
// Add new provider option below
|
||||
// Use default value as this field is not available in OrtTensorRTProviderOptionsV
|
||||
|
||||
return trt_options_converted;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* provider_options) {
|
||||
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options);
|
||||
if (auto* provider = s_library_tensorrt.Get())
|
||||
return provider->CreateExecutionProviderFactory(&trt_options_converted);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* provider_options) {
|
||||
if (auto* provider = s_library_tensorrt.Get())
|
||||
return provider->CreateExecutionProviderFactory(provider_options);
|
||||
|
||||
|
|
@ -1420,7 +1456,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or
|
|||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) {
|
||||
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT(options, reinterpret_cast<const OrtTensorRTProviderOptions*>(tensorrt_options));
|
||||
API_IMPL_BEGIN
|
||||
auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options);
|
||||
if (!factory) {
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library");
|
||||
}
|
||||
|
||||
options->provider_factories.push_back(factory);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out) {
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/provider_bridge_ort.h"
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
|
||||
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
|
||||
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
|
||||
|
|
@ -374,7 +375,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
std::string calibration_table, cache_path, lib_path;
|
||||
auto it = provider_options_map.find(type);
|
||||
if (it != provider_options_map.end()) {
|
||||
OrtTensorRTProviderOptions params{
|
||||
OrtTensorRTProviderOptionsV2 params{
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ struct OrtStatus {
|
|||
#include "core/providers/providers.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/providers/cpu/cpu_provider_factory_creator.h"
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#define BACKEND_PROC "GPU"
|
||||
|
|
@ -474,6 +475,7 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor);
|
|||
} // namespace python
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(int device_id);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "ort_test_session.h"
|
||||
#include <core/session/onnxruntime_cxx_api.h>
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
#include <assert.h>
|
||||
#include "providers.h"
|
||||
#include "TestCase.h"
|
||||
|
|
@ -209,7 +210,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n");
|
||||
}
|
||||
}
|
||||
OrtTensorRTProviderOptions tensorrt_options;
|
||||
OrtTensorRTProviderOptionsV2 tensorrt_options;
|
||||
tensorrt_options.device_id = device_id;
|
||||
tensorrt_options.has_user_compute_stream = 0;
|
||||
tensorrt_options.user_compute_stream = nullptr;
|
||||
|
|
@ -228,7 +229,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable;
|
||||
tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str();
|
||||
tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build;
|
||||
session_options.AppendExecutionProvider_TensorRT(tensorrt_options);
|
||||
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);
|
||||
|
||||
OrtCUDAProviderOptions cuda_options;
|
||||
cuda_options.device_id=device_id;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/ort_env.h"
|
||||
#include "core/providers/tensorrt/tensorrt_provider_options.h"
|
||||
#include "asserts.h"
|
||||
#include <core/platform/path_lib.h>
|
||||
#include "default_providers.h"
|
||||
|
|
@ -591,7 +592,7 @@ TEST_P(ModelTest, Run) {
|
|||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNupharExecutionProvider()));
|
||||
} else if (provider_name == "tensorrt") {
|
||||
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
|
||||
OrtTensorRTProviderOptions params{
|
||||
OrtTensorRTProviderOptionsV2 params{
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
|
|
|
|||
|
|
@ -54,6 +54,16 @@ std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const O
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) {
|
||||
#ifdef USE_TENSORRT
|
||||
if (auto factory = CreateExecutionProviderFactory_Tensorrt(params))
|
||||
return factory->CreateProvider();
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(params);
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
|
||||
#ifdef USE_MIGRAPHX
|
||||
OrtMIGraphXProviderOptions params{
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVI
|
|||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
|
||||
|
||||
// EP for internal testing
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_InternalTesting(const std::unordered_set<std::string>& supported_ops);
|
||||
|
|
@ -38,6 +39,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider(bool allow_un
|
|||
//std::unique_ptr<IExecutionProvider> DefaultStvmExecutionProvider();
|
||||
std::unique_ptr<IExecutionProvider> DefaultTensorrtExecutionProvider();
|
||||
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params);
|
||||
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params);
|
||||
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider();
|
||||
std::unique_ptr<IExecutionProvider> MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params);
|
||||
std::unique_ptr<IExecutionProvider> DefaultOpenVINOExecutionProvider();
|
||||
|
|
|
|||
Loading…
Reference in a new issue