From efe0af37206ad8a0bf8617e4a9d45adf671fea98 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 26 Jun 2023 09:02:59 -0700 Subject: [PATCH] [TensorRT EP] Fix nullptr check (#16468) ### Description Fix the nullptr check so that it would check the actual existence of engine/context (Currently, it checks the address of unique_ptr, which is always not null. Thx @jslhcl for pointing that out) > A quick recall of struct [trt_state](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L104): > ``` > std::unique_ptr* engine = nullptr; >std::unique_ptr* context = nullptr; >``` ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/15982 The incorrect check couldn't stop TRT EP from loading incompatible engine cache on purpose, which invokes unhandled exception --- .../tensorrt/tensorrt_execution_provider.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7ec4a72e68..cb0a0ed432 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2257,7 +2257,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (trt_state->engine == nullptr) { + if (*(trt_state->engine) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; @@ -2269,7 +2269,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext) = std::unique_ptr( trt_state->engine->get()->createExecutionContext()); } - if (trt_state->context == nullptr) { + if (*(trt_state->context) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get(); @@ -2291,7 +2291,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext->reset(); trt_state->engine->reset(); *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (trt_state->engine == nullptr) { + if (*(trt_state->engine) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); } @@ -2304,7 +2304,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext) = std::unique_ptr( trt_state->engine->get()->createExecutionContext()); } - if (trt_state->context == nullptr) { + if (*(trt_state->context) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get(); @@ -2430,7 +2430,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectortrt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } } - if (trt_state->engine == nullptr) { + if (*(trt_state->engine) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } trt_engine = trt_state->engine->get(); @@ -2476,7 +2476,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext) = std::unique_ptr( trt_state->engine->get()->createExecutionContext()); } - if (trt_state->context == nullptr) { + if (*(trt_state->context) == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } trt_context = trt_state->context->get();