[TensorRT EP] Fix nullptr check (#16468)

### Description
Fix the nullptr check so that it would check the actual existence of
engine/context
(Currently, it checks the address of unique_ptr, which is always not
null. Thx @jslhcl for pointing that out)

> A quick recall of struct
[trt_state](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L104):
> ```
> std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
>std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
>```


### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/15982
The incorrect check couldn't stop TRT EP from loading incompatible
engine cache on purpose, which invokes unhandled exception
This commit is contained in:
Yifan Li 2023-06-26 09:02:59 -07:00 committed by GitHub
parent c55c6255e0
commit efe0af3720
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2257,7 +2257,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
engine_file.read((char*)engine_buf.get(), engine_size);
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(
trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine == nullptr) {
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
@ -2269,7 +2269,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
}
if (trt_state->context == nullptr) {
if (*(trt_state->context) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
@ -2291,7 +2291,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_state->context->reset();
trt_state->engine->reset();
*(trt_state->engine) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine == nullptr) {
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
@ -2304,7 +2304,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
}
if (trt_state->context == nullptr) {
if (*(trt_state->context) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
@ -2430,7 +2430,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast<std::chrono::milliseconds>(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
}
}
if (trt_state->engine == nullptr) {
if (*(trt_state->engine) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
trt_engine = trt_state->engine->get();
@ -2476,7 +2476,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
}
if (trt_state->context == nullptr) {
if (*(trt_state->context) == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();