From 760828b2d48d619736f577d714fc06fa7e79bd11 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Wed, 12 May 2021 15:26:27 -0700 Subject: [PATCH] Add FromProviderOptions()/ToProviderOptions() for TensorRT EP (#7654) * integrate existed provider option configuration method * add GetProviderOptions() * fix bug * Add tests * Update test --- .../tensorrt/tensorrt_execution_provider.cc | 2 +- .../tensorrt/tensorrt_execution_provider.h | 20 ++---- .../tensorrt_execution_provider_info.cc | 69 +++++++++++++++++++ .../tensorrt_execution_provider_info.h | 29 ++++++++ .../test/python/onnxruntime_test_python.py | 54 +++++++++++++++ 5 files changed, 159 insertions(+), 15 deletions(-) create mode 100644 onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc create mode 100644 onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7ca71dd532..40d299bb82 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -384,7 +384,7 @@ TensorrtLogger& GetTensorrtLogger() { } TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, true}, device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, true}, info_(info), device_id_(info.device_id) { CUDA_CALL_THROW(cudaSetDevice(device_id_)); if (info.has_user_compute_stream) { external_stream_ = true; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 2ac979b358..858303430e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -6,6 +6,7 @@ #include "NvInfer.h" #include "NvOnnxParser.h" #include "core/platform/ort_mutex.h" +#include "tensorrt_execution_provider_info.h" namespace onnxruntime { @@ -69,20 +70,6 @@ template using unique_pointer = std::unique_ptr; }; // namespace tensorrt_ptr -// Information needed to construct trt execution providers. -struct TensorrtExecutionProviderInfo { - int device_id{0}; - bool has_user_compute_stream{false}; - void* user_compute_stream{nullptr}; - bool has_trt_options{false}; - size_t max_workspace_size{1 << 30}; - bool fp16_enable{false}; - bool int8_enable{false}; - std::string int8_calibration_table_name{""}; - bool int8_use_native_calibration_table{false}; - bool force_sequential_engine_build{false}; -}; - // Information to construct kernel function state. struct TensorrtFuncState { AllocateFunc test_allocate_func = nullptr; @@ -141,7 +128,12 @@ class TensorrtExecutionProvider : public IExecutionProvider { void* GetComputeStream() const override { return static_cast(stream_); } + ProviderOptions GetProviderOptions() const override { + return TensorrtExecutionProviderInfo::ToProviderOptions(info_); + } + private: + TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; int max_partition_iterations_ = 1000; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc new file mode 100644 index 0000000000..529f5a4e7c --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/tensorrt/tensorrt_execution_provider_info.h" + +#include "core/common/make_string.h" +#include "core/common/parse_string.h" +#include "core/framework/provider_options_utils.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace tensorrt { +namespace provider_option_names { +constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasTrtOptions = "has_trt_options"; +constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; +constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; +constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; +//constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; +} // namespace provider_option_names +} // namespace tensorrt + +TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { + TensorrtExecutionProviderInfo info{}; + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( + tensorrt::provider_option_names::kDeviceId, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + int num_devices{}; + ORT_RETURN_IF_NOT( + CUDA_CALL(cudaGetDeviceCount(&num_devices)), + "cudaGetDeviceCount() failed."); + ORT_RETURN_IF_NOT( + 0 <= info.device_id && info.device_id < num_devices, + "Invalid device ID: ", info.device_id, + ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kHasTrtOptions, info.has_trt_options) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) + //.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) + .Parse(options)); + + return info; +} + +ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtExecutionProviderInfo& info) { + const ProviderOptions options{ + {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {tensorrt::provider_option_names::kHasTrtOptions, MakeStringWithClassicLocale(info.has_trt_options)}, + {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, + {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, + {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, + {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, + //{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, + }; + + return options; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h new file mode 100644 index 0000000000..7395f52a25 --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/ortdevice.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" + +namespace onnxruntime { +// Information needed to construct trt execution providers. +struct TensorrtExecutionProviderInfo { + int device_id{0}; + bool has_user_compute_stream{false}; + void* user_compute_stream{nullptr}; + bool has_trt_options{false}; + size_t max_workspace_size{1 << 30}; + bool fp16_enable{false}; + bool int8_enable{false}; + std::string int8_calibration_table_name{""}; + bool int8_use_native_calibration_table{false}; + bool force_sequential_engine_build{false}; + + static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); + static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); +}; +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index f34e458d5b..53c320efe5 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -62,6 +62,60 @@ class TestInferenceSession(unittest.TestCase): self.assertEqual(['CPUExecutionProvider'], sess.get_providers()) def testSetProvidersWithOptions(self): + if 'TensorrtExecutionProvider' in onnxrt.get_available_providers(): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) + self.assertIn('TensorrtExecutionProvider', sess.get_providers()) + + options = sess.get_provider_options() + option = options['TensorrtExecutionProvider'] + self.assertIn('device_id', option) + self.assertIn('has_trt_options', option) + self.assertIn('trt_max_workspace_size', option) + self.assertIn('trt_fp16_enable', option) + self.assertIn('trt_int8_enable', option) + self.assertIn('trt_int8_calibration_table_name', option) + self.assertIn('trt_int8_use_native_calibration_table', option) + + ori_max_workspace_size = option['trt_max_workspace_size'] + new_max_workspace_size = int(ori_max_workspace_size) // 2 + + option = {} + option['trt_max_workspace_size'] = new_max_workspace_size + trt_options = "true" + option['has_trt_options'] = trt_options + fp16_enable = "true" + option['trt_fp16_enable'] = fp16_enable + int8_enable = "false" + option['trt_int8_enable'] = int8_enable + calib_table_name = '/home/onnxruntime/table.flatbuffers' + option['trt_int8_calibration_table_name'] = calib_table_name + int8_use_native_calibration_table = "true" + option['trt_int8_use_native_calibration_table'] = int8_use_native_calibration_table + sess.set_providers(['TensorrtExecutionProvider'], [option]) + + options = sess.get_provider_options() + option = options['TensorrtExecutionProvider'] + self.assertEqual(option['trt_max_workspace_size'], str(new_max_workspace_size)) + self.assertEqual(option['trt_int8_calibration_table_name'], str(calib_table_name)) + self.assertEqual(option['has_trt_options'], '1') + self.assertEqual(option['trt_fp16_enable'], '1') + self.assertEqual(option['trt_int8_enable'], '0') + self.assertEqual(option['trt_int8_use_native_calibration_table'], '1') + + + # We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability + + ''' + int8_use_native_calibration_table = "false" + option['trt_int8_use_native_calibration_table'] = int8_use_native_calibration_table + int8_enable = "true" + option['trt_int8_enable'] = int8_enable + calib_table_name = '/home/onnxruntime/table.flatbuffers' # this file is not existed + option['trt_int8_calibration_table_name'] = calib_table_name + with self.assertRaises(RuntimeError): + sess.set_providers(['TensorrtExecutionProvider'], [option]) + ''' + if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): import sys import ctypes