mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Improve TensorRT engine caching (#5737)
* add profile caching to improve engine caching feature * Add comments * fix typo * add decryption for engine caching * Update tensorrt_execution_provider.cc * Update tensorrt_execution_provider.cc * Update tensorrt_execution_provider.cc * Update tensorrt_execution_provider.cc * Update tensorrt_execution_provider.cc * update onnx-tensorrt submodule * set opt profile to max value of the range * add hash to engine/profile name * Add calibration based INT8 quantization * add an option to enable both FP16 and INT8 * Update tensorrt_execution_provider.cc * add env variable to specify calibration file name * clean up code * Add comments and update TRT document * enable tensorrt basic test and add EngineCachingTest * clean up * update envrionment variable in the test * clean up
This commit is contained in:
parent
2a87108431
commit
54de618c2e
5 changed files with 289 additions and 115 deletions
|
|
@ -444,14 +444,14 @@ if (onnxruntime_USE_TENSORRT)
|
|||
|
||||
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs})
|
||||
add_library(onnxruntime_providers_tensorrt SHARED ${onnxruntime_providers_tensorrt_cc_srcs})
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx )
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers)
|
||||
add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
if(WIN32)
|
||||
target_link_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDA_HOME}/x64/lib64)
|
||||
else()
|
||||
target_link_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDA_HOME}/lib64)
|
||||
endif()
|
||||
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart onnxruntime_providers_shared protobuf::libprotobuf)
|
||||
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart onnxruntime_providers_shared protobuf::libprotobuf flatbuffers)
|
||||
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${onnxruntime_CUDNN_HOME}/include ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
|
||||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tensorrt DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
|
||||
|
|
|
|||
|
|
@ -61,12 +61,12 @@ There are several environment variables for TensorRT execution provider.
|
|||
|
||||
* ORT_TENSORRT_FP16_ENABLE: Enable FP16 mode in TensorRT
|
||||
|
||||
* ORT_TENSORRT_ENGINE_CACHE_ENABLE: Enable TensorRT engine caching. The purpose of using engine caching is to save engine build time in the cases that TensorRT may take long time to optimize and build engine. Engine will be cached after it's built at the first time so that next time when inference session is created the engine can be loaded directly from cache. Note each engine is created for specific settings such as precision (FP32/FP16/INT8 etc), workspace, profiles etc, and specific GPUs and it's not portable, so it's essential to make sure those settings are not changing, otherwise the engines need to be rebuilt and cached again.
|
||||
**Warning: Please clean up any old engine cache files (.engine) if any of the following changes:**
|
||||
* ORT_TENSORRT_ENGINE_CACHE_ENABLE: Enable TensorRT engine caching. The purpose of using engine caching is to save engine build time in the cases that TensorRT may take long time to optimize and build engine. Engine will be cached after it's built at the first time so that next time when inference session is created the engine can be loaded directly from cache. In order to validate that the loaded engine is usable for current inference, engine profile is also cached and loaded along with engine. If current input shapes are in the range of the engine profile, that means the loaded engine can be safely used. Otherwise if input shapes are out of range, profile cache will be updated to cover the new shape and engine will be recreated based on the new profile (and also refreshed in the engine cache). Note each engine is created for specific settings such as precision (FP32/FP16/INT8 etc), workspace, profiles etc, and specific GPUs and it's not portable, so it's essential to make sure those settings are not changing, otherwise the engines need to be rebuilt and cached again.
|
||||
**Warning: Please clean up any old engine and profile cache files (.engine and .profile) if any of the following changes:**
|
||||
- Model changes (if there are any changes to the model topology, opset version etc.)
|
||||
- ORT version changes (i.e. moving from ORT version 1.4 to 1.5)
|
||||
- TensorRT version changes (i.e. moving from TensorRT 7.0 to 7.1)
|
||||
- Hardware changes. (Engine files are not portable and optimized for specific Nvidia hardware)
|
||||
- Hardware changes. (Engine and profile files are not portable and optimized for specific Nvidia hardware)
|
||||
|
||||
* ORT_TENSORRT_ENGINE_CACHE_PATH: Specify path for TensorRT engine files if ORT_TENSORRT_ENGINE_CACHE_ENABLE is 1
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "flatbuffers/idl.h"
|
||||
|
||||
#define CUDA_RETURN_IF_ERROR(expr) \
|
||||
ORT_RETURN_IF_ERROR(CUDA_CALL(expr) \
|
||||
|
|
@ -30,23 +31,82 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::logging;
|
||||
namespace fs = std::experimental::filesystem;
|
||||
namespace {
|
||||
std::string GetEnginePath(const ::std::string& root, const std::string& name) {
|
||||
std::string GetCachePath(const std::string& root, const std::string& name) {
|
||||
if (root.empty()) {
|
||||
return name + ".engine";
|
||||
return name;
|
||||
} else {
|
||||
fs::path path = root;
|
||||
path.append(name + ".engine");
|
||||
path.append(name);
|
||||
return path.string();
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetVecHash(const ::std::vector<int>& vec) {
|
||||
std::string GetVecHash(const std::string& vec) {
|
||||
std::size_t ret = vec.size();
|
||||
for (auto i : vec) {
|
||||
ret ^= i + 0x9e3779b9 + (ret << 6) + (ret >> 2);
|
||||
ret ^= static_cast<int>(i) + 0x9e3779b9 + (ret << 6) + (ret >> 2);
|
||||
}
|
||||
return std::to_string(ret);
|
||||
}
|
||||
|
||||
/*
|
||||
* Seralize engine profile
|
||||
* The profile contains min/max shape ranges of every dynamic shape dimension for each input tensor
|
||||
* For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b
|
||||
* has one dynamic shape dimension: dim_1. The data in profile will be,
|
||||
* key: tensor_a, value: dim_0 min_shape max_shape dim_2 min_shape max_shape
|
||||
* key: tensor_b, value: dim_1 min_shape max_shape
|
||||
*/
|
||||
void SerializeProfile(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<int, std::pair<int64_t, int64_t>>>& shape_ranges) {
|
||||
// Serialize profile
|
||||
flexbuffers::Builder builder;
|
||||
auto profile_start = builder.StartMap();
|
||||
for (auto outer_it = shape_ranges.begin(); outer_it != shape_ranges.end(); ++outer_it) {
|
||||
builder.TypedVector(outer_it->first.c_str(), [&] {
|
||||
for (auto inner_it = outer_it->second.begin(); inner_it != outer_it->second.end(); ++inner_it) {
|
||||
builder.Int(inner_it->first);
|
||||
builder.Int(inner_it->second.first);
|
||||
builder.Int(inner_it->second.second);
|
||||
}
|
||||
});
|
||||
}
|
||||
builder.EndMap(profile_start);
|
||||
builder.Finish();
|
||||
|
||||
// Save flexbuffer
|
||||
std::ofstream file(file_name, std::ios::binary | std::ios::out);
|
||||
auto buf = builder.GetBuffer();
|
||||
size_t size = builder.GetSize();
|
||||
file.write(reinterpret_cast<const char*>(&buf[0]), size);
|
||||
file.close();
|
||||
}
|
||||
|
||||
// Deserialize engine profile
|
||||
std::unordered_map<std::string, std::unordered_map<int, std::pair<int64_t, int64_t>>> DeserializeProfile(std::ifstream& infile) {
|
||||
// Load flexbuffer
|
||||
infile.seekg(0, std::ios::end);
|
||||
int length = infile.tellg();
|
||||
infile.seekg(0, std::ios::beg);
|
||||
std::unique_ptr<char[]> data{new char[length]};
|
||||
infile.read((char*)data.get(), length);
|
||||
infile.close();
|
||||
|
||||
// Deserialize profile
|
||||
std::unordered_map<std::string, std::unordered_map<int, std::pair<int64_t, int64_t>>> shape_ranges;
|
||||
auto tensors_range_entries = flexbuffers::GetRoot((const uint8_t*)data.get(), length).AsMap();
|
||||
auto keys = tensors_range_entries.Keys();
|
||||
auto values = tensors_range_entries.Values();
|
||||
for (size_t i = 0, end = keys.size(); i < end; ++i) {
|
||||
auto dim_range_vectors = values[i].AsTypedVector();
|
||||
std::unordered_map<int, std::pair<int64_t, int64_t>> inner_map;
|
||||
for (size_t j = 0, end = dim_range_vectors.size() / 3; j < end; ++j) {
|
||||
size_t idx = 3 * j;
|
||||
inner_map[dim_range_vectors[idx].AsInt64()] = std::make_pair(dim_range_vectors[idx + 1].AsInt64(), dim_range_vectors[idx + 2].AsInt64());
|
||||
}
|
||||
shape_ranges[keys[i].AsString().c_str()] = inner_map;
|
||||
}
|
||||
return shape_ranges;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace google {
|
||||
|
|
@ -589,7 +649,7 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
|
|||
// Add non TensorRT nodes to the maps
|
||||
for (const auto& index : non_trt_node_index) {
|
||||
const auto& node = graph.GetNode(index);
|
||||
std::string node_name = node->Name();
|
||||
const std::string node_name = node->Name();
|
||||
if (node_to_index_map.find(node_name) == node_to_index_map.end()) {
|
||||
index_to_node_map[id] = node_name;
|
||||
node_to_index_map[node_name] = id++;
|
||||
|
|
@ -729,7 +789,8 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
|
||||
}
|
||||
const Provider_Graph& graph_body = func_body->Body();
|
||||
auto model = graph_body.CreateGraphViewer()->CreateModel(*GetLogger());
|
||||
auto graph_body_viewer = graph_body.CreateGraphViewer();
|
||||
auto model = graph_body_viewer->CreateModel(*GetLogger());
|
||||
auto model_proto = model->ToProto();
|
||||
|
||||
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
|
||||
|
|
@ -762,7 +823,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
|
||||
// Initialize shape range for dynamic shape tensors
|
||||
bool has_dynamic_shape = false;
|
||||
std::vector<int> input_shapes;
|
||||
for (unsigned int i = 0, end = num_inputs; i < end; ++i) {
|
||||
auto input = trt_network->getInput(i);
|
||||
const std::string& input_name = input->getName();
|
||||
|
|
@ -772,12 +832,9 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
// Shape tensor
|
||||
input_shape_ranges[input_name][0] = std::make_pair(INT_MAX, INT_MIN);
|
||||
has_dynamic_shape = true;
|
||||
for (int i = 0; i < nb_dims; i++)
|
||||
input_shapes.push_back(1); // dummy value
|
||||
} else {
|
||||
// Execution tensor
|
||||
for (int j = 0, end = nb_dims; j < end; ++j) {
|
||||
input_shapes.push_back(dims.d[j]); // could be neg.
|
||||
if (dims.d[j] == -1) {
|
||||
input_shape_ranges[input_name][j] = std::make_pair(INT_MAX, INT_MIN);
|
||||
has_dynamic_shape = true;
|
||||
|
|
@ -792,23 +849,25 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
trt_node_name_with_precision += "_fp16";
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled.";
|
||||
}
|
||||
|
||||
int num_nodes = graph_body_viewer->NumberOfNodes();
|
||||
trt_node_name_with_precision += "_" + GetVecHash(trt_node_name_with_precision + std::to_string(num_nodes));
|
||||
|
||||
// Build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will
|
||||
// be built at runtime
|
||||
tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine> trt_engine;
|
||||
tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext> trt_context;
|
||||
if (!has_dynamic_shape) {
|
||||
std::string trt_node_name_with_precision_shape = trt_node_name_with_precision + "_" + GetVecHash(input_shapes);
|
||||
std::string cached_path = GetEnginePath(engine_cache_path_, trt_node_name_with_precision_shape);
|
||||
std::ifstream plan_file(cached_path, std::ios::binary | std::ios::in);
|
||||
if (plan_file && engine_cache_enable_) {
|
||||
plan_file.seekg(0, std::ios::end);
|
||||
int engine_size = plan_file.tellg();
|
||||
plan_file.seekg(0, std::ios::beg);
|
||||
const std::string cache_path = GetCachePath(engine_cache_path_, trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + ".engine";
|
||||
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
|
||||
if (engine_cache_enable_ && engine_file) {
|
||||
engine_file.seekg(0, std::ios::end);
|
||||
int engine_size = engine_file.tellg();
|
||||
engine_file.seekg(0, std::ios::beg);
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
plan_file.read((char*)engine_buf.get(), engine_size);
|
||||
engine_file.read((char*)engine_buf.get(), engine_size);
|
||||
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + cached_path;
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
} else {
|
||||
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
|
||||
if (trt_engine == nullptr) {
|
||||
|
|
@ -817,10 +876,10 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
}
|
||||
if (engine_cache_enable_) {
|
||||
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
|
||||
std::ofstream file(cached_path, std::ios::binary | std::ios::out);
|
||||
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
|
||||
file.write(reinterpret_cast<char*>(serializedModel->data()), serializedModel->size());
|
||||
serializedModel->destroy();
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + cached_path;
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
|
||||
}
|
||||
}
|
||||
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
|
||||
|
|
@ -868,11 +927,12 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
NodeComputeInfo compute_info;
|
||||
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
|
||||
std::unique_ptr<TensorrtFuncState> p = onnxruntime::make_unique<TensorrtFuncState>();
|
||||
*p = {context->allocate_func, context->release_func, context->allocator_handle, parsers_[context->node_name].get(),
|
||||
&engines_[context->node_name], &contexts_[context->node_name], builders_[context->node_name].get(),
|
||||
networks_[context->node_name].get(), input_info_[context->node_name], output_info_[context->node_name],
|
||||
input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_,
|
||||
&max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, engine_cache_path_, runtime_, allocator_};
|
||||
*p = {context->allocate_func, context->release_func, context->allocator_handle, &parsers_[context->node_name],
|
||||
&engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
|
||||
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
|
||||
input_shape_ranges_[context->node_name], &tensorrt_mu_, &fp16_enable_, &max_workspace_size_,
|
||||
trt_node_name_with_precision, engine_cache_enable_, engine_cache_path_, runtime_,
|
||||
allocator_};
|
||||
*state = p.release();
|
||||
return 0;
|
||||
};
|
||||
|
|
@ -892,7 +952,7 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
const std::unordered_map<std::string, int>& output_indexes = (trt_state->output_info)[0];
|
||||
const std::unordered_map<std::string, int>& output_types = (trt_state->output_info)[1];
|
||||
auto& shape_ranges = trt_state->input_shape_ranges;
|
||||
auto trt_builder = trt_state->builder;
|
||||
auto trt_builder = trt_state->builder->get();
|
||||
auto trt_engine = trt_state->engine->get();
|
||||
auto trt_context = trt_state->context->get();
|
||||
auto alloc = trt_state->scratch_allocator;
|
||||
|
|
@ -902,9 +962,43 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
std::unordered_map<std::string, bool> dimension_update;
|
||||
std::unordered_map<std::string, std::vector<int32_t>> tensor_shape_values;
|
||||
nvinfer1::IOptimizationProfile* trt_profile = nullptr;
|
||||
std::vector<int> input_shapes;
|
||||
|
||||
// Load serialized engine
|
||||
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
|
||||
const std::string engine_cache_path = cache_path + ".engine";
|
||||
const std::string profile_cache_path = cache_path + ".profile";
|
||||
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
|
||||
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
|
||||
if (engine_file && profile_file && (trt_state->engine_cache_enable && trt_engine == nullptr)) {
|
||||
// Deserialize profile
|
||||
shape_ranges = DeserializeProfile(profile_file);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
|
||||
// Deserialize engine
|
||||
trt_state->context->reset();
|
||||
trt_state->engine->reset();
|
||||
engine_file.seekg(0, std::ios::end);
|
||||
int engine_size = engine_file.tellg();
|
||||
engine_file.seekg(0, std::ios::beg);
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
engine_file.read((char*)engine_buf.get(), engine_size);
|
||||
auto runtime_ = trt_state->runtime;
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
|
||||
runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
if (trt_state->engine->get() == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
|
||||
trt_engine = trt_state->engine->get();
|
||||
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
if (trt_state->context->get() == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
|
||||
}
|
||||
trt_context = trt_state->context->get();
|
||||
}
|
||||
|
||||
for (int i = 0, end = num_inputs; i < end; ++i) {
|
||||
auto input = trt_state->network->getInput(i);
|
||||
auto input = trt_state->network->get()->getInput(i);
|
||||
const std::string& input_name = input->getName();
|
||||
nvinfer1::Dims dims = input->getDimensions();
|
||||
int nb_dims = dims.nbDims;
|
||||
|
|
@ -917,7 +1011,7 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
if (iter != input_indexes.end()) {
|
||||
input_index = iter->second;
|
||||
}
|
||||
|
||||
|
||||
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
|
||||
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
|
||||
const auto& tensor_shapes = ort.GetTensorShape(tensor_info);
|
||||
|
|
@ -934,7 +1028,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
int32_t* input = new int32_t[shape_size];
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(input, ort.GetTensorData<int32_t>(input_tensor), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost));
|
||||
for (int j = 0; j < shape_size; ++j) {
|
||||
input_shapes.push_back(input[j]);
|
||||
tensor_shape_values[input_name][j] = input[j];
|
||||
}
|
||||
delete[] input;
|
||||
|
|
@ -944,7 +1037,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
int64_t* input = new int64_t[shape_size];
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(input, ort.GetTensorData<int64_t>(input_tensor), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost));
|
||||
for (int j = 0; j < shape_size; ++j) {
|
||||
input_shapes.push_back(input[j]);
|
||||
tensor_shape_values[input_name][j] = static_cast<int32_t>(input[j]);
|
||||
}
|
||||
delete[] input;
|
||||
|
|
@ -971,7 +1063,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
if (tensor_shape_value < shape_range[j].first) {
|
||||
shape_range[j].first = tensor_shape_value;
|
||||
shapes_min[j] = tensor_shape_value;
|
||||
shapes_opt[j] = tensor_shape_value;
|
||||
dimension_update[input_name] = true;
|
||||
}
|
||||
// Update shape range upper bound
|
||||
|
|
@ -1005,7 +1096,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);
|
||||
for (int j = 0, end = nb_dims; j < end; ++j) {
|
||||
const auto& tensor_shape = tensor_shapes[j];
|
||||
input_shapes.push_back(tensor_shape);
|
||||
if (shape_range.find(j) != shape_range.end()) {
|
||||
dims_min.d[j] = shape_range[j].first;
|
||||
dims_opt.d[j] = shape_range[j].second;
|
||||
|
|
@ -1015,7 +1105,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
if (tensor_shape < shape_range[j].first) {
|
||||
shape_range[j].first = tensor_shape;
|
||||
dims_min.d[j] = tensor_shape;
|
||||
dims_opt.d[j] = tensor_shape;
|
||||
dimension_update[input_name] = true;
|
||||
}
|
||||
// Update maximum dimension
|
||||
|
|
@ -1046,47 +1135,33 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
// Regenerate engine
|
||||
// Only one profile is generated, so no need to explicitly set optimization profile
|
||||
if (engine_update) {
|
||||
std::string trt_node_name_with_precision_shape = trt_state->trt_node_name_with_precision + "_" + GetVecHash(input_shapes);
|
||||
std::string cached_path = GetEnginePath(trt_state->engine_cache_path, trt_node_name_with_precision_shape);
|
||||
std::ifstream plan_file(cached_path, std::ios::binary | std::ios::in);
|
||||
trt_state->context->reset();
|
||||
trt_state->engine->reset();
|
||||
if (plan_file && trt_state->engine_cache_enable) {
|
||||
plan_file.seekg(0, std::ios::end);
|
||||
int engine_size = plan_file.tellg();
|
||||
plan_file.seekg(0, std::ios::beg);
|
||||
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
|
||||
plan_file.read((char*)engine_buf.get(), engine_size);
|
||||
auto runtime_ = trt_state->runtime;
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
|
||||
runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
|
||||
if (trt_state->engine->get() == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + cached_path;
|
||||
trt_engine = trt_state->engine->get();
|
||||
} else {
|
||||
auto trt_config = tensorrt_ptr::unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
|
||||
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
|
||||
trt_config->addOptimizationProfile(trt_profile);
|
||||
if (*(trt_state->fp16_enable_ptr) && trt_builder->platformHasFastFp16()) {
|
||||
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||
}
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
|
||||
trt_builder->buildEngineWithConfig(*trt_state->network, *trt_config));
|
||||
|
||||
if (trt_state->engine->get() == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
trt_engine = trt_state->engine->get();
|
||||
if (trt_state->engine_cache_enable) {
|
||||
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
|
||||
std::ofstream file(cached_path, std::ios::binary | std::ios::out);
|
||||
file.write(reinterpret_cast<char*>(serializedModel->data()), serializedModel->size());
|
||||
serializedModel->destroy();
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + cached_path;
|
||||
}
|
||||
auto trt_config = tensorrt_ptr::unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
|
||||
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
|
||||
trt_config->addOptimizationProfile(trt_profile);
|
||||
if (*(trt_state->fp16_enable_ptr) && trt_builder->platformHasFastFp16()) {
|
||||
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||
}
|
||||
|
||||
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
|
||||
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
|
||||
if (trt_state->engine->get() == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
|
||||
}
|
||||
trt_engine = trt_state->engine->get();
|
||||
if (trt_state->engine_cache_enable) {
|
||||
// Serialize engine profile
|
||||
SerializeProfile(profile_cache_path, shape_ranges);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
|
||||
// Serialize engine
|
||||
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
|
||||
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
|
||||
file.write(reinterpret_cast<char*>(serializedModel->data()), serializedModel->size());
|
||||
serializedModel->destroy();
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
|
||||
}
|
||||
|
||||
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
|
||||
trt_state->engine->get()->createExecutionContext());
|
||||
if (trt_state->context->get() == nullptr) {
|
||||
|
|
@ -1359,7 +1434,6 @@ common::Status TensorrtExecutionProvider::Provider_Compile(const std::vector<onn
|
|||
|
||||
node_compute_funcs.push_back(compute_info);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -25,7 +25,7 @@ class TensorrtLogger : public nvinfer1::ILogger {
|
|||
public:
|
||||
TensorrtLogger(Severity verbosity = Severity::kWARNING)
|
||||
: verbosity_(verbosity) {}
|
||||
void log(Severity severity, const char* msg) override {
|
||||
void log(Severity severity, const char* msg) noexcept override {
|
||||
if (severity <= verbosity_) {
|
||||
time_t rawtime = std::time(0);
|
||||
char buf[256];
|
||||
|
|
@ -69,11 +69,11 @@ struct TensorrtFuncState {
|
|||
AllocateFunc test_allocate_func = nullptr;
|
||||
DestroyFunc test_release_func = nullptr;
|
||||
AllocatorHandle allocator = nullptr;
|
||||
nvonnxparser::IParser* parser = nullptr;
|
||||
tensorrt_ptr::unique_pointer<nvonnxparser::IParser>* parser = nullptr;
|
||||
tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>* engine = nullptr;
|
||||
tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>* context = nullptr;
|
||||
nvinfer1::IBuilder* builder = nullptr;
|
||||
nvinfer1::INetworkDefinition* network = nullptr;
|
||||
tensorrt_ptr::unique_pointer<nvinfer1::IBuilder>* builder = nullptr;
|
||||
tensorrt_ptr::unique_pointer<nvinfer1::INetworkDefinition>* network = nullptr;
|
||||
std::vector<std::unordered_map<std::string, int>> input_info;
|
||||
std::vector<std::unordered_map<std::string, int>> output_info;
|
||||
std::unordered_map<std::string, std::unordered_map<int, std::pair<int64_t, int64_t>>> input_shape_ranges;
|
||||
|
|
@ -147,4 +147,4 @@ class TensorrtExecutionProvider : public Provider_IExecutionProvider {
|
|||
AllocatorPtr allocator_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,10 +5,8 @@
|
|||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#if 0 // TODO: Make this work with TensorRT as a shared library
|
||||
|
||||
#include "core/providers/tensorrt/tensorrt_execution_provider.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/util/include/scoped_env_vars.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -28,8 +26,116 @@ void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64
|
|||
ASSERT_EQ(expected_values, found);
|
||||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, EngineCachingTest) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_TENSORRT_ENGINE_CACHE_ENABLE", {"1"}},}};
|
||||
onnxruntime::Model model("enginecachingtest", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("sym1");
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("sym2");
|
||||
|
||||
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
|
||||
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
|
||||
inputs.push_back(&input_arg_1);
|
||||
inputs.push_back(&input_arg_2);
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
|
||||
outputs.push_back(&output_arg);
|
||||
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
|
||||
|
||||
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
|
||||
inputs.clear();
|
||||
inputs.push_back(&output_arg);
|
||||
inputs.push_back(&input_arg_3);
|
||||
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_2);
|
||||
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
std::string model_file_name = "trt_execution_provider_enginecaching_test.onnx";
|
||||
status = onnxruntime::Model::Save(model, model_file_name);
|
||||
|
||||
// First run with input shape {1, 3, 2}
|
||||
// TRT engine and profile will be created and cached
|
||||
// Data in profile,
|
||||
// X: 1, 3, 3, 2, 2, 2
|
||||
// Y: 1, 3, 3, 2, 2, 2
|
||||
// Z: 1, 3, 3, 2, 2, 2
|
||||
std::vector<int64_t> dims_mul_x = {1, 3, 2};
|
||||
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
feeds.insert(std::make_pair("Z", ml_value_z));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("M");
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
|
||||
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "TensorrtExecutionProviderTest.EngineCachingTest";
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = so.session_logid;
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = DefaultTensorrtExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
status = session_object.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
status = session_object.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
// Now run
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
|
||||
// Second run with input shape {1, 1, 6}
|
||||
// TRT engine and profile will be updated
|
||||
// Data in profile,
|
||||
// X: 1, 1, 3, 2, 2, 6
|
||||
// Y: 1, 1, 3, 2, 2, 6
|
||||
// Z: 1, 1, 3, 2, 2, 6
|
||||
dims_mul_x = {1, 1, 6};
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_y);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_z);
|
||||
feeds.clear();
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
feeds.insert(std::make_pair("Z", ml_value_z));
|
||||
|
||||
// prepare outputs
|
||||
fetches.clear();
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
expected_dims_mul_m = {1, 1, 6};
|
||||
|
||||
// Now run
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, FunctionTest) {
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Model model("functiontest", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
|
@ -66,11 +172,11 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) {
|
|||
std::vector<int64_t> dims_mul_x = {1, 3, 2};
|
||||
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_y);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_z);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
|
|
@ -92,9 +198,8 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) {
|
|||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
|
||||
TensorrtExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::TensorrtExecutionProvider>(epi)).IsOK());
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = DefaultTensorrtExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
status = session_object.Load(model_file_name);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
|
@ -108,7 +213,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) {
|
|||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) {
|
||||
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Model model("nodeindexmappingtest", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
|
@ -177,11 +282,11 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) {
|
|||
std::vector<int64_t> dims_mul_y = {1, 3, 2};
|
||||
std::vector<float> values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<bool>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<bool>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_y);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_z);
|
||||
CreateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
|
|
@ -206,9 +311,8 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) {
|
|||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
|
||||
TensorrtExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::TensorrtExecutionProvider>(epi)).IsOK());
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = DefaultTensorrtExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
|
@ -220,7 +324,7 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) {
|
|||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
|
||||
onnxruntime::Model model("graph_removecycleTest", false, DefaultLoggingManager().DefaultLogger());
|
||||
onnxruntime::Model model("removecycletest", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
|
@ -291,11 +395,11 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
|
|||
std::vector<bool> values_mul_z = {true, false, true, false, true, false};
|
||||
|
||||
OrtValue ml_value_x;
|
||||
CreateMLValue<bool>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
CreateMLValue<bool>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
|
||||
OrtValue ml_value_y;
|
||||
CreateMLValue<bool>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_y);
|
||||
CreateMLValue<bool>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_y);
|
||||
OrtValue ml_value_z;
|
||||
CreateMLValue<bool>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_z);
|
||||
CreateMLValue<bool>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_z);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
feeds.insert(std::make_pair("Y", ml_value_y));
|
||||
|
|
@ -317,9 +421,8 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
|
|||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
|
||||
TensorrtExecutionProviderInfo epi;
|
||||
epi.device_id = 0;
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::TensorrtExecutionProvider>(epi)).IsOK());
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = DefaultTensorrtExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
|
@ -328,8 +431,5 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
|
|||
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
|
||||
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
Loading…
Reference in a new issue