mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[TensorRT EP] Fix nullptr check (#16468)
### Description Fix the nullptr check so that it would check the actual existence of engine/context (Currently, it checks the address of unique_ptr, which is always not null. Thx @jslhcl for pointing that out) > A quick recall of struct [trt_state](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L104): > ``` > std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr; >std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr; >``` ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/15982 The incorrect check couldn't stop TRT EP from loading incompatible engine cache on purpose, which invokes unhandled exception
This commit is contained in:
parent
c55c6255e0
commit
efe0af3720
1 changed files with 6 additions and 6 deletions
|
|
@ -2257,7 +2257,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
engine_file.read((char*)engine_buf.get(), engine_size);
|
||||
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(
|
||||
trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
if (trt_state->engine == nullptr) {
|
||||
if (*(trt_state->engine) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
|
|
@ -2269,7 +2269,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
}
|
||||
if (trt_state->context == nullptr) {
|
||||
if (*(trt_state->context) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
|
|
@ -2291,7 +2291,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
trt_state->context->reset();
|
||||
trt_state->engine->reset();
|
||||
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
if (trt_state->engine == nullptr) {
|
||||
if (*(trt_state->engine) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
|
||||
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
|
||||
}
|
||||
|
|
@ -2304,7 +2304,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
}
|
||||
if (trt_state->context == nullptr) {
|
||||
if (*(trt_state->context) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
|
|
@ -2430,7 +2430,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast<std::chrono::milliseconds>(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
|
||||
}
|
||||
}
|
||||
if (trt_state->engine == nullptr) {
|
||||
if (*(trt_state->engine) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
trt_engine = trt_state->engine->get();
|
||||
|
|
@ -2476,7 +2476,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
}
|
||||
if (trt_state->context == nullptr) {
|
||||
if (*(trt_state->context) == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
|
|
|
|||
Loading…
Reference in a new issue