[TRT EP] Fix multithreading bug of getting the corrupted trt engine instance (#17507)

Revert to the old TRT EP behavior of securing the whole compute_function
by lock_guard.

Current TRT EP which only puts lock_guard around a critical section
(obvious wrong) inside compute_function.
The issue can happen where one thread is updating the engine in
compute_function whereas another thread still accesses the
stale/corrupted engine instance in compute_function, for example, the
code outside the critical section, `int total_bindings =
trt_engine->getNbBindings()`.

So, make the whole compute_function the critical section should be okay.
This commit is contained in:
Chi Lo 2023-09-12 07:37:45 -07:00 committed by GitHub
parent db558ef9b4
commit aa5e36456a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2433,6 +2433,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
Ort::KernelContext ctx(context);
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
// The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine,
// save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading.
// More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
@ -2475,239 +2480,232 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
timing_cache_path = GetTimingCachePath(cache_path_, prop);
}
// Following block is the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine,
// save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading.
// More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
{
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
// Load serialized engine
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && !trt_state->engine_decryption_enable && profile_file) {
// Deserialize profile
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Load serialized engine
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && !trt_state->engine_decryption_enable && profile_file) {
// Deserialize profile
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Prepare buffer
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
// Deserialize engine
// Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
trt_state->engine->reset();
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(
trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
context_update = true;
} else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Decrypt engine
size_t engine_size = 0;
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
// Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
trt_state->engine->reset();
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
trt_engine = trt_state->engine->get();
context_update = true;
}
}
// Check and update shape ranges for dynamic shape inputs.
for (int i = 0, end = num_inputs; i < end; ++i) {
auto input = trt_state->network->get()->getInput(i);
const std::string& input_name = input->getName();
input_names.insert(input_name);
// If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved.
// TRT EP will help determine the min/max/opt profile values based on current input tensor value.
if (shape_ranges.find(input_name) != shape_ranges.end()) {
auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles.");
}
}
}
// Regenerate engine
if (engine_update) {
// Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior.
if (GetPerThreadContext().IsTensorRTContextInMap(fused_node_name)) {
GetPerThreadContext().ResetTensorRTContext(fused_node_name);
}
// Prepare buffer
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
// Deserialize engine
// Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
trt_state->engine->reset();
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
for (auto trt_profile : trt_profiles) {
trt_config->addOptimizationProfile(trt_profile);
}
// Set INT8 Per Tensor Dynamic range
if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) {
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range.");
}
}
// Set precision
if (trt_state->fp16_enable && trt_state->int8_enable) {
trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
} else if (trt_state->fp16_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
} else if (trt_state->int8_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
}
// Set DLA (DLA can only run with FP16 or INT8)
if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
trt_config->setDLACore(trt_state->dla_core);
}
// enable sparse weights
if (trt_state->sparsity_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
}
// enable builder heuristics
if (trt_state->build_heuristics_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled";
}
#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8
// switch optimizaion level
if (trt_state->builder_optimization_level != 3) {
trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_;
}
// limit auxiliary streams
if (trt_state->auxiliary_streams >= 0) {
trt_config->setMaxAuxStreams(trt_state->auxiliary_streams);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams;
}
#else
if (trt_state->builder_optimization_level != 3) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!";
}
if (trt_state->auxiliary_streams >= 0) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
}
#endif
// limit used tactic sources
if (trt_state->filter_tactic_sources) {
nvinfer1::TacticSources tactics = trt_config->getTacticSources();
tactics |= trt_state->tactic_sources;
trt_config->setTacticSources(tactics);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics;
}
// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr<nvinfer1::ITimingCache> timing_cache = nullptr;
if (trt_state->timing_cache_enable) {
std::vector<char> loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
timing_cache.reset(trt_config->createTimingCache(static_cast<const void*>(loaded_timing_cache.data()), loaded_timing_cache.size()));
if (timing_cache == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not create timing cache: " + timing_cache_path);
}
trt_config->setTimingCache(*timing_cache, force_timing_cache_match_);
if (detailed_build_log_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path;
}
}
// Build engine
{
auto lock = GetApiLock();
std::chrono::steady_clock::time_point engine_build_start;
if (detailed_build_log_) {
engine_build_start = std::chrono::steady_clock::now();
}
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
if (detailed_build_log_) {
auto engine_build_stop = std::chrono::steady_clock::now();
LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast<std::chrono::milliseconds>(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
}
}
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(
trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
if (trt_state->engine_cache_enable) {
// Serialize engine profile
SerializeProfileV2(profile_cache_path, shape_ranges);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
// Serialize engine
std::unique_ptr<nvinfer1::IHostMemory> serializedModel(trt_engine->serialize());
size_t engine_size = serializedModel->size();
if (trt_state->engine_decryption_enable) {
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
if (trt_state->engine_encryption != nullptr) {
if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
context_update = true;
} else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Decrypt engine
size_t engine_size = 0;
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
// serialize and save timing cache
if (trt_state->timing_cache_enable) {
auto timing_cache = trt_config->getTimingCache();
std::unique_ptr<nvinfer1::IHostMemory> timingCacheHostData{timing_cache->serialize()};
if (timingCacheHostData == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not serialize timing cache: " + timing_cache_path);
}
saveTimingCacheFile(timing_cache_path, timingCacheHostData.get());
if (detailed_build_log_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path;
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
// Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
trt_state->engine->reset();
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
trt_engine = trt_state->engine->get();
context_update = true;
}
}
// Check and update shape ranges for dynamic shape inputs.
for (int i = 0, end = num_inputs; i < end; ++i) {
auto input = trt_state->network->get()->getInput(i);
const std::string& input_name = input->getName();
input_names.insert(input_name);
// If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved.
// TRT EP will help determine the min/max/opt profile values based on current input tensor value.
if (shape_ranges.find(input_name) != shape_ranges.end()) {
auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles.");
}
}
}
// Regenerate engine
if (engine_update) {
// Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior.
if (GetPerThreadContext().IsTensorRTContextInMap(fused_node_name)) {
GetPerThreadContext().ResetTensorRTContext(fused_node_name);
}
trt_state->engine->reset();
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
for (auto trt_profile : trt_profiles) {
trt_config->addOptimizationProfile(trt_profile);
}
// Set INT8 Per Tensor Dynamic range
if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) {
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range.");
}
}
// Set precision
if (trt_state->fp16_enable && trt_state->int8_enable) {
trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
} else if (trt_state->fp16_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
} else if (trt_state->int8_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
}
// Set DLA (DLA can only run with FP16 or INT8)
if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
trt_config->setDLACore(trt_state->dla_core);
}
// enable sparse weights
if (trt_state->sparsity_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
}
// enable builder heuristics
if (trt_state->build_heuristics_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled";
}
#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8
// switch optimizaion level
if (trt_state->builder_optimization_level != 3) {
trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_;
}
// limit auxiliary streams
if (trt_state->auxiliary_streams >= 0) {
trt_config->setMaxAuxStreams(trt_state->auxiliary_streams);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams;
}
#else
if (trt_state->builder_optimization_level != 3) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!";
}
if (trt_state->auxiliary_streams >= 0) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
}
#endif
// limit used tactic sources
if (trt_state->filter_tactic_sources) {
nvinfer1::TacticSources tactics = trt_config->getTacticSources();
tactics |= trt_state->tactic_sources;
trt_config->setTacticSources(tactics);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics;
}
// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr<nvinfer1::ITimingCache> timing_cache = nullptr;
if (trt_state->timing_cache_enable) {
std::vector<char> loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
timing_cache.reset(trt_config->createTimingCache(static_cast<const void*>(loaded_timing_cache.data()), loaded_timing_cache.size()));
if (timing_cache == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not create timing cache: " + timing_cache_path);
}
trt_config->setTimingCache(*timing_cache, force_timing_cache_match_);
if (detailed_build_log_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path;
}
}
// Build engine
{
auto lock = GetApiLock();
std::chrono::steady_clock::time_point engine_build_start;
if (detailed_build_log_) {
engine_build_start = std::chrono::steady_clock::now();
}
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
if (detailed_build_log_) {
auto engine_build_stop = std::chrono::steady_clock::now();
LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast<std::chrono::milliseconds>(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
}
}
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
trt_engine = trt_state->engine->get();
if (trt_state->engine_cache_enable) {
// Serialize engine profile
SerializeProfileV2(profile_cache_path, shape_ranges);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
// Serialize engine
std::unique_ptr<nvinfer1::IHostMemory> serializedModel(trt_engine->serialize());
size_t engine_size = serializedModel->size();
if (trt_state->engine_decryption_enable) {
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
if (trt_state->engine_encryption != nullptr) {
if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
}
// serialize and save timing cache
if (trt_state->timing_cache_enable) {
auto timing_cache = trt_config->getTimingCache();
std::unique_ptr<nvinfer1::IHostMemory> timingCacheHostData{timing_cache->serialize()};
if (timingCacheHostData == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not serialize timing cache: " + timing_cache_path);
}
saveTimingCacheFile(timing_cache_path, timingCacheHostData.get());
if (detailed_build_log_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path;
}
}
context_update = true;
}
// Build execution context if either of the following conditions is true:
// (1) The engine is built or updated by this thread.
// (2) The first inference run for this thread where there is no IExecutionContext object yet.