mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Add FromProviderOptions()/ToProviderOptions() for TensorRT EP (#7654)
* integrate existed provider option configuration method * add GetProviderOptions() * fix bug * Add tests * Update test
This commit is contained in:
parent
1c7e683a95
commit
760828b2d4
5 changed files with 159 additions and 15 deletions
|
|
@ -384,7 +384,7 @@ TensorrtLogger& GetTensorrtLogger() {
|
|||
}
|
||||
|
||||
TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info)
|
||||
: IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, true}, device_id_(info.device_id) {
|
||||
: IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, true}, info_(info), device_id_(info.device_id) {
|
||||
CUDA_CALL_THROW(cudaSetDevice(device_id_));
|
||||
if (info.has_user_compute_stream) {
|
||||
external_stream_ = true;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "NvInfer.h"
|
||||
#include "NvOnnxParser.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "tensorrt_execution_provider_info.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -69,20 +70,6 @@ template <typename T>
|
|||
using unique_pointer = std::unique_ptr<T, TensorrtInferDeleter>;
|
||||
}; // namespace tensorrt_ptr
|
||||
|
||||
// Information needed to construct trt execution providers.
|
||||
struct TensorrtExecutionProviderInfo {
|
||||
int device_id{0};
|
||||
bool has_user_compute_stream{false};
|
||||
void* user_compute_stream{nullptr};
|
||||
bool has_trt_options{false};
|
||||
size_t max_workspace_size{1 << 30};
|
||||
bool fp16_enable{false};
|
||||
bool int8_enable{false};
|
||||
std::string int8_calibration_table_name{""};
|
||||
bool int8_use_native_calibration_table{false};
|
||||
bool force_sequential_engine_build{false};
|
||||
};
|
||||
|
||||
// Information to construct kernel function state.
|
||||
struct TensorrtFuncState {
|
||||
AllocateFunc test_allocate_func = nullptr;
|
||||
|
|
@ -141,7 +128,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
|
||||
void* GetComputeStream() const override { return static_cast<void*>(stream_); }
|
||||
|
||||
ProviderOptions GetProviderOptions() const override {
|
||||
return TensorrtExecutionProviderInfo::ToProviderOptions(info_);
|
||||
}
|
||||
|
||||
private:
|
||||
TensorrtExecutionProviderInfo info_;
|
||||
bool external_stream_ = false;
|
||||
cudaStream_t stream_ = nullptr;
|
||||
int max_partition_iterations_ = 1000;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/tensorrt/tensorrt_execution_provider_info.h"
|
||||
|
||||
#include "core/common/make_string.h"
|
||||
#include "core/common/parse_string.h"
|
||||
#include "core/framework/provider_options_utils.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace tensorrt {
|
||||
namespace provider_option_names {
|
||||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kHasTrtOptions = "has_trt_options";
|
||||
constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size";
|
||||
constexpr const char* kFp16Enable = "trt_fp16_enable";
|
||||
constexpr const char* kInt8Enable = "trt_int8_enable";
|
||||
constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name";
|
||||
constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table";
|
||||
//constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
|
||||
} // namespace provider_option_names
|
||||
} // namespace tensorrt
|
||||
|
||||
TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
|
||||
TensorrtExecutionProviderInfo info{};
|
||||
ORT_THROW_IF_ERROR(
|
||||
ProviderOptionsParser{}
|
||||
.AddValueParser(
|
||||
tensorrt::provider_option_names::kDeviceId,
|
||||
[&info](const std::string& value_str) -> Status {
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
|
||||
int num_devices{};
|
||||
ORT_RETURN_IF_NOT(
|
||||
CUDA_CALL(cudaGetDeviceCount(&num_devices)),
|
||||
"cudaGetDeviceCount() failed.");
|
||||
ORT_RETURN_IF_NOT(
|
||||
0 <= info.device_id && info.device_id < num_devices,
|
||||
"Invalid device ID: ", info.device_id,
|
||||
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
|
||||
return Status::OK();
|
||||
})
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kHasTrtOptions, info.has_trt_options)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table)
|
||||
//.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
|
||||
.Parse(options));
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) {
|
||||
const ProviderOptions options{
|
||||
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{tensorrt::provider_option_names::kHasTrtOptions, MakeStringWithClassicLocale(info.has_trt_options)},
|
||||
{tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)},
|
||||
{tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
|
||||
{tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
|
||||
{tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)},
|
||||
{tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)},
|
||||
//{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
|
||||
};
|
||||
|
||||
return options;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "core/framework/ortdevice.h"
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
// Information needed to construct trt execution providers.
|
||||
struct TensorrtExecutionProviderInfo {
|
||||
int device_id{0};
|
||||
bool has_user_compute_stream{false};
|
||||
void* user_compute_stream{nullptr};
|
||||
bool has_trt_options{false};
|
||||
size_t max_workspace_size{1 << 30};
|
||||
bool fp16_enable{false};
|
||||
bool int8_enable{false};
|
||||
std::string int8_calibration_table_name{""};
|
||||
bool int8_use_native_calibration_table{false};
|
||||
bool force_sequential_engine_build{false};
|
||||
|
||||
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -62,6 +62,60 @@ class TestInferenceSession(unittest.TestCase):
|
|||
self.assertEqual(['CPUExecutionProvider'], sess.get_providers())
|
||||
|
||||
def testSetProvidersWithOptions(self):
|
||||
if 'TensorrtExecutionProvider' in onnxrt.get_available_providers():
|
||||
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"))
|
||||
self.assertIn('TensorrtExecutionProvider', sess.get_providers())
|
||||
|
||||
options = sess.get_provider_options()
|
||||
option = options['TensorrtExecutionProvider']
|
||||
self.assertIn('device_id', option)
|
||||
self.assertIn('has_trt_options', option)
|
||||
self.assertIn('trt_max_workspace_size', option)
|
||||
self.assertIn('trt_fp16_enable', option)
|
||||
self.assertIn('trt_int8_enable', option)
|
||||
self.assertIn('trt_int8_calibration_table_name', option)
|
||||
self.assertIn('trt_int8_use_native_calibration_table', option)
|
||||
|
||||
ori_max_workspace_size = option['trt_max_workspace_size']
|
||||
new_max_workspace_size = int(ori_max_workspace_size) // 2
|
||||
|
||||
option = {}
|
||||
option['trt_max_workspace_size'] = new_max_workspace_size
|
||||
trt_options = "true"
|
||||
option['has_trt_options'] = trt_options
|
||||
fp16_enable = "true"
|
||||
option['trt_fp16_enable'] = fp16_enable
|
||||
int8_enable = "false"
|
||||
option['trt_int8_enable'] = int8_enable
|
||||
calib_table_name = '/home/onnxruntime/table.flatbuffers'
|
||||
option['trt_int8_calibration_table_name'] = calib_table_name
|
||||
int8_use_native_calibration_table = "true"
|
||||
option['trt_int8_use_native_calibration_table'] = int8_use_native_calibration_table
|
||||
sess.set_providers(['TensorrtExecutionProvider'], [option])
|
||||
|
||||
options = sess.get_provider_options()
|
||||
option = options['TensorrtExecutionProvider']
|
||||
self.assertEqual(option['trt_max_workspace_size'], str(new_max_workspace_size))
|
||||
self.assertEqual(option['trt_int8_calibration_table_name'], str(calib_table_name))
|
||||
self.assertEqual(option['has_trt_options'], '1')
|
||||
self.assertEqual(option['trt_fp16_enable'], '1')
|
||||
self.assertEqual(option['trt_int8_enable'], '0')
|
||||
self.assertEqual(option['trt_int8_use_native_calibration_table'], '1')
|
||||
|
||||
|
||||
# We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability
|
||||
|
||||
'''
|
||||
int8_use_native_calibration_table = "false"
|
||||
option['trt_int8_use_native_calibration_table'] = int8_use_native_calibration_table
|
||||
int8_enable = "true"
|
||||
option['trt_int8_enable'] = int8_enable
|
||||
calib_table_name = '/home/onnxruntime/table.flatbuffers' # this file is not existed
|
||||
option['trt_int8_calibration_table_name'] = calib_table_name
|
||||
with self.assertRaises(RuntimeError):
|
||||
sess.set_providers(['TensorrtExecutionProvider'], [option])
|
||||
'''
|
||||
|
||||
if 'CUDAExecutionProvider' in onnxrt.get_available_providers():
|
||||
import sys
|
||||
import ctypes
|
||||
|
|
|
|||
Loading…
Reference in a new issue