Make user capable of adding new field in OrtTensorRTProviderOptionsV2 as new provider option (#10450)

* modify code for add additional field in OrtTensorRTProviderOptionsV2

* add include file

* fix typo

* fix bug

* add comment

* fix code

* revert change
This commit is contained in:
Chi Lo 2022-02-05 11:15:12 -08:00 committed by GitHub
parent 927f1f18c9
commit 0f5d0a091a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 84 additions and 9 deletions

View file

@ -5,7 +5,9 @@
/// <summary>
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
/// Please note that this struct is identical to OrtTensorRTProviderOptions but only to be used internally.
/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
/// OrtTensorRTProviderOptions will be deprecated over time.
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {

View file

@ -482,6 +482,9 @@ typedef struct OrtTensorRTProviderOptions {
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
// This is the legacy struct and don't add new fields here.
// For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
// For non-string field, need to create a new separate api to handle it.
} OrtTensorRTProviderOptions;
/** \brief MIGraphX Provider Options

View file

@ -360,6 +360,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn

View file

@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
return *this;
}
inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(p_, &provider_options));
return *this;
}
inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options));
return *this;

View file

@ -27,6 +27,7 @@ constexpr const char* kCachePath = "trt_engine_cache_path";
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
// add new provider option name here.
} // namespace provider_option_names
} // namespace tensorrt
@ -63,7 +64,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
.Parse(options));
.Parse(options)); // add new provider option here.
return info;
}
@ -87,6 +88,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
// add new provider option here.
};
return options;
}

View file

@ -6,6 +6,7 @@
#include <atomic>
#include "tensorrt_execution_provider.h"
#include "core/framework/provider_options.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include <string.h>
using namespace onnxruntime;
@ -48,7 +49,7 @@ struct Tensorrt_Provider : Provider {
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory(const void* provider_options) override {
auto& options = *reinterpret_cast<const OrtTensorRTProviderOptions*>(provider_options);
auto& options = *reinterpret_cast<const OrtTensorRTProviderOptionsV2*>(provider_options);
TensorrtExecutionProviderInfo info;
info.device_id = options.device_id;
info.has_user_compute_stream = options.has_user_compute_stream != 0;
@ -74,7 +75,7 @@ struct Tensorrt_Provider : Provider {
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options);
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptions*>(provider_options);
auto& trt_options = *reinterpret_cast<OrtTensorRTProviderOptionsV2*>(provider_options);
trt_options.device_id = internal_options.device_id;
trt_options.trt_max_partition_iterations = internal_options.max_partition_iterations;
trt_options.trt_min_subgraph_size = internal_options.min_subgraph_size;

View file

@ -1148,7 +1148,43 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGrap
return nullptr;
}
// Adapter to convert the legacy OrtTensorRTProviderOptions to the latest OrtTensorRTProviderOptionsV2
OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(const OrtTensorRTProviderOptions* legacy_trt_options) {
OrtTensorRTProviderOptionsV2 trt_options_converted;
trt_options_converted.device_id = legacy_trt_options->device_id;
trt_options_converted.has_user_compute_stream = legacy_trt_options->has_user_compute_stream;
trt_options_converted.user_compute_stream = legacy_trt_options->user_compute_stream;
trt_options_converted.trt_max_partition_iterations = legacy_trt_options->trt_max_partition_iterations;
trt_options_converted.trt_min_subgraph_size = legacy_trt_options->trt_min_subgraph_size;
trt_options_converted.trt_max_workspace_size = legacy_trt_options->trt_max_workspace_size;
trt_options_converted.trt_fp16_enable = legacy_trt_options->trt_fp16_enable;
trt_options_converted.trt_int8_enable = legacy_trt_options->trt_int8_enable;
trt_options_converted.trt_int8_calibration_table_name = legacy_trt_options->trt_int8_calibration_table_name;
trt_options_converted.trt_int8_use_native_calibration_table = legacy_trt_options->trt_int8_use_native_calibration_table;
trt_options_converted.trt_dla_enable = legacy_trt_options->trt_dla_enable;
trt_options_converted.trt_dla_core = legacy_trt_options->trt_dla_core;
trt_options_converted.trt_dump_subgraphs = legacy_trt_options->trt_dump_subgraphs;
trt_options_converted.trt_engine_cache_enable = legacy_trt_options->trt_engine_cache_enable;
trt_options_converted.trt_engine_cache_path = legacy_trt_options->trt_engine_cache_path;
trt_options_converted.trt_engine_decryption_enable = legacy_trt_options->trt_engine_decryption_enable;
trt_options_converted.trt_engine_decryption_lib_path = legacy_trt_options->trt_engine_decryption_lib_path;
trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build;
// Add new provider option below
// Use default value as this field is not available in OrtTensorRTProviderOptionsV
return trt_options_converted;
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* provider_options) {
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options);
if (auto* provider = s_library_tensorrt.Get())
return provider->CreateExecutionProviderFactory(&trt_options_converted);
return nullptr;
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* provider_options) {
if (auto* provider = s_library_tensorrt.Get())
return provider->CreateExecutionProviderFactory(provider_options);
@ -1420,7 +1456,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) {
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT(options, reinterpret_cast<const OrtTensorRTProviderOptions*>(tensorrt_options));
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library");
}
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out) {

View file

@ -27,6 +27,7 @@
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/provider_bridge_ort.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
@ -374,7 +375,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
std::string calibration_table, cache_path, lib_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptions params{
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,

View file

@ -28,6 +28,7 @@ struct OrtStatus {
#include "core/providers/providers.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/providers/cpu/cpu_provider_factory_creator.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#if defined(USE_CUDA) || defined(USE_ROCM)
#define BACKEND_PROC "GPU"
@ -474,6 +475,7 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor);
} // namespace python
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(int device_id);

View file

@ -1,6 +1,7 @@
#include "ort_test_session.h"
#include <core/session/onnxruntime_cxx_api.h>
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include <assert.h>
#include "providers.h"
#include "TestCase.h"
@ -209,7 +210,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n");
}
}
OrtTensorRTProviderOptions tensorrt_options;
OrtTensorRTProviderOptionsV2 tensorrt_options;
tensorrt_options.device_id = device_id;
tensorrt_options.has_user_compute_stream = 0;
tensorrt_options.user_compute_stream = nullptr;
@ -228,7 +229,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable;
tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str();
tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build;
session_options.AppendExecutionProvider_TensorRT(tensorrt_options);
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id=device_id;

View file

@ -8,6 +8,7 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/inference_session.h"
#include "core/session/ort_env.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "asserts.h"
#include <core/platform/path_lib.h>
#include "default_providers.h"
@ -591,7 +592,7 @@ TEST_P(ModelTest, Run) {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultNupharExecutionProvider()));
} else if (provider_name == "tensorrt") {
if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) {
OrtTensorRTProviderOptions params{
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,

View file

@ -54,6 +54,16 @@ std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const O
return nullptr;
}
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params) {
#ifdef USE_TENSORRT
if (auto factory = CreateExecutionProviderFactory_Tensorrt(params))
return factory->CreateProvider();
#else
ORT_UNUSED_PARAMETER(params);
#endif
return nullptr;
}
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
#ifdef USE_MIGRAPHX
OrtMIGraphXProviderOptions params{

View file

@ -24,6 +24,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVI
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptionsV2* params);
// EP for internal testing
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_InternalTesting(const std::unordered_set<std::string>& supported_ops);
@ -38,6 +39,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider(bool allow_un
//std::unique_ptr<IExecutionProvider> DefaultStvmExecutionProvider();
std::unique_ptr<IExecutionProvider> DefaultTensorrtExecutionProvider();
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params);
std::unique_ptr<IExecutionProvider> TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptionsV2* params);
std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider();
std::unique_ptr<IExecutionProvider> MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params);
std::unique_ptr<IExecutionProvider> DefaultOpenVINOExecutionProvider();