diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 1f5d4da2fd..9dfc218eab 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -21,20 +21,11 @@ #include "core/graph/model.h" #include "core/providers/cuda/gpu_data_transfer.h" -using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; namespace onnxruntime { -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status = call; \ - if (status != cudaSuccess) { \ - return common::Status(common::ONNXRUNTIME, common::FAIL); \ - } \ - } while (0) - ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, @@ -252,7 +243,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // Serialize modelproto to string ONNX_NAMESPACE::ModelProto model_proto = model_build.ToProto(); - string string_buf; + std::string string_buf; model_proto.SerializeToString(&string_buf); // Get supported node list recursively @@ -312,7 +303,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); // Serialize modelproto to string - string string_buf; + std::string string_buf; model_proto.SerializeToString(&string_buf); // Get supported node list @@ -382,7 +373,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(trt_builder->buildEngineWithConfig(*trt_network, *trt_config)); - ORT_ENFORCE(trt_engine != nullptr); + if (trt_engine == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build Engine for fused node: " + fused_node->Name()); + } // Build TensorRT context auto trt_context = unique_pointer(trt_engine->createExecutionContext()); - ORT_ENFORCE(trt_context != nullptr); + if (trt_context == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not build Execution Context for fused node: " + fused_node->Name()); + } // Get input shape and binding index int num_inputs = trt_network->getNbInputs(); @@ -565,10 +562,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(ort.GetTensorData(input_tensor), reinterpret_cast(buffers[i]), input_dim_size); } else { - return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); } } @@ -596,10 +594,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector