Add engine encryption in TensorRT EP (#8732)

* add engine encryption

* Update tensorrt_execution_provider.cc

* Update tensorrt_execution_provider.h

* Update tensorrt_execution_provider.cc

* Update tensorrt_execution_provider.h

* clean up

* update encryption signature
This commit is contained in:
stevenlix 2021-08-17 08:34:22 -07:00 committed by GitHub
parent f668a79532
commit 11a618b2ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 8 deletions

View file

@ -554,6 +554,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
"TensorRT EP could not open shared library from " + engine_decryption_lib_path_));
}
engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt");
engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
<< "device_id: " << device_id_
@ -1309,8 +1310,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
if (engine_cache_enable_) {
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), serializedModel->size());
size_t engine_size = serializedModel->size();
if (engine_decryption_enable_) {
// Encrypt engine
if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
}
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
@ -1367,7 +1377,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, dla_enable_,
dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr,
allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_};
allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_};
*state = p.release();
return 0;
};
@ -1648,10 +1658,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
// Serialize engine
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), serializedModel->size());
size_t engine_size = serializedModel->size();
if (trt_state->engine_decryption_enable) {
// Encrypt engine
if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
}
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
// Build context
@ -1967,7 +1985,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
auto output_tensor_ptr = ort.GetTensorMutableData<double>(output_tensor[i]);
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]);
}
}
}
}
return Status::OK();
@ -1977,4 +1995,4 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -99,6 +99,7 @@ struct TensorrtFuncState {
std::unordered_map<std::string, float> dynamic_range_map;
bool engine_decryption_enable;
int (*engine_decryption)(const char*, char*, size_t*);
int (*engine_encryption)(const char*, char*, size_t);
};
// Logical device representation.
@ -157,6 +158,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
mutable char model_path_[4096]; // Reserved for max path length
bool engine_decryption_enable_ = false;
int (*engine_decryption_)(const char*, char*, size_t*);
int (*engine_encryption_)(const char*, char*, size_t);
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>> engines_;