TRT EP race condition fix during ep compile time (#13356)

### Description
TRT EP has the chance to encounter race condition when multiple threads
are doing engine serialization/deserialization during EP compile time.
Let's say one thread is serializing the engine and has not yet
completely written all the data to file, and at this moment, another
thread finds the engine file is existed and begins to deserialize the
engine, it will end up deserialize the corrupt file.
The fix is to put a lock around engine deserialization/serialization,
engine build and context build.



### Motivation and Context
The TensorRT EP Windows CI sometimes fails because of
`TensorrtExecutionProviderTest.MultiThreadsTestWithOneSessionSingleThreadInference`
unit test fails (This PR changes the name to
SessionCreationWithMultiThreadsAndInferenceWithMultiThreads). It's
highly possible due to race condition.
The TensorRT CI failure also been reported
[here](https://github.com/microsoft/onnxruntime/issues/13030)
This commit is contained in:
Chi Lo 2022-10-19 11:19:10 -07:00 committed by GitHub
parent 565da71275
commit 86c5c07ea4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 66 deletions

View file

@ -1325,72 +1325,74 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
if (!has_dynamic_shape) {
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
if (engine_cache_enable_ && engine_file) {
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) {
// Decrypt engine
size_t engine_size = 0;
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
} else {
// Set INT8 per tensor dynamic range
if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) {
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
}
}
{
// ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading.
auto lock = GetApiLock();
// Build engine
{
auto lock = GetApiLock();
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
}
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build engine for fused node: " + fused_node.Name());
}
if (engine_cache_enable_) {
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
size_t engine_size = serializedModel->size();
if (engine_decryption_enable_) {
// Encrypt engine
if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
if (engine_cache_enable_ && engine_file) {
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) {
// Decrypt engine
size_t engine_size = 0;
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
} else {
// Set INT8 per tensor dynamic range
if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) {
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
}
}
// Build engine
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build engine for fused node: " + fused_node.Name());
}
if (engine_cache_enable_) {
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
size_t engine_size = serializedModel->size();
if (engine_decryption_enable_) {
// Encrypt engine
if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
}
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
}

View file

@ -247,7 +247,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string
th.join();
}
TEST(TensorrtExecutionProviderTest, MultiThreadsTestWithOneSessionSingleThreadInference) {
TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) {
std::vector<std::thread> threads;
std::string model_name = "trt_execution_provider_multithreading_test.onnx";
std::string graph_name = "multithreading_test";
@ -264,7 +264,7 @@ TEST(TensorrtExecutionProviderTest, MultiThreadsTestWithOneSessionSingleThreadIn
th.join();
}
TEST(TensorrtExecutionProviderTest, MultiThreadsTestWithOneSessionMultiThreadsInference) {
TEST(TensorrtExecutionProviderTest, SessionCreationWithSingleThreadAndInferenceWithMultiThreads) {
std::string model_name = "trt_execution_provider_multithreading_test.onnx";
std::string graph_name = "multithreading_test";
std::string sess_log_id = "TRTEPMultiThreadingTestWithOneSessionMultiThreads";