From 1dd920fa7cac2d2ce8a0b497f2f32bd4a2b099fc Mon Sep 17 00:00:00 2001 From: nietras Date: Mon, 8 Feb 2021 05:09:30 +0100 Subject: [PATCH] Fix TensorRT unnecessary file cache operations (#6601) * Fix TensorRT unnecessary file cache operations * fix compile --- .../tensorrt/tensorrt_execution_provider.cc | 56 ++++++++++--------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index ea2d88a749..4e065fa6a4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1210,34 +1210,36 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& fuse const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); const std::string engine_cache_path = cache_path + ".engine"; const std::string profile_cache_path = cache_path + ".profile"; - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && profile_file && (trt_state->engine_cache_enable && trt_engine == nullptr)) { - // Deserialize profile - shape_ranges = DeserializeProfile(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Deserialize engine - trt_state->context->reset(); - trt_state->engine->reset(); - engine_file.seekg(0, std::ios::end); - int engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - auto runtime_ = trt_state->runtime; - *(trt_state->engine) = tensorrt_ptr::unique_pointer( - runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (trt_state->engine->get() == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + if ((trt_state->engine_cache_enable && trt_engine == nullptr)) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfile(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Deserialize engine + trt_state->context->reset(); + trt_state->engine->reset(); + engine_file.seekg(0, std::ios::end); + int engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + auto runtime_ = trt_state->runtime; + *(trt_state->engine) = tensorrt_ptr::unique_pointer( + runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (trt_state->engine->get() == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + *(trt_state->context) = tensorrt_ptr::unique_pointer( + trt_state->engine->get()->createExecutionContext()); + if (trt_state->context->get() == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); + } + trt_context = trt_state->context->get(); } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - *(trt_state->context) = tensorrt_ptr::unique_pointer( - trt_state->engine->get()->createExecutionContext()); - if (trt_state->context->get() == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); - } - trt_context = trt_state->context->get(); } for (int i = 0, end = num_inputs; i < end; ++i) {