Add multithreading test and put a lock on nvinfer1::createInferRuntime() for TRT EP (#10714)

* Add multithread unit test and put lock on library call

* update code

* remove debug code

* add comment

* add one session multi-threads inference

* Put lock for build engine all the time

* Update naming and comment

* remove unnecessary lock

* Revert "remove unnecessary lock"

This reverts commit 9c2317b1d2273dec0ebdeb52160bc757839e5edc.
This commit is contained in:
Chi Lo 2022-03-16 09:19:33 -07:00 committed by GitHub
parent ce204d0744
commit 42d7112f03
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 200 additions and 14 deletions

View file

@ -251,6 +251,11 @@ TensorrtLogger& GetTensorrtLogger() {
return trt_logger;
}
std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetApiLock() const {
static OrtMutex singleton;
return std::unique_lock<OrtMutex>(singleton);
}
TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, true}, info_(info), device_id_(info.device_id) {
CUDA_CALL_THROW(cudaSetDevice(device_id_));
@ -396,7 +401,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
throw std::runtime_error("Failed to create directory " + cache_path_);
}
}
runtime_ = tensorrt_ptr::unique_pointer<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger()));
{
auto lock = GetApiLock();
runtime_ = tensorrt_ptr::unique_pointer<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger()));
}
}
if (engine_decryption_enable_) {
@ -1001,13 +1009,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
return result;
}
std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetEngineBuildLock() const {
static OrtMutex singleton;
// Acquire a lock only when force_sequential_engine_build_ is true;
return force_sequential_engine_build_ ? std::unique_lock<OrtMutex>(singleton) : std::unique_lock<OrtMutex>();
}
common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {
@ -1197,7 +1198,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Build engine
{
auto lock = GetEngineBuildLock();
auto lock = GetApiLock();
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
}
if (trt_engine == nullptr) {
@ -1538,7 +1539,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Build engine
{
auto lock = GetEngineBuildLock();
auto lock = GetApiLock();
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
}

View file

@ -194,10 +194,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {
void RemoveTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph) const;
/**
Get a unique_lock object to control the concurrency behavior of TensorRT engine building. When force_sequential_engine_build
is set to true, the lock object is associated with a mutex shared across all providers to enforce sequential engine build.
Otherwise, the constructed unique_lock is not associated with any mutex therefore no locking/unlocking will happen.
Get a unique_lock object to control the concurrency behavior.
Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading)
should be protected by a lock when invoked by multiple threads concurrently.
*/
std::unique_lock<OrtMutex> GetEngineBuildLock() const;
std::unique_lock<OrtMutex> GetApiLock() const;
};
} // namespace onnxruntime

View file

@ -10,6 +10,7 @@
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "core/providers/tensorrt/tensorrt_execution_provider_utils.h"
#include <string>
#include <thread>
using namespace std;
using namespace ONNX_NAMESPACE;
@ -87,6 +88,190 @@ void CreateBaseModel(std::string model_name, std::string graph_name, std::vector
status = onnxruntime::Model::Save(model, model_name);
}
void RunSession(InferenceSession& session_object,
RunOptions& run_options,
NameMLValMap& feeds,
std::vector<std::string> output_names,
std::vector<int64_t> expected_dims,
std::vector<float> expected_values) {
std::vector<OrtValue> fetches;
auto status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());
VerifyOutputs(fetches, expected_dims, expected_values);
}
void RunWithOneSessionSingleThreadInference(std::string model_name, std::string sess_log_id) {
SessionOptions so;
so.session_logid = sess_log_id;
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};
auto allocator_manager = session_object.GetAllocatorManager();
auto cuda_provider = DefaultCudaExecutionProvider();
cuda_provider->RegisterAllocator(allocator_manager);
auto cpu_allocator = cuda_provider->GetAllocator(0, OrtMemTypeCPU);
std::vector<int64_t> dims_mul_x = {1, 3, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
OrtValue ml_value_x;
CreateMLValue<float>(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_y);
OrtValue ml_value_z;
CreateMLValue<float>(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
feeds.insert(std::make_pair("Z", ml_value_z));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("M");
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0};
params.trt_engine_cache_enable = 1;
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
auto status = session_object.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
// run inference
// TRT engine will be created and cached
// TRT profile will be created and cached only for dynamic input shape
// Data in profile,
// X: 1, 3, 3, 2, 2, 2
// Y: 1, 3, 3, 2, 2, 2
// Z: 1, 3, 3, 2, 2, 2
RunSession(session_object, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}
void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string sess_log_id) {
SessionOptions so;
so.session_logid = sess_log_id;
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};
auto allocator_manager = session_object.GetAllocatorManager();
auto cuda_provider = DefaultCudaExecutionProvider();
cuda_provider->RegisterAllocator(allocator_manager);
auto cpu_allocator = cuda_provider->GetAllocator(0, OrtMemTypeCPU);
std::vector<int64_t> dims_mul_x = {1, 3, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
OrtValue ml_value_x;
CreateMLValue<float>(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_y);
OrtValue ml_value_z;
CreateMLValue<float>(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
feeds.insert(std::make_pair("Z", ml_value_z));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("M");
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
OrtTensorRTProviderOptionsV2 params{
0,
0,
nullptr,
1000,
1,
1 << 30,
0,
0,
nullptr,
0,
0,
0,
0,
0,
nullptr,
0,
nullptr,
0};
params.trt_engine_cache_enable = 1;
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
auto status = session_object.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
// run inference with multi-threads
// TRT engine will be created and cached
// TRT profile will be created and cached only for dynamic input shape
// Data in profile,
// X: 1, 3, 3, 2, 2, 2
// Y: 1, 3, 3, 2, 2, 2
// Z: 1, 3, 3, 2, 2, 2
std::vector<std::thread> threads;
int num_thread = 5;
for (int i = 0; i < num_thread; ++i)
threads.push_back(std::thread(RunSession, std::ref(session_object), std::ref(run_options), std::ref(feeds), std::ref(output_names), std::ref(expected_dims_mul_m), std::ref(expected_values_mul_m)));
for (auto& th : threads)
th.join();
}
TEST(TensorrtExecutionProviderTest, MultiThreadsTestWithOneSessionSingleThreadInference) {
std::vector<std::thread> threads;
std::string model_name = "trt_execution_provider_multithreading_test.onnx";
std::string graph_name = "multithreading_test";
std::string sess_log_id = "TRTEPMultiThreadingTestWithOneSessionSingleThread";
std::vector<int> dims = {1, 3, 2};
int num_thread = 5;
CreateBaseModel(model_name, graph_name, dims);
for (int i = 0; i < num_thread; ++i)
threads.push_back(std::thread(RunWithOneSessionSingleThreadInference, model_name, sess_log_id));
for (auto& th : threads)
th.join();
}
TEST(TensorrtExecutionProviderTest, MultiThreadsTestWithOneSessionMultiThreadsInference) {
std::string model_name = "trt_execution_provider_multithreading_test.onnx";
std::string graph_name = "multithreading_test";
std::string sess_log_id = "TRTEPMultiThreadingTestWithOneSessionMultiThreads";
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model_name, graph_name, dims);
RunWithOneSessionMultiThreadsInference(model_name, sess_log_id);
}
TEST_P(TensorrtExecutionProviderCacheTest, Run) {
// GetParam() returns the parameter of following format:
// ##cache type##_##input shape type##