Fix TensorRT unnecessary file cache operations (#6601)

* Fix TensorRT unnecessary file cache operations

* fix compile
This commit is contained in:
nietras 2021-02-08 05:09:30 +01:00 committed by GitHub
parent 19c130f561
commit 1dd920fa7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1210,34 +1210,36 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
const std::string profile_cache_path = cache_path + ".profile";
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && profile_file && (trt_state->engine_cache_enable && trt_engine == nullptr)) {
// Deserialize profile
shape_ranges = DeserializeProfile(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Deserialize engine
trt_state->context->reset();
trt_state->engine->reset();
engine_file.seekg(0, std::ios::end);
int engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
auto runtime_ = trt_state->runtime;
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine->get() == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
if ((trt_state->engine_cache_enable && trt_engine == nullptr)) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && profile_file) {
// Deserialize profile
shape_ranges = DeserializeProfile(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Deserialize engine
trt_state->context->reset();
trt_state->engine->reset();
engine_file.seekg(0, std::ios::end);
int engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
auto runtime_ = trt_state->runtime;
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine->get() == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context->get() == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context->get() == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
}
for (int i = 0, end = num_inputs; i < end; ++i) {