From f4a5d172941d8e2034a3e93cf60cb5fbb9dfab0f Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Tue, 25 Feb 2020 05:36:01 -0800 Subject: [PATCH] Upgrade to CUDA10.2 for TensorRT (#3084) * Switch to CUDA10.2 * Update win-gpu-tensorrt-ci-pipeline.yml * Update win-gpu-tensorrt-ci-pipeline.yml * remove dynamic_shape * update onnx-tensorrt submodule * check if input shape is specified for TensorRT subgraph input and enable some TensorRT unit tests * fix format issue * add shape inference instruction for TensorRT * update according to the reviews * Update win-gpu-tensorrt-ci-pipeline.yml --- .gitmodules | 6 +- cmake/external/onnx-tensorrt | 2 +- .../TensorRT-ExecutionProvider.md | 5 +- .../tensorrt/tensorrt_execution_provider.cc | 92 ++++++++++--------- .../cpu/tensor/space_depth_ops_test.cc | 21 ++--- .../win-gpu-tensorrt-ci-pipeline.yml | 4 +- 6 files changed, 65 insertions(+), 65 deletions(-) diff --git a/.gitmodules b/.gitmodules index c3ea0ab7bd..6335586e0b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -46,9 +46,9 @@ [submodule "cmake/external/FeaturizersLibrary"] path = cmake/external/FeaturizersLibrary url = https://github.com/microsoft/FeaturizersLibrary.git -[submodule "cmake/external/onnx-tensorrt"] - path = cmake/external/onnx-tensorrt - url = https://github.com/stevenlix/onnx-tensorrt.git [submodule "cmake/external/SafeInt/safeint"] path = cmake/external/SafeInt/safeint url = https://github.com/dcleblanc/SafeInt.git +[submodule "cmake/external/onnx-tensorrt"] + path = cmake/external/onnx-tensorrt + url = https://github.com/stevenlix/onnx-tensorrt.git diff --git a/cmake/external/onnx-tensorrt b/cmake/external/onnx-tensorrt index 6769f66f3c..5a7cba1a76 160000 --- a/cmake/external/onnx-tensorrt +++ b/cmake/external/onnx-tensorrt @@ -1 +1 @@ -Subproject commit 6769f66f3cd5ea8f74f467617d3a272c911c057e +Subproject commit 5a7cba1a768c3bb01cbf323e3acdeb8e29e3beca diff --git a/docs/execution_providers/TensorRT-ExecutionProvider.md b/docs/execution_providers/TensorRT-ExecutionProvider.md index a1435de780..24cb76e114 100644 --- a/docs/execution_providers/TensorRT-ExecutionProvider.md +++ b/docs/execution_providers/TensorRT-ExecutionProvider.md @@ -19,8 +19,11 @@ status = session_object.Load(model_file_name); ``` The C API details are [here](../C_API.md#c-api). +#### Shape Inference for TensorRT Subgraphs +If some operators in the model are not supported by TensorRT, ONNX Runtime will partition the graph and only send supported subgraphs to TensorRT execution provider. Because TensorRT requires that all inputs of the subgraphs have shape specified, ONNX Runtime will throw error if there is no input shape info. In this case please run shape inference for the entire model first by running script [here](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py). + #### Sample -To run Faster R-CNN model on TensorRT execution provider, +This example shows how to run Faster R-CNN model on TensorRT execution provider, First, download Faster R-CNN onnx model from onnx model zoo [here](https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/faster-rcnn). diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a91410eea1..c6c2f4443a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -116,7 +116,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv const std::string fp16_enable_env = env_instance.GetEnvironmentVar(tensorrt_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); - } + } } TensorrtExecutionProvider::~TensorrtExecutionProvider() {} @@ -184,8 +184,8 @@ bool FindCycleHelper(int i, const std::list* adjacency_map, // Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) { // Here only NonZero and NonMaxSuppression related empty shape nodes are removed, particularly for Faster-rcnn and Mask-rcnn models. - // TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around - const std::vector& node_index = graph.GetNodesInTopologicalOrder(); + // TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around + const std::vector& node_index = graph.GetNodesInTopologicalOrder(); const std::string exclude_dim_name1 = "NonZero"; const std::string exclude_dim_name2 = "NonMaxSuppression"; SubGraphCollection_t parser_nodes_vector = {{{}, false}}; @@ -202,8 +202,8 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph std::string dim_name = dim.dim_param(); if (!dim_name.empty()) { if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) { - exclude_node = true; - break; + exclude_node = true; + break; } } } @@ -216,9 +216,9 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph // Remove the node with empty input shape if (!exclude_node) { - parser_nodes_vector.back().first.push_back(index); + parser_nodes_vector.back().first.push_back(index); } else if (!parser_nodes_vector.back().first.empty()) { - parser_nodes_vector.push_back({{},false}); + parser_nodes_vector.push_back({{}, false}); } } @@ -407,6 +407,18 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect ORT_ENFORCE(graph_build.Resolve().IsOK()); + // Check if input tensors have shapes + if (iterations > 1) { + for (const auto* input_arg : graph_build.GetInputs()) { + if (input_arg->Shape() == nullptr) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "TensorRT input: " + input_arg->Name() + " has no shape specified. " + + "Please run shape inference on the onnx model first. Details can be found in " + + "https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/TensorRT-ExecutionProvider.md#shape-inference-for-tensorrt-subgraphs")); + } + } + } + // Serialize modelproto to string const onnxruntime::GraphViewer graph_viewer(graph_build); @@ -453,10 +465,10 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& std::unordered_map index_to_node_map; std::unordered_map> input_to_nodes_map, node_to_outputs_map; std::unordered_set non_trt_node_index(node_index.begin(), node_index.end()); - int counter = 0, id = 0; + int counter = 0, id = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - // Construct subgraph from node list + // Construct subgraph from node list std::unique_ptr sub_graph = GetSubGraph(group, counter, graph); // Create node to inputs/outputs/index maps @@ -468,23 +480,23 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& } if (meta_def != nullptr) { - for (const auto& input: meta_def->inputs) { + for (const auto& input : meta_def->inputs) { input_to_nodes_map[input].insert(node_name); } - for (const auto& output: meta_def->outputs) { + for (const auto& output : meta_def->outputs) { node_to_outputs_map[node_name].insert(output); } } // Remove TensorRT nodes from node index list - for (const auto& index: group.first) { + for (const auto& index : group.first) { non_trt_node_index.erase(node_index[index]); } } } // Add non TensorRT nodes to the maps - for (const auto& index: non_trt_node_index) { + for (const auto& index : non_trt_node_index) { const auto& node = graph.GetNode(index); std::string node_name = node->Name(); if (node_to_index_map.find(node_name) == node_to_index_map.end()) { @@ -503,13 +515,13 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& // Create adjacency list int graph_size = node_to_index_map.size(); - std::list *adjacency_map = new std::list[graph_size]; - for (const auto& node: node_to_outputs_map) { + std::list* adjacency_map = new std::list[graph_size]; + for (const auto& node : node_to_outputs_map) { for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) { const auto& loc = input_to_nodes_map.find(*iter); if (loc != input_to_nodes_map.end()) { int parent_node_index = node_to_index_map.find(node.first)->second; - for (auto child_node: loc->second) { + for (auto child_node : loc->second) { int child_node_index = node_to_index_map.find(child_node)->second; adjacency_map[parent_node_index].push_back(child_node_index); } @@ -518,8 +530,8 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& } // Check cycle in the graph - bool *visited = new bool[graph_size]; - bool *st = new bool[graph_size]; + bool* visited = new bool[graph_size]; + bool* st = new bool[graph_size]; for (int i = 0; i < graph_size; ++i) { visited[i] = false; st[i] = false; @@ -529,19 +541,19 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t& bool has_cycle = false; for (int i = 0; i < graph_size; ++i) { if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) { - has_cycle = true; - break; + has_cycle = true; + break; } } - // Remove TensorRT subgraph from the supported node list if it's part of the cycle + // Remove TensorRT subgraph from the supported node list if it's part of the cycle if (has_cycle) { for (int i = 0; i < static_cast(cycles.size()); ++i) { auto loc = index_to_node_map.find(cycles[i]); if (loc != index_to_node_map.end() && loc->second.find("TRTKernel") != std::string::npos) { int trt_node_index = std::stoi(loc->second.substr(10)); supported_nodes_vector.erase(supported_nodes_vector.begin() + trt_node_index); - trt_cycle = true; + trt_cycle = true; break; } } @@ -587,7 +599,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, result.push_back(onnxruntime::make_unique(std::move(sub_graph))); } } - + return result; } @@ -660,6 +672,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsetShapeValues(input->getName(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], nb_dims); trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], nb_dims); } else { // Execution tensor + bool is_dynamic_shape = false; for (int j = 0, end = nb_dims; j < end; ++j) { // For dynamic shape subgraph, a dummy engine is created at compile phase. // Real engine will be created at compute phase based on input data @@ -667,12 +680,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsetDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); + + if (is_dynamic_shape) { + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); + } } } @@ -764,8 +780,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, parsers_[context->node_name].get(), engines_[context->node_name].get(), contexts_[context->node_name].get(), builders_[context->node_name].get(), networks_[context->node_name].get(), input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_, &fp16_enable_, - &max_workspace_size_}; + input_shape_ranges_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_, &fp16_enable_, + &max_workspace_size_}; *state = p.release(); return 0; }; @@ -790,14 +806,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector buffers(total_bindings); - bool dynamic_shape = false; - auto trt_context = trt_state->context; - if (!trt_context->allInputDimensionsSpecified() || !trt_context->allInputShapesSpecified()) { - dynamic_shape = true; - } - // Update shape ranges bool dimension_update = false; + auto trt_context = trt_state->context; auto trt_builder = trt_state->builder; nvinfer1::IOptimizationProfile* trt_profile = nullptr; for (int i = 0, end = num_binding_inputs; i < end; ++i) { @@ -857,20 +868,13 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsetDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); - trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); - } } // Regenerate engine and context // Only one profile is generated, so no need to explicitly set optimization profile if (dimension_update) { auto trt_config = unique_pointer(trt_builder->createBuilderConfig()); - trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); + trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); trt_config->addOptimizationProfile(trt_profile); if (*(trt_state->fp16_enable_ptr) && trt_builder->platformHasFastFp16()) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); @@ -985,4 +989,4 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result); - // TensorRT has error: Expected output shape [{1,8,1,2}] did not match run output shape [{8,1,1,2}] for output - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } TEST(TensorOpTest, SpaceToDepthTest_2) { @@ -70,8 +69,7 @@ TEST(TensorOpTest, SpaceToDepthTest_2) { 98., 101., 66., 69., 84., 87., 102., 105., 67., 70., 85., 88., 103., 106., 68., 71., 86., 89., 104., 107.}; test.AddOutput("output", {2, 27, 1, 2}, result); - // TensorRT has error: Expected output shape [{2,27,1,2}] did not match run output shape [{54,1,1,2}] for output - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } TEST(TensorOpTest, DepthToSpaceTest_1) { @@ -102,8 +100,7 @@ TEST(TensorOpTest, DepthToSpaceTest_1) { 2.0f, 2.1f, 2.2f, 2.3f, 3.0f, 3.1f, 3.2f, 3.3f}; test.AddOutput("output", {N, C / (blocksize * blocksize), H * blocksize, W * blocksize}, result); - // TensorRT output shape mismatches - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } TEST(TensorOpTest, DepthToSpaceTest_2) { @@ -146,8 +143,7 @@ TEST(TensorOpTest, DepthToSpaceTest_2) { 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; test.AddOutput("output", {2, 3, 6, 4}, result); - // TensorRT output shape mismatches - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } TEST(TensorOpTest, DepthToSpaceTest_3) { @@ -190,8 +186,7 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; test.AddOutput("output", {2, 3, 6, 4}, result); - // TensorRT output shape mismatches - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } TEST(TensorOpTest, DepthToSpaceTest_4) { @@ -235,8 +230,7 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; test.AddOutput("output", {2, 3, 6, 4}, result); - // TensorRT output shape mismatches - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } TEST(TensorOpTest, DepthToSpaceTest_5) { @@ -263,8 +257,7 @@ TEST(TensorOpTest, DepthToSpaceTest_5) { 21., 30., 22., 31., 23., 32.}; test.AddOutput("output", {1, 1, 4, 6}, result); - // TensorRT output shape mismatches - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(); } } // namespace test diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index 1ba88d9ead..304fffafcc 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -42,7 +42,7 @@ jobs: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.0.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0' + arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.2.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.2 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2" --cudnn_home="C:\local\cudnn-10.2-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0' workingDirectory: '$(Build.BinariesDirectory)' - task: VSBuild@1 @@ -81,7 +81,7 @@ jobs: del wheel_filename_file python.exe -m pip install -q --upgrade %WHEEL_FILENAME% set PATH=$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig);%PATH% - python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.0.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.2.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.2 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2" --cudnn_home="C:\local\cudnn-10.2-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run tests'