mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-18 01:54:05 +00:00
[TensorRT EP] Customizable engine cache prefix (#19083)
### Description <!-- Describe your changes. --> Add new option `trt_engine_cache_prefix` to customize TRTEP engine cache prefix. i.e: - If user specifies `trt_engine_cache_prefix|FRCNN trt_engine_cache_enable|true` when running FRCNN model - the cache will be saved/loaded: `FRCNN_2068723788287043730_*_sm80.engine`. Engine profile follows same pattern. - If skipping this option, the engine will be saved/loaded: `TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_2068723788287043730_*_*_sm80.engine` as default case. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/16708 --------- Co-authored-by: Chi Lo <Chi.Lo@microsoft.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com>
This commit is contained in:
parent
150c4cb8fe
commit
443aeb851c
11 changed files with 120 additions and 8 deletions
|
|
@ -49,4 +49,5 @@ struct OrtTensorRTProviderOptionsV2 {
|
|||
int trt_dump_ep_context_model{0}; // Dump EP context node model
|
||||
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
|
||||
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
|
||||
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1352,6 +1352,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
detailed_build_log_ = info.detailed_build_log;
|
||||
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
|
||||
cache_path_ = info.engine_cache_path;
|
||||
cache_prefix_ = info.engine_cache_prefix;
|
||||
}
|
||||
// use a more global cache if given
|
||||
if (timing_cache_enable_) {
|
||||
|
|
@ -1463,6 +1464,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
|
||||
const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath);
|
||||
cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath);
|
||||
cache_prefix_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePrefix);
|
||||
if (!engine_cache_path.empty() && cache_path_.empty()) {
|
||||
cache_path_ = engine_cache_path;
|
||||
LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path";
|
||||
|
|
@ -1578,7 +1580,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
dla_core_ = 0;
|
||||
}
|
||||
|
||||
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
|
||||
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_ || !cache_prefix_.empty()) {
|
||||
if (!cache_path_.empty() && !fs::is_directory(cache_path_)) {
|
||||
if (!fs::create_directory(cache_path_)) {
|
||||
throw std::runtime_error("Failed to create directory " + cache_path_);
|
||||
|
|
@ -1689,7 +1691,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
<< ", trt_profile_min_shapes: " << profile_min_shapes
|
||||
<< ", trt_profile_max_shapes: " << profile_max_shapes
|
||||
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
|
||||
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_;
|
||||
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_
|
||||
<< ", trt_cache_prefix: " << cache_prefix_;
|
||||
}
|
||||
|
||||
TensorrtExecutionProvider::~TensorrtExecutionProvider() {
|
||||
|
|
@ -2026,7 +2029,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
bool has_control_flow_op = false;
|
||||
|
||||
// Add node and node args
|
||||
// If node output is also parent graph output, the output will be added to the
|
||||
// If node output is also parent graph output, the output will be added to the
|
||||
// subgraph's output list
|
||||
std::vector<std::string> subgraph_output_names;
|
||||
for (const auto& index : group.first) {
|
||||
|
|
@ -2774,7 +2777,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
|
|||
trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
|
||||
}
|
||||
|
||||
#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5
|
||||
if (build_heuristics_enable_) {
|
||||
trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
|
||||
|
|
@ -2831,7 +2833,16 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
|
|||
|
||||
// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
|
||||
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
|
||||
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
|
||||
std::string cache_suffix = "";
|
||||
std::string cache_path = "";
|
||||
// Customize cache prefix if assigned
|
||||
if (!cache_prefix_.empty()) {
|
||||
// Generate cache suffix in case user would like to customize cache prefix
|
||||
cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision);
|
||||
cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix;
|
||||
} else {
|
||||
cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
|
||||
}
|
||||
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
|
||||
const std::string engine_cache_path = cache_path_prefix + ".engine";
|
||||
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
|
||||
|
|
@ -3072,7 +3083,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
|
|||
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
|
||||
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
|
||||
global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_,
|
||||
builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics};
|
||||
builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix};
|
||||
*state = p.release();
|
||||
return 0;
|
||||
};
|
||||
|
|
@ -3124,7 +3135,13 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
|
|||
// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
|
||||
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
|
||||
// Prepare cache name
|
||||
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
|
||||
std::string cache_path = "";
|
||||
// Customize cache prefix if assigned
|
||||
if (!cache_prefix_.empty()) {
|
||||
cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix;
|
||||
} else {
|
||||
cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
|
||||
}
|
||||
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
|
||||
const std::string engine_cache_path = cache_path_prefix + ".engine";
|
||||
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
|
|||
static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
|
||||
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
|
||||
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
|
||||
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
|
||||
// Old env variable for backward compatibility
|
||||
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
|
||||
} // namespace tensorrt_env_vars
|
||||
|
|
@ -178,6 +179,8 @@ struct TensorrtFuncState {
|
|||
bool filter_tactic_sources = false;
|
||||
nvinfer1::TacticSources tactic_sources;
|
||||
bool cuda_graph_enable = 0;
|
||||
std::string cache_prefix;
|
||||
std::string cache_suffix;
|
||||
};
|
||||
|
||||
// Minimum information to construct kernel function state for direct engine load code path
|
||||
|
|
@ -290,6 +293,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
bool force_timing_cache_match_ = false;
|
||||
bool detailed_build_log_ = false;
|
||||
bool cuda_graph_enable_ = false;
|
||||
std::string cache_prefix_;
|
||||
|
||||
// The OrtAllocator object will be get during ep compute time
|
||||
// and should be kept for the lifetime of TRT EP object.
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ constexpr const char* kDLACore = "trt_dla_core";
|
|||
constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs";
|
||||
constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable";
|
||||
constexpr const char* kEngineCachePath = "trt_engine_cache_path";
|
||||
constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix";
|
||||
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
|
||||
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
|
||||
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
|
||||
|
|
@ -81,6 +82,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
|
|||
.AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
|
||||
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
|
||||
|
|
@ -124,6 +126,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
|
|||
{tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)},
|
||||
{tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)},
|
||||
{tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)},
|
||||
{tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)},
|
||||
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
|
||||
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
|
||||
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
|
||||
|
|
@ -155,6 +158,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
|
|||
auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; };
|
||||
const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name);
|
||||
const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path);
|
||||
const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix);
|
||||
const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path);
|
||||
const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources);
|
||||
const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path);
|
||||
|
|
@ -178,6 +182,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
|
|||
{tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)},
|
||||
{tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)},
|
||||
{tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_},
|
||||
{tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_},
|
||||
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)},
|
||||
{tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_},
|
||||
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)},
|
||||
|
|
@ -267,6 +272,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
|
|||
trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable;
|
||||
|
||||
trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path);
|
||||
trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix);
|
||||
trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path);
|
||||
|
||||
trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ struct TensorrtExecutionProviderInfo {
|
|||
bool dump_ep_context_model{false};
|
||||
int ep_context_embed_mode{0};
|
||||
bool ep_context_compute_capability_enable{1};
|
||||
std::string engine_cache_prefix{""};
|
||||
|
||||
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <filesystem>
|
||||
#include <experimental/filesystem>
|
||||
|
|
@ -695,4 +697,49 @@ bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<st
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> split(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
std::string token;
|
||||
std::istringstream tokenStream(str);
|
||||
while (std::getline(tokenStream, token, delimiter)) {
|
||||
tokens.push_back(token);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::string join(const std::vector<std::string>& vec, const std::string& delimiter) {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
result += vec[i];
|
||||
if (i < vec.size() - 1) {
|
||||
result += delimiter;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Parse engine cache name suffix when user customizes prefix for engine cache name
|
||||
*
|
||||
* For example:
|
||||
* When default subgraph name is "TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_2068723788287043730_189_189_fp16"
|
||||
* This func will generate the suffix "2068723788287043730_189_fp16"
|
||||
*
|
||||
*/
|
||||
std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) {
|
||||
std::vector<std::string> split_fused_node_name = split(fused_node_name, '_');
|
||||
if (split_fused_node_name.size() >= 3) {
|
||||
// Get index of model hash from fused_node_name
|
||||
std::string model_hash = split_fused_node_name[split_fused_node_name.size() - 3];
|
||||
size_t index = fused_node_name.find(model_hash);
|
||||
// Parse suffix from trt_node_name_with_precision, as it has additional precision info
|
||||
std::vector<std::string> suffix_group = split(trt_node_name_with_precision.substr(index), '_');
|
||||
if (suffix_group.size() > 2) {
|
||||
suffix_group.erase(suffix_group.begin() + 2);
|
||||
}
|
||||
return join(suffix_group, "_");
|
||||
}
|
||||
return "";
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ struct Tensorrt_Provider : Provider {
|
|||
info.dump_ep_context_model = options.trt_dump_ep_context_model != 0;
|
||||
info.ep_context_embed_mode = options.trt_ep_context_embed_mode;
|
||||
info.ep_context_compute_capability_enable = options.trt_ep_context_compute_capability_enable != 0;
|
||||
info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix;
|
||||
|
||||
return std::make_shared<TensorrtProviderFactory>(info);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1419,6 +1419,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti
|
|||
trt_options_converted.trt_profile_max_shapes = "";
|
||||
trt_options_converted.trt_profile_opt_shapes = "";
|
||||
trt_options_converted.trt_cuda_graph_enable = 0;
|
||||
trt_options_converted.trt_engine_cache_prefix = "";
|
||||
|
||||
return trt_options_converted;
|
||||
}
|
||||
|
|
@ -1982,6 +1983,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
|
|||
if (ptr != nullptr) {
|
||||
delete[] ptr->trt_int8_calibration_table_name;
|
||||
delete[] ptr->trt_engine_cache_path;
|
||||
delete[] ptr->trt_engine_cache_prefix;
|
||||
delete[] ptr->trt_timing_cache_path;
|
||||
delete[] ptr->trt_engine_decryption_lib_path;
|
||||
delete[] ptr->trt_tactic_sources;
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
// So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance.
|
||||
// (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance
|
||||
// and TRT EP instance, so it won't be released.)
|
||||
std::string calibration_table, cache_path, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile;
|
||||
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile;
|
||||
auto it = provider_options_map.find(type);
|
||||
if (it != provider_options_map.end()) {
|
||||
OrtTensorRTProviderOptionsV2 params;
|
||||
|
|
@ -572,6 +572,13 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
} else {
|
||||
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a path string i.e. 'engine_cache'.\n");
|
||||
}
|
||||
} else if (option.first == "trt_engine_cache_prefix") {
|
||||
if (!option.second.empty()) {
|
||||
cache_prefix = option.second;
|
||||
params.trt_engine_cache_prefix = cache_prefix.c_str();
|
||||
} else {
|
||||
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n");
|
||||
}
|
||||
} else if (option.first == "trt_engine_decryption_enable") {
|
||||
if (option.second == "True" || option.second == "true") {
|
||||
params.trt_engine_decryption_enable = true;
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ namespace perftest {
|
|||
"\t [TensorRT only] [trt_dump_subgraphs]: Dump TRT subgraph to onnx model.\n"
|
||||
"\t [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n"
|
||||
"\t [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n"
|
||||
"\t [TensorRT only] [trt_engine_cache_prefix]: Customize engine cache prefix when trt_engine_cache_enable is true.\n"
|
||||
"\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n"
|
||||
"\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n"
|
||||
"\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n"
|
||||
|
|
|
|||
|
|
@ -122,6 +122,19 @@ void CreateBaseModel(std::string model_name,
|
|||
status = onnxruntime::Model::Save(model, model_name);
|
||||
}
|
||||
|
||||
bool HasCacheFileWithPrefix(const std::string& prefix) {
|
||||
const std::filesystem::path current_dir = std::filesystem::current_path();
|
||||
for (const auto& entry : std::filesystem::directory_iterator(current_dir)) {
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
if (filename.rfind(prefix, 0) == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RunSession(InferenceSession& session_object,
|
||||
RunOptions& run_options,
|
||||
NameMLValMap& feeds,
|
||||
|
|
@ -177,6 +190,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string
|
|||
|
||||
OrtTensorRTProviderOptionsV2 params;
|
||||
params.trt_engine_cache_enable = 1;
|
||||
params.trt_engine_cache_prefix = "TRTEP_Cache_Test";
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(¶ms);
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
auto status = session_object.Load(model_name);
|
||||
|
|
@ -192,6 +206,9 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string
|
|||
// Y: 1, 3, 3, 2, 2, 2
|
||||
// Z: 1, 3, 3, 2, 2, 2
|
||||
RunSession(session_object, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
|
||||
|
||||
// Verify on cache with customized prefix
|
||||
ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix));
|
||||
}
|
||||
|
||||
void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string sess_log_id, bool has_non_zero_node = false) {
|
||||
|
|
@ -227,6 +244,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string
|
|||
|
||||
OrtTensorRTProviderOptionsV2 params;
|
||||
params.trt_engine_cache_enable = 1;
|
||||
params.trt_engine_cache_prefix = "TRTEP_Cache_Test";
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(¶ms);
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
auto status = session_object.Load(model_name);
|
||||
|
|
@ -253,6 +271,9 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string
|
|||
|
||||
for (auto& th : threads)
|
||||
th.join();
|
||||
|
||||
// Verify on cache with customized prefix
|
||||
ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix));
|
||||
}
|
||||
|
||||
TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) {
|
||||
|
|
@ -426,6 +447,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) {
|
|||
*/
|
||||
|
||||
params.trt_engine_cache_enable = 1;
|
||||
params.trt_engine_cache_prefix = "TRTEP_Cache_Test";
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(¶ms);
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
auto status = session_object.Load(model_name);
|
||||
|
|
@ -551,6 +573,9 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) {
|
|||
|
||||
status = session_object2.Run(run_options, feeds, output_names, &fetches);
|
||||
|
||||
// Verify on cache with customized prefix
|
||||
ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix));
|
||||
|
||||
if (input_type.compare("static") == 0) {
|
||||
// Can't run inference since input shape changes but the engine is built with static input
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
|
|
|
|||
Loading…
Reference in a new issue