Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add doc… (#1623)

* Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add documentation for ORT_TENSORRT* env vars.

* Update TensorRT-ExecutionProvider.md
This commit is contained in:
jywu-msft 2019-08-14 16:34:39 -07:00 committed by GitHub
parent 09db1e06b5
commit 24d17f4353
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 6 deletions

View file

@ -22,3 +22,12 @@ When using the python wheel from the ONNX Runtime build with TensorRT execution
### Using onnxruntime_perf_test
You can test the performance for your ONNX Model with the TensorRT execution provider. Use the flag `-e tensorrt` in [onnxruntime_perf_test](https://github.com/Microsoft/onnxruntime/tree/master/onnxruntime/test/perftest#onnxruntime-performance-test).
### Configuring Engine Max Batch Size and Workspace Size.
By default TensorRT execution provider builds an ICudaEngine with max batch size = 1 and max workspace size = 1 GB
One can override these defaults by setting environment variables ORT_TENSORRT_MAX_BATCH_SIZE and ORT_TENSORRT_MAX_WORKSPACE_SIZE.
e.g. on Linux
#### override default batch size to 10
export ORT_TENSORRT_MAX_BATCH_SIZE=10
#### override default max workspace size to 2GB
export ORT_TENSORRT_MAX_WORKSPACE_SIZE=2147483648

View file

@ -25,6 +25,12 @@ using namespace ::onnxruntime::logging;
namespace onnxruntime {
// Per TensorRT documentation, logger needs to be a singleton.
TensorrtLogger& GetTensorrtLogger() {
static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
return trt_logger;
}
#define CHECK_CUDA(call) \
do { \
cudaError_t status = call; \
@ -197,7 +203,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
@ -255,7 +261,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// Get supported node list
SubGraphCollection_t parser_nodes_vector;
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
@ -323,7 +329,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
model_proto.SerializeToString(&string_buf);
// Create TensorRT engine
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
@ -490,7 +496,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
// Run TRT inference
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
bool ret = trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
if (!ret) {
if (trt_state->context->getEngine().getMaxBatchSize() < batch_size) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"TRT enqueue failed: Set ORT_TRT_MAX_BATCH_SIZE environment variable to at least " + to_string(batch_size));
}
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to enqueue to TRT execution context.");
}
return Status::OK();
};

View file

@ -200,8 +200,7 @@ TEST(ActivationOpTest, Softplus) {
return x + logf(expf(-x) + 1);
else
return logf(expf(x) + 1);
},
{}, false); // Disable TensorRT because result mismatches
});
}
TEST(ActivationOpTest, Softsign) {