mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add doc… (#1623)
* Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add documentation for ORT_TENSORRT* env vars. * Update TensorRT-ExecutionProvider.md
This commit is contained in:
parent
09db1e06b5
commit
24d17f4353
3 changed files with 27 additions and 6 deletions
|
|
@ -22,3 +22,12 @@ When using the python wheel from the ONNX Runtime build with TensorRT execution
|
|||
|
||||
### Using onnxruntime_perf_test
|
||||
You can test the performance for your ONNX Model with the TensorRT execution provider. Use the flag `-e tensorrt` in [onnxruntime_perf_test](https://github.com/Microsoft/onnxruntime/tree/master/onnxruntime/test/perftest#onnxruntime-performance-test).
|
||||
|
||||
### Configuring Engine Max Batch Size and Workspace Size.
|
||||
By default TensorRT execution provider builds an ICudaEngine with max batch size = 1 and max workspace size = 1 GB
|
||||
One can override these defaults by setting environment variables ORT_TENSORRT_MAX_BATCH_SIZE and ORT_TENSORRT_MAX_WORKSPACE_SIZE.
|
||||
e.g. on Linux
|
||||
#### override default batch size to 10
|
||||
export ORT_TENSORRT_MAX_BATCH_SIZE=10
|
||||
#### override default max workspace size to 2GB
|
||||
export ORT_TENSORRT_MAX_WORKSPACE_SIZE=2147483648
|
||||
|
|
|
|||
|
|
@ -25,6 +25,12 @@ using namespace ::onnxruntime::logging;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
// Per TensorRT documentation, logger needs to be a singleton.
|
||||
TensorrtLogger& GetTensorrtLogger() {
|
||||
static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
|
||||
return trt_logger;
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status = call; \
|
||||
|
|
@ -197,7 +203,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
|
||||
// Get supported node list recursively
|
||||
SubGraphCollection_t parser_nodes_list;
|
||||
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
|
||||
TensorrtLogger& trt_logger = GetTensorrtLogger();
|
||||
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
|
||||
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
|
||||
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
|
||||
|
|
@ -255,7 +261,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
|
||||
// Get supported node list
|
||||
SubGraphCollection_t parser_nodes_vector;
|
||||
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
|
||||
TensorrtLogger& trt_logger = GetTensorrtLogger();
|
||||
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
|
||||
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
|
||||
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
|
||||
|
|
@ -323,7 +329,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
|
|||
model_proto.SerializeToString(&string_buf);
|
||||
|
||||
// Create TensorRT engine
|
||||
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
|
||||
TensorrtLogger& trt_logger = GetTensorrtLogger();
|
||||
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
|
||||
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
|
||||
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
|
||||
|
|
@ -490,7 +496,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
|
|||
|
||||
// Run TRT inference
|
||||
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
|
||||
trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
|
||||
bool ret = trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
|
||||
if (!ret) {
|
||||
if (trt_state->context->getEngine().getMaxBatchSize() < batch_size) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"TRT enqueue failed: Set ORT_TRT_MAX_BATCH_SIZE environment variable to at least " + to_string(batch_size));
|
||||
}
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to enqueue to TRT execution context.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -200,8 +200,7 @@ TEST(ActivationOpTest, Softplus) {
|
|||
return x + logf(expf(-x) + 1);
|
||||
else
|
||||
return logf(expf(x) + 1);
|
||||
},
|
||||
{}, false); // Disable TensorRT because result mismatches
|
||||
});
|
||||
}
|
||||
|
||||
TEST(ActivationOpTest, Softsign) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue