mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
[TRT EP] Fix logic to reach cache encryption code. (#17111)
### Description This is a followup to PR #15519 that is closed in favor of this one. ### Motivation and Context The current implementation of TRT cache has no code execution path possible so that an encrypted TRT engine cache could be created when flags engine_cache_enable and engine_decryption_enable are true. This was originally raised in issue #12551.
This commit is contained in:
parent
ca0159b45d
commit
4eedd3bb46
1 changed files with 39 additions and 22 deletions
|
|
@ -22,6 +22,7 @@
|
|||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <filesystem>
|
||||
// TODO: find a better way to share this
|
||||
#include "core/providers/cuda/cuda_stream_handle.h"
|
||||
|
||||
|
|
@ -1028,6 +1029,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
}
|
||||
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
|
||||
engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt");
|
||||
if (engine_decryption_ == nullptr) {
|
||||
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_));
|
||||
}
|
||||
}
|
||||
|
||||
if (int8_enable_) {
|
||||
|
|
@ -2209,6 +2214,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
if (!has_dynamic_shape) {
|
||||
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
|
||||
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
|
||||
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
|
||||
std::string timing_cache_path = "";
|
||||
bool engine_update = false;
|
||||
|
|
@ -2231,7 +2237,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
}
|
||||
|
||||
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
|
||||
if (engine_cache_enable_ && engine_file && !engine_update) {
|
||||
if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) {
|
||||
engine_file.seekg(0, std::ios::end);
|
||||
size_t engine_size = engine_file.tellg();
|
||||
engine_file.seekg(0, std::ios::beg);
|
||||
|
|
@ -2243,24 +2249,24 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
|
||||
}
|
||||
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file && !engine_update) {
|
||||
} else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) {
|
||||
// Decrypt engine
|
||||
size_t engine_size = 0;
|
||||
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
|
||||
if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not get engine buffer size");
|
||||
}
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
|
||||
if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine decryption function decrypt");
|
||||
}
|
||||
// Deserialize engine
|
||||
trt_engine = std::unique_ptr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
|
||||
if (trt_engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
|
||||
}
|
||||
} else {
|
||||
// Set INT8 per tensor dynamic range
|
||||
|
|
@ -2311,16 +2317,21 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
std::unique_ptr<nvinfer1::IHostMemory> serializedModel(trt_engine->serialize());
|
||||
size_t engine_size = serializedModel->size();
|
||||
if (engine_decryption_enable_) {
|
||||
// Encrypt engine
|
||||
if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine encryption function encrypt");
|
||||
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
|
||||
if (engine_encryption_ != nullptr) {
|
||||
if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP call to engine encryption library failed");
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
|
||||
} else {
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
|
||||
}
|
||||
} else {
|
||||
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
|
||||
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path;
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path;
|
||||
}
|
||||
// serialize and save timing cache
|
||||
if (timing_cache_enable_) {
|
||||
|
|
@ -2465,6 +2476,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
// Prepare cache name
|
||||
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
|
||||
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
|
||||
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
|
||||
std::string timing_cache_path = "";
|
||||
if (timing_cache_enable_) {
|
||||
|
|
@ -2481,7 +2493,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
|
||||
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
|
||||
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
|
||||
if (engine_file && profile_file) {
|
||||
if (engine_file && !trt_state->engine_decryption_enable && profile_file) {
|
||||
// Deserialize profile
|
||||
shape_ranges = DeserializeProfileV2(profile_file);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
|
||||
|
|
@ -2505,17 +2517,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
trt_engine = trt_state->engine->get();
|
||||
context_update = true;
|
||||
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
|
||||
} else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
|
||||
shape_ranges = DeserializeProfileV2(profile_file);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
|
||||
// Decrypt engine
|
||||
size_t engine_size = 0;
|
||||
if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) {
|
||||
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not get engine buffer size");
|
||||
}
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
|
||||
if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine decryption function decrypt");
|
||||
}
|
||||
|
|
@ -2526,9 +2538,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
if (*(trt_state->engine) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
|
||||
trt_engine = trt_state->engine->get();
|
||||
context_update = true;
|
||||
}
|
||||
|
|
@ -2670,16 +2682,21 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
std::unique_ptr<nvinfer1::IHostMemory> serializedModel(trt_engine->serialize());
|
||||
size_t engine_size = serializedModel->size();
|
||||
if (trt_state->engine_decryption_enable) {
|
||||
// Encrypt engine
|
||||
if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine encryption function encrypt");
|
||||
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
|
||||
if (trt_state->engine_encryption != nullptr) {
|
||||
if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine encryption function encrypt");
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
|
||||
} else {
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
|
||||
}
|
||||
} else {
|
||||
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
|
||||
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
|
||||
}
|
||||
|
||||
// serialize and save timing cache
|
||||
|
|
|
|||
Loading…
Reference in a new issue