trt provider status return cleanup (#2032)

* status and code cleanup.

* revert change. seems like a bug in TRT causes intermittent failure return?
This commit is contained in:
George Wu 2019-10-07 18:34:48 -07:00 committed by GitHub
parent b2c1937523
commit 0bd807f3b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -21,20 +21,11 @@
#include "core/graph/model.h"
#include "core/providers/cuda/gpu_data_transfer.h"
using namespace std;
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
namespace onnxruntime {
#define CHECK_CUDA(call) \
do { \
cudaError_t status = call; \
if (status != cudaSuccess) { \
return common::Status(common::ONNXRUNTIME, common::FAIL); \
} \
} while (0)
ONNX_OPERATOR_KERNEL_EX(
MemcpyFromHost,
kOnnxDomain,
@ -252,7 +243,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
// Serialize modelproto to string
ONNX_NAMESPACE::ModelProto model_proto = model_build.ToProto();
string string_buf;
std::string string_buf;
model_proto.SerializeToString(&string_buf);
// Get supported node list recursively
@ -312,7 +303,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
// Serialize modelproto to string
string string_buf;
std::string string_buf;
model_proto.SerializeToString(&string_buf);
// Get supported node list
@ -382,7 +373,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
*(model_proto.mutable_graph()) = graph_body.ToGraphProto();
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
string string_buf;
std::string string_buf;
model_proto.SerializeToString(&string_buf);
// Create TensorRT engine
@ -436,11 +427,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
}
auto trt_engine = unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
ORT_ENFORCE(trt_engine != nullptr);
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build Engine for fused node: " + fused_node->Name());
}
// Build TensorRT context
auto trt_context = unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
ORT_ENFORCE(trt_context != nullptr);
if (trt_context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build Execution Context for fused node: " + fused_node->Name());
}
// Get input shape and binding index
int num_inputs = trt_network->getNbInputs();
@ -565,10 +562,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
for (int j = 0, end = dimensions.nbDims; j < end; ++j) {
input_dim_size *= tensor_shape[j];
}
CHECK_CUDA(cudaMalloc(&buffers[i], input_dim_size * sizeof(int32_t)));
CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[i], input_dim_size * sizeof(int32_t)));
cuda::Impl_Cast<int64_t, int32_t>(ort.GetTensorData<int64_t>(input_tensor), reinterpret_cast<int32_t*>(buffers[i]), input_dim_size);
} else {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.");
}
}
@ -596,10 +594,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
for (int j = 0, end = dimensions.nbDims; j < end; ++j) {
output_dim_size[i] *= dimensions.d[j];
}
CHECK_CUDA(cudaMalloc(&buffers[i + num_binding_inputs], output_dim_size[i] * sizeof(int32_t)));
CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[i + num_binding_inputs], output_dim_size[i] * sizeof(int32_t)));
} else {
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP output onnx tensor data type: " + std::to_string(output_types[i]) + " not supported.");
}
}
@ -623,4 +622,3 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
return Status::OK();
}
} // namespace onnxruntime