mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add engine decryption in TensorRT EP (#6612)
* add trt engine decryption * update document * add windows support to decryption * fix issues * remove redundant get() from engine/context check * fix issue
This commit is contained in:
parent
0b89f931d0
commit
e9d03983fc
2 changed files with 93 additions and 6 deletions
|
|
@ -21,6 +21,18 @@
|
|||
#include "flatbuffers/idl.h"
|
||||
#include "ort_trt_int8_cal_table.fbs.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define LIBTYPE HINSTANCE
|
||||
#define OPENLIB(libname) LoadLibrary(libname)
|
||||
#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn))
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define LIBTYPE void*
|
||||
#define OPENLIB(libname) dlopen((libname), RTLD_LAZY)
|
||||
#define LIBFUNC(lib, fn) dlsym((lib), (fn))
|
||||
#endif
|
||||
|
||||
#define CUDA_RETURN_IF_ERROR(expr) \
|
||||
ORT_RETURN_IF_ERROR(CUDA_CALL(expr) \
|
||||
? common::Status::OK() \
|
||||
|
|
@ -443,6 +455,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
}
|
||||
runtime_ = nvinfer1::createInferRuntime(GetTensorrtLogger());
|
||||
}
|
||||
|
||||
const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable);
|
||||
if (!engine_decryption_enable_env.empty()) {
|
||||
engine_decryption_enable_ = (std::stoi(engine_decryption_enable_env) == 0 ? false : true);
|
||||
}
|
||||
|
||||
if (engine_decryption_enable_) {
|
||||
std::string engine_decryption_lib_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath);
|
||||
LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str());
|
||||
if (handle == nullptr) {
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not open shared library from " + engine_decryption_lib_path);
|
||||
}
|
||||
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
|
||||
}
|
||||
}
|
||||
|
||||
TensorrtExecutionProvider::~TensorrtExecutionProvider() {
|
||||
|
|
@ -1098,6 +1125,29 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
engine_file.read((char*)engine_buf.get(), engine_size);
|
||||
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
if (trt_engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
|
||||
}
|
||||
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) {
|
||||
// Decrypt engine
|
||||
size_t engine_size = 0;
|
||||
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not get engine buffer size");
|
||||
}
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine decryption function decrypt");
|
||||
}
|
||||
// Deserialize engine
|
||||
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
if (trt_engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
|
||||
}
|
||||
} else {
|
||||
// Set INT8 per tensor dynamic range
|
||||
if (int8_enable_ && trt_builder->platformHasFastInt8()) {
|
||||
|
|
@ -1174,7 +1224,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
|
||||
input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &int8_enable_, &max_workspace_size_,
|
||||
trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_,
|
||||
allocator_, dynamic_range_map};
|
||||
allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_};
|
||||
*state = p.release();
|
||||
return 0;
|
||||
};
|
||||
|
|
@ -1210,7 +1260,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + ".engine";
|
||||
const std::string profile_cache_path = cache_path + ".profile";
|
||||
if ((trt_state->engine_cache_enable && trt_engine == nullptr)) {
|
||||
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
|
||||
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
|
||||
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
|
||||
if (engine_file && profile_file) {
|
||||
|
|
@ -1228,14 +1278,45 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
auto runtime_ = trt_state->runtime;
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
|
||||
runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
if (trt_state->engine->get() == nullptr) {
|
||||
if (trt_state->engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
trt_engine = trt_state->engine->get();
|
||||
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
if (trt_state->context->get() == nullptr) {
|
||||
if (trt_state->context == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
|
||||
shape_ranges = DeserializeProfile(profile_file);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
|
||||
// Decrypt engine
|
||||
size_t engine_size = 0;
|
||||
if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not get engine buffer size");
|
||||
}
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not call engine decryption function decrypt");
|
||||
}
|
||||
// Deserialize engine
|
||||
trt_state->context->reset();
|
||||
trt_state->engine->reset();
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
if (trt_state->engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
trt_engine = trt_state->engine->get();
|
||||
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
if (trt_state->context == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
|
|
@ -1408,7 +1489,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
// Build engine
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
|
||||
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
|
||||
if (trt_state->engine->get() == nullptr) {
|
||||
if (trt_state->engine == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
trt_engine = trt_state->engine->get();
|
||||
|
|
@ -1428,7 +1509,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
|
|||
// Build context
|
||||
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
if (trt_state->context->get() == nullptr) {
|
||||
if (trt_state->context == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"
|
|||
static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
|
||||
// Old env variable for backward compatibility
|
||||
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
|
||||
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
|
||||
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
|
||||
} // namespace tensorrt_env_vars
|
||||
|
||||
class TensorrtLogger : public nvinfer1::ILogger {
|
||||
|
|
@ -94,6 +96,8 @@ struct TensorrtFuncState {
|
|||
nvinfer1::IRuntime* runtime = nullptr;
|
||||
AllocatorPtr scratch_allocator;
|
||||
std::unordered_map<std::string, float> dynamic_range_map;
|
||||
bool engine_decryption_enable;
|
||||
int (*engine_decryption)(const char*, char*, size_t*);
|
||||
};
|
||||
|
||||
// Logical device representation.
|
||||
|
|
@ -142,6 +146,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
int device_id_;
|
||||
AllocatorPtr allocator_;
|
||||
mutable char model_path_[4096]; // Reserved for max path length
|
||||
bool engine_decryption_enable_ = false;
|
||||
int (*engine_decryption_)(const char*, char*, size_t*);
|
||||
|
||||
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
|
||||
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>> engines_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue