diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e215f95cd4..248e9604fa 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -554,6 +554,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv "TensorRT EP could not open shared library from " + engine_decryption_lib_path_)); } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); + engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: " << "device_id: " << device_id_ @@ -1309,8 +1310,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } if (engine_cache_enable_) { nvinfer1::IHostMemory* serializedModel = trt_engine->serialize(); - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serializedModel->data()), serializedModel->size()); + size_t engine_size = serializedModel->size(); + if (engine_decryption_enable_) { + // Encrypt engine + if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), engine_size); + } serializedModel->destroy(); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; } @@ -1367,7 +1377,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr, - allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_}; + allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_}; *state = p.release(); return 0; }; @@ -1648,10 +1658,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse // Serialize engine nvinfer1::IHostMemory* serializedModel = trt_engine->serialize(); - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serializedModel->data()), serializedModel->size()); + size_t engine_size = serializedModel->size(); + if (trt_state->engine_decryption_enable) { + // Encrypt engine + if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), engine_size); + } serializedModel->destroy(); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; } // Build context @@ -1967,7 +1985,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse auto output_tensor_ptr = ort.GetTensorMutableData(output_tensor[i]); if (output_tensor_ptr != nullptr) { cuda::Impl_Cast(stream, reinterpret_cast(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]); - } + } } } return Status::OK(); @@ -1977,4 +1995,4 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse } return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 0dc649a29d..0c888fc63f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -99,6 +99,7 @@ struct TensorrtFuncState { std::unordered_map dynamic_range_map; bool engine_decryption_enable; int (*engine_decryption)(const char*, char*, size_t*); + int (*engine_encryption)(const char*, char*, size_t); }; // Logical device representation. @@ -157,6 +158,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { mutable char model_path_[4096]; // Reserved for max path length bool engine_decryption_enable_ = false; int (*engine_decryption_)(const char*, char*, size_t*); + int (*engine_encryption_)(const char*, char*, size_t); std::unordered_map> parsers_; std::unordered_map> engines_;