From e9d03983fc63828f84a9dcc8538392e1b22b46bb Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Tue, 9 Feb 2021 00:46:14 -0800 Subject: [PATCH] Add engine decryption in TensorRT EP (#6612) * add trt engine decryption * update document * add windows support to decryption * fix issues * remove redundant get() from engine/context check * fix issue --- .../tensorrt/tensorrt_execution_provider.cc | 93 +++++++++++++++++-- .../tensorrt/tensorrt_execution_provider.h | 6 ++ 2 files changed, 93 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 4e065fa6a4..6faae63c49 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -21,6 +21,18 @@ #include "flatbuffers/idl.h" #include "ort_trt_int8_cal_table.fbs.h" +#ifdef _WIN32 +#include +#define LIBTYPE HINSTANCE +#define OPENLIB(libname) LoadLibrary(libname) +#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn)) +#else +#include +#define LIBTYPE void* +#define OPENLIB(libname) dlopen((libname), RTLD_LAZY) +#define LIBFUNC(lib, fn) dlsym((lib), (fn)) +#endif + #define CUDA_RETURN_IF_ERROR(expr) \ ORT_RETURN_IF_ERROR(CUDA_CALL(expr) \ ? common::Status::OK() \ @@ -443,6 +455,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } runtime_ = nvinfer1::createInferRuntime(GetTensorrtLogger()); } + + const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable); + if (!engine_decryption_enable_env.empty()) { + engine_decryption_enable_ = (std::stoi(engine_decryption_enable_env) == 0 ? false : true); + } + + if (engine_decryption_enable_) { + std::string engine_decryption_lib_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionLibPath); + LIBTYPE handle = OPENLIB(engine_decryption_lib_path.c_str()); + if (handle == nullptr) { + ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not open shared library from " + engine_decryption_lib_path); + } + engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); + } } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -1098,6 +1125,29 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse engine_file.read((char*)engine_buf.get(), engine_size); trt_engine = tensorrt_ptr::unique_pointer(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from cache: " + engine_cache_path); + } + } else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) { + // Decrypt engine + size_t engine_size = 0; + if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + trt_engine = tensorrt_ptr::unique_pointer(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); + } } else { // Set INT8 per tensor dynamic range if (int8_enable_ && trt_builder->platformHasFastInt8()) { @@ -1174,7 +1224,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &int8_enable_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_, - allocator_, dynamic_range_map}; + allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_}; *state = p.release(); return 0; }; @@ -1210,7 +1260,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); const std::string engine_cache_path = cache_path + ".engine"; const std::string profile_cache_path = cache_path + ".profile"; - if ((trt_state->engine_cache_enable && trt_engine == nullptr)) { + if (trt_state->engine_cache_enable && trt_engine == nullptr) { std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); if (engine_file && profile_file) { @@ -1228,14 +1278,45 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse auto runtime_ = trt_state->runtime; *(trt_state->engine) = tensorrt_ptr::unique_pointer( runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (trt_state->engine->get() == nullptr) { + if (trt_state->engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; trt_engine = trt_state->engine->get(); *(trt_state->context) = tensorrt_ptr::unique_pointer( trt_state->engine->get()->createExecutionContext()); - if (trt_state->context->get() == nullptr) { + if (trt_state->context == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); + } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { + shape_ranges = DeserializeProfile(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + trt_state->context->reset(); + trt_state->engine->reset(); + *(trt_state->engine) = tensorrt_ptr::unique_pointer(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + if (trt_state->engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + *(trt_state->context) = tensorrt_ptr::unique_pointer( + trt_state->engine->get()->createExecutionContext()); + if (trt_state->context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get(); @@ -1408,7 +1489,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse // Build engine *(trt_state->engine) = tensorrt_ptr::unique_pointer( trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); - if (trt_state->engine->get() == nullptr) { + if (trt_state->engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } trt_engine = trt_state->engine->get(); @@ -1428,7 +1509,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse // Build context *(trt_state->context) = tensorrt_ptr::unique_pointer( trt_state->engine->get()->createExecutionContext()); - if (trt_state->context->get() == nullptr) { + if (trt_state->context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 5bc13bcab3..1490a0ffd8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -22,6 +22,8 @@ static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE" static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH"; // Old env variable for backward compatibility static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; +static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE"; +static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH"; } // namespace tensorrt_env_vars class TensorrtLogger : public nvinfer1::ILogger { @@ -94,6 +96,8 @@ struct TensorrtFuncState { nvinfer1::IRuntime* runtime = nullptr; AllocatorPtr scratch_allocator; std::unordered_map dynamic_range_map; + bool engine_decryption_enable; + int (*engine_decryption)(const char*, char*, size_t*); }; // Logical device representation. @@ -142,6 +146,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { int device_id_; AllocatorPtr allocator_; mutable char model_path_[4096]; // Reserved for max path length + bool engine_decryption_enable_ = false; + int (*engine_decryption_)(const char*, char*, size_t*); std::unordered_map> parsers_; std::unordered_map> engines_;