mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Fixed node index remapping issue in TensorRT graph partitioning (#2155)
* Fixed node index mapping issue during graph partitioning * add test for node index mapping * Update BUILD.md * Update TensorRT-ExecutionProvider.md
This commit is contained in:
parent
7b18bd563f
commit
a9f01a5f29
4 changed files with 120 additions and 2 deletions
2
BUILD.md
2
BUILD.md
|
|
@ -187,7 +187,7 @@ See more information on the TensorRT Execution Provider [here](./docs/execution_
|
|||
* The path to the CUDA `bin` directory must be added to the PATH environment variable so that `nvcc` is found.
|
||||
* The path to the cuDNN installation (path to folder that contains libcudnn.so) must be provided via the cuDNN_PATH environment variable, or `--cudnn_home parameter`.
|
||||
* Install [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download)
|
||||
* The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5.
|
||||
* The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5 but validated with the feature set equivalent to TensorRT 5. Some TensorRT 6 new features such as dynamic shape is not available at this time.
|
||||
* The path to TensorRT installation must be provided via the `--tensorrt_home parameter`.
|
||||
|
||||
#### Build Instructions
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ With the TensorRT execution provider, the ONNX Runtime delivers better inferenci
|
|||
## Build
|
||||
For build instructions, please see the [BUILD page](../../BUILD.md#tensorrt).
|
||||
|
||||
The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5 but validated with the feature set equivalent to TensorRT 5. Some TensorRT 6 new features such as dynamic shape is not available as this time.
|
||||
|
||||
## Using the TensorRT execution provider
|
||||
### C/C++
|
||||
The TensortRT execution provider needs to be registered with ONNX Runtime to enable in the inference session.
|
||||
|
|
|
|||
|
|
@ -267,10 +267,11 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
|
||||
SubGraphCollection_t next_nodes_list;
|
||||
const onnxruntime::GraphViewer graph_viewer(graph_build);
|
||||
const std::vector<NodeIndex>& subgraph_node_index = graph_viewer.GetNodesInTopologicalOrder();
|
||||
next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, graph_viewer, early_termination);
|
||||
for (int i = 0, end = next_nodes_list.size(); i < end; ++i) {
|
||||
for (int j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) {
|
||||
next_nodes_list[i].first[j] = group.first[next_nodes_list[i].first[j]];
|
||||
next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]];
|
||||
}
|
||||
nodes_list_output.push_back(next_nodes_list[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,5 +102,120 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) {
|
|||
ASSERT_TRUE(status.IsOK());
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) {
|
||||
onnxruntime::Model model("graph_1");
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor.
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
|
||||
|
||||
// BOOL tensor.
|
||||
ONNX_NAMESPACE::TypeProto bool_tensor;
|
||||
bool_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL);
|
||||
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
|
||||
|
||||
// UINT8 tensor.
|
||||
ONNX_NAMESPACE::TypeProto uint8_tensor;
|
||||
uint8_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
|
||||
uint8_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
uint8_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
uint8_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
|
||||
|
||||
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &bool_tensor);
|
||||
inputs.push_back(&input_arg_1);
|
||||
auto& output_arg_1 = graph.GetOrCreateNodeArg("node_1_out", &uint8_tensor);
|
||||
outputs.push_back(&output_arg_1);
|
||||
auto& cast_node = graph.AddNode("cast1", "Cast", "node 1.", inputs, outputs);
|
||||
AttributeProto attr_proto;
|
||||
attr_proto.set_name("to");
|
||||
attr_proto.set_type(AttributeProto_AttributeType_INT);
|
||||
attr_proto.set_i(2);
|
||||
cast_node.AddAttribute("to", attr_proto);
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(&output_arg_1);
|
||||
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &bool_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_2);
|
||||
auto& cast_node_2 = graph.AddNode("cast2", "Cast", "node 2.", inputs, outputs);
|
||||
AttributeProto attr_proto_2;
|
||||
attr_proto_2.set_name("to");
|
||||
attr_proto_2.set_type(AttributeProto_AttributeType_INT);
|
||||
attr_proto_2.set_i(9);
|
||||
cast_node_2.AddAttribute("to", attr_proto_2);
|
||||
|
||||
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
|
||||
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
|
||||
inputs.clear();
|
||||
inputs.push_back(&input_arg_2);
|
||||
inputs.push_back(&input_arg_3);
|
||||
auto& output_arg_3 = graph.GetOrCreateNodeArg("N", &float_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_3);
|
||||
graph.AddNode("sub", "Sub", "node 3.", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
std::string model_file_name = "trt_execution_provider_NodeIndexMappingTest.onnx";
|
||||
status = onnxruntime::Model::Save(model, model_file_name);
|
||||
|
||||
std::vector<int64_t> dims_mul_x = {1, 3, 2};
|
||||
std::vector<bool> values_mul_x = {true, false, true, false, true, false};
|
||||
std::vector<int64_t> dims_mul_y = {1, 3, 2};
|
||||
std::vector<float> values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<bool>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
feeds.insert(std::make_pair("Z", ml_value_z));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("M");
|
||||
output_names.push_back("N");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
|
||||
std::vector<bool> expected_values_mul_m = {true, false, true, false, true, false};
|
||||
std::vector<int64_t> expected_dims_mul_n = {1, 3, 2};
|
||||
std::vector<float> expected_values_mul_n = {0, 0, 0, 0, 0, 0};
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "TensorrtExecutionProviderTest.NodeIndexMappingTest";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSession session_object{so};
|
||||
|
||||
TensorrtExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::TensorrtExecutionProvider>(epi)).IsOK());
|
||||
|
||||
status = session_object.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
// Now run
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
std::vector<OrtValue> fetche {fetches.back()};
|
||||
VerifyOutputs(fetche, expected_dims_mul_n, expected_values_mul_n);
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue