Add engine decryption in TensorRT EP (#6612)

* add trt engine decryption

* update document

* add windows support to decryption

* fix issues

* remove redundant get() from engine/context check

* fix issue
This commit is contained in:
stevenlix 2021-02-09 00:46:14 -08:00 committed by GitHub
parent 0b89f931d0
commit e9d03983fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 6 deletions

View file

@ -21,6 +21,18 @@
#include "flatbuffers/idl.h"
#include "ort_trt_int8_cal_table.fbs.h"
#ifdef _WIN32
#include <windows.h>
#define LIBTYPE HINSTANCE
#define OPENLIB(libname) LoadLibrary(libname)
#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn))
#else
#include <dlfcn.h>
#define LIBTYPE void*
#define OPENLIB(libname) dlopen((libname), RTLD_LAZY)
#define LIBFUNC(lib, fn) dlsym((lib), (fn))
#endif
#define CUDA_RETURN_IF_ERROR(expr) \
ORT_RETURN_IF_ERROR(CUDA_CALL(expr) \
? common::Status::OK() \
@ -443,6 +455,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
runtime_ = nvinfer1::createInferRuntime(GetTensorrtLogger());
}
const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable);
if (!engine_decryption_enable_env.empty()) {
engine_decryption_enable_ = (std::stoi(engine_decryption_enable_env) == 0 ? false : true);
}
if (engine_decryption_enable_) {
std::string engine_decryption_lib_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath);
LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str());
if (handle == nullptr) {
ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not open shared library from " + engine_decryption_lib_path);
}
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
}
}
TensorrtExecutionProvider::~TensorrtExecutionProvider() {
@ -1098,6 +1125,29 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
engine_file.read((char*)engine_buf.get(), engine_size);
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) {
// Decrypt engine
size_t engine_size = 0;
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
} else {
// Set INT8 per tensor dynamic range
if (int8_enable_ && trt_builder->platformHasFastInt8()) {
@ -1174,7 +1224,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &int8_enable_, &max_workspace_size_,
trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_,
allocator_, dynamic_range_map};
allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_};
*state = p.release();
return 0;
};
@ -1210,7 +1260,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
const std::string profile_cache_path = cache_path + ".profile";
if ((trt_state->engine_cache_enable && trt_engine == nullptr)) {
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && profile_file) {
@ -1228,14 +1278,45 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
auto runtime_ = trt_state->runtime;
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine->get() == nullptr) {
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context->get() == nullptr) {
if (trt_state->context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
shape_ranges = DeserializeProfile(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Decrypt engine
size_t engine_size = 0;
if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_state->context->reset();
trt_state->engine->reset();
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
@ -1408,7 +1489,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Build engine
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
if (trt_state->engine->get() == nullptr) {
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
trt_engine = trt_state->engine->get();
@ -1428,7 +1509,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Build context
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context->get() == nullptr) {
if (trt_state->context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();

View file

@ -22,6 +22,8 @@ static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"
static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
} // namespace tensorrt_env_vars
class TensorrtLogger : public nvinfer1::ILogger {
@ -94,6 +96,8 @@ struct TensorrtFuncState {
nvinfer1::IRuntime* runtime = nullptr;
AllocatorPtr scratch_allocator;
std::unordered_map<std::string, float> dynamic_range_map;
bool engine_decryption_enable;
int (*engine_decryption)(const char*, char*, size_t*);
};
// Logical device representation.
@ -142,6 +146,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
int device_id_;
AllocatorPtr allocator_;
mutable char model_path_[4096]; // Reserved for max path length
bool engine_decryption_enable_ = false;
int (*engine_decryption_)(const char*, char*, size_t*);
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>> engines_;