mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix bugs in TensorRT (#4780)
* fix bugs * Move -Wno-deprecated-declarations to target compile flag
This commit is contained in:
parent
aa993e95c9
commit
7acef875bb
4 changed files with 8 additions and 8 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -61,5 +61,4 @@
|
|||
url = https://github.com/google/libprotobuf-mutator.git
|
||||
[submodule "cmake/external/onnx-tensorrt"]
|
||||
path = cmake/external/onnx-tensorrt
|
||||
url = https://github.com/stevenlix/onnx-tensorrt.git
|
||||
branch = trt71
|
||||
url = https://github.com/onnx/onnx-tensorrt.git
|
||||
|
|
|
|||
2
cmake/external/onnx-tensorrt
vendored
2
cmake/external/onnx-tensorrt
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit 9d962f4510102e296944b9be046047a953be0ee1
|
||||
Subproject commit 088554a5fbee9ba183c05c09c1abe986034e9208
|
||||
|
|
@ -422,6 +422,7 @@ if (onnxruntime_USE_TENSORRT)
|
|||
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst")
|
||||
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync_cpp)
|
||||
elseif(UNIX)
|
||||
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
|
||||
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections")
|
||||
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync_cpp stdc++fs)
|
||||
elseif(WIN32)
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ bool FindCycleHelper(int i, const std::list<int>* adjacency_map,
|
|||
}
|
||||
|
||||
std::unique_ptr<Provider_IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, const onnxruntime::Provider_GraphViewer& graph) const {
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<size_t> node_set;
|
||||
node_set.reserve(graph_nodes_index.first.size());
|
||||
for (const auto& index : graph_nodes_index.first) {
|
||||
|
|
@ -915,11 +915,11 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
// Get shape values for shape tensor input
|
||||
const auto& tensor_type = ort.GetTensorElementType(tensor_info);
|
||||
int shape_size = nb_dims == 0 ? 1 : tensor_shapes[0];
|
||||
tensor_shape_values[input_name].reserve(shape_size);
|
||||
tensor_shape_values[input_name].resize(shape_size);
|
||||
switch (tensor_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
|
||||
int32_t* input = new int32_t[shape_size];
|
||||
cudaMemcpy(input, ort.GetTensorData<int32_t>(input_tensor), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(input, ort.GetTensorData<int32_t>(input_tensor), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost));
|
||||
for (int j = 0; j < shape_size; ++j) {
|
||||
tensor_shape_values[input_name][j] = input[j];
|
||||
}
|
||||
|
|
@ -928,7 +928,7 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
}
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
|
||||
int64_t* input = new int64_t[shape_size];
|
||||
cudaMemcpy(input, ort.GetTensorData<int64_t>(input_tensor), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(input, ort.GetTensorData<int64_t>(input_tensor), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost));
|
||||
for (int j = 0; j < shape_size; ++j) {
|
||||
tensor_shape_values[input_name][j] = static_cast<int32_t>(input[j]);
|
||||
}
|
||||
|
|
@ -1333,4 +1333,4 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue