[TRT EP] Fix logic to reach cache encryption code. (#17111)

### Description
This is a followup to PR #15519 that is closed in favor of this one.


### Motivation and Context
The current implementation of TRT cache has no code execution path
possible so that an encrypted TRT engine cache could be created when
flags engine_cache_enable and engine_decryption_enable are true. This
was originally raised in issue #12551.
This commit is contained in:
simonjub 2023-08-26 23:09:03 -04:00 committed by GitHub
parent ca0159b45d
commit 4eedd3bb46
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -22,6 +22,7 @@
#include <limits>
#include <map>
#include <memory>
#include <filesystem>
// TODO: find a better way to share this
#include "core/providers/cuda/cuda_stream_handle.h"
@ -1028,6 +1029,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt");
if (engine_decryption_ == nullptr) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_));
}
}
if (int8_enable_) {
@ -2209,6 +2214,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
if (!has_dynamic_shape) {
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
std::string timing_cache_path = "";
bool engine_update = false;
@ -2231,7 +2237,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
if (engine_cache_enable_ && engine_file && !engine_update) {
if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) {
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
@ -2243,24 +2249,24 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file && !engine_update) {
} else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) {
// Decrypt engine
size_t engine_size = 0;
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_engine = std::unique_ptr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
}
} else {
// Set INT8 per tensor dynamic range
@ -2311,16 +2317,21 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
std::unique_ptr<nvinfer1::IHostMemory> serializedModel(trt_engine->serialize());
size_t engine_size = serializedModel->size();
if (engine_decryption_enable_) {
// Encrypt engine
if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
if (engine_encryption_ != nullptr) {
if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP call to engine encryption library failed");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path;
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path;
}
// serialize and save timing cache
if (timing_cache_enable_) {
@ -2465,6 +2476,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
// Prepare cache name
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
std::string timing_cache_path = "";
if (timing_cache_enable_) {
@ -2481,7 +2493,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && profile_file) {
if (engine_file && !trt_state->engine_decryption_enable && profile_file) {
// Deserialize profile
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
@ -2505,17 +2517,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
context_update = true;
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
} else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Decrypt engine
size_t engine_size = 0;
if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) {
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
@ -2526,9 +2538,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
trt_engine = trt_state->engine->get();
context_update = true;
}
@ -2670,16 +2682,21 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
std::unique_ptr<nvinfer1::IHostMemory> serializedModel(trt_engine->serialize());
size_t engine_size = serializedModel->size();
if (trt_state->engine_decryption_enable) {
// Encrypt engine
if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
if (trt_state->engine_encryption != nullptr) {
if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
// serialize and save timing cache