From 4eedd3bb46ad596aa4167aa07c68c0c7edd7ba4b Mon Sep 17 00:00:00 2001 From: simonjub <78098752+simonjub@users.noreply.github.com> Date: Sat, 26 Aug 2023 23:09:03 -0400 Subject: [PATCH] [TRT EP] Fix logic to reach cache encryption code. (#17111) ### Description This is a followup to PR #15519 that is closed in favor of this one. ### Motivation and Context The current implementation of TRT cache has no code execution path possible so that an encrypted TRT engine cache could be created when flags engine_cache_enable and engine_decryption_enable are true. This was originally raised in issue #12551. --- .../tensorrt/tensorrt_execution_provider.cc | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 3f0cfdac8a..36ab2f62b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -22,6 +22,7 @@ #include #include #include +#include // TODO: find a better way to share this #include "core/providers/cuda/cuda_stream_handle.h" @@ -1028,6 +1029,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); + if (engine_decryption_ == nullptr) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + } } if (int8_enable_) { @@ -2209,6 +2214,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector engine_buf{new char[engine_size]}; - if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); } // Deserialize engine trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } } else { // Set INT8 per tensor dynamic range @@ -2311,16 +2317,21 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector serializedModel(trt_engine->serialize()); size_t engine_size = serializedModel->size(); if (engine_decryption_enable_) { - // Encrypt engine - if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (engine_encryption_ != nullptr) { + if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP call to engine encryption library failed"); + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); file.write(reinterpret_cast(serializedModel->data()), engine_size); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path; } // serialize and save timing cache if (timing_cache_enable_) { @@ -2465,6 +2476,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine_cache_path, trt_state->trt_node_name_with_precision); const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; std::string timing_cache_path = ""; if (timing_cache_enable_) { @@ -2481,7 +2493,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine_cache_enable && trt_engine == nullptr) { std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && profile_file) { + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { // Deserialize profile shape_ranges = DeserializeProfileV2(profile_file); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; @@ -2505,17 +2517,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->get(); context_update = true; - } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { shape_ranges = DeserializeProfileV2(profile_file); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; // Decrypt engine size_t engine_size = 0; - if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) { + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not get engine buffer size"); } std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not call engine decryption function decrypt"); } @@ -2526,9 +2538,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); if (*(trt_state->engine) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); + "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; trt_engine = trt_state->engine->get(); context_update = true; } @@ -2670,16 +2682,21 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector serializedModel(trt_engine->serialize()); size_t engine_size = serializedModel->size(); if (trt_state->engine_decryption_enable) { - // Encrypt engine - if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; } } else { std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); file.write(reinterpret_cast(serializedModel->data()), engine_size); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; } // serialize and save timing cache