From 443aeb851c7f941237ffecb04d36462216eaeaf3 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Fri, 12 Jan 2024 18:10:05 -0800 Subject: [PATCH] [TensorRT EP] Customizable engine cache prefix (#19083) ### Description Add new option `trt_engine_cache_prefix` to customize TRTEP engine cache prefix. i.e: - If user specifies `trt_engine_cache_prefix|FRCNN trt_engine_cache_enable|true` when running FRCNN model - the cache will be saved/loaded: `FRCNN_2068723788287043730_*_sm80.engine`. Engine profile follows same pattern. - If skipping this option, the engine will be saved/loaded: `TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_2068723788287043730_*_*_sm80.engine` as default case. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/16708 --------- Co-authored-by: Chi Lo Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> --- .../tensorrt/tensorrt_provider_options.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 31 +++++++++--- .../tensorrt/tensorrt_execution_provider.h | 4 ++ .../tensorrt_execution_provider_info.cc | 6 +++ .../tensorrt_execution_provider_info.h | 1 + .../tensorrt_execution_provider_utils.h | 47 +++++++++++++++++++ .../tensorrt/tensorrt_provider_factory.cc | 1 + .../core/session/provider_bridge_ort.cc | 2 + .../python/onnxruntime_pybind_state.cc | 9 +++- .../test/perftest/command_args_parser.cc | 1 + .../providers/tensorrt/tensorrt_basic_test.cc | 25 ++++++++++ 11 files changed, 120 insertions(+), 8 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index daa4089061..60196d0c80 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -49,4 +49,5 @@ struct OrtTensorRTProviderOptionsV2 { int trt_dump_ep_context_model{0}; // Dump EP context node model int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 1d4ead019d..aa02d8384a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1352,6 +1352,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv detailed_build_log_ = info.detailed_build_log; if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { cache_path_ = info.engine_cache_path; + cache_prefix_ = info.engine_cache_prefix; } // use a more global cache if given if (timing_cache_enable_) { @@ -1463,6 +1464,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); + cache_prefix_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePrefix); if (!engine_cache_path.empty() && cache_path_.empty()) { cache_path_ = engine_cache_path; LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path"; @@ -1578,7 +1580,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv dla_core_ = 0; } - if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_ || !cache_prefix_.empty()) { if (!cache_path_.empty() && !fs::is_directory(cache_path_)) { if (!fs::create_directory(cache_path_)) { throw std::runtime_error("Failed to create directory " + cache_path_); @@ -1689,7 +1691,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_profile_min_shapes: " << profile_min_shapes << ", trt_profile_max_shapes: " << profile_max_shapes << ", trt_profile_opt_shapes: " << profile_opt_shapes - << ", trt_cuda_graph_enable: " << cuda_graph_enable_; + << ", trt_cuda_graph_enable: " << cuda_graph_enable_ + << ", trt_cache_prefix: " << cache_prefix_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -2026,7 +2029,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect bool has_control_flow_op = false; // Add node and node args - // If node output is also parent graph output, the output will be added to the + // If node output is also parent graph output, the output will be added to the // subgraph's output list std::vector subgraph_output_names; for (const auto& index : group.first) { @@ -2774,7 +2777,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; } - #if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 if (build_heuristics_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); @@ -2831,7 +2833,16 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); + std::string cache_suffix = ""; + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision); + cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; + } else { + cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); + } const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_; const std::string engine_cache_path = cache_path_prefix + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; @@ -3072,7 +3083,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, - builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics}; + builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix}; *state = p.release(); return 0; }; @@ -3124,7 +3135,13 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity // Prepare cache name - const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + } else { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + } const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_; const std::string engine_cache_path = cache_path_prefix + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 9b8798e0fc..401a8da119 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -49,6 +49,7 @@ static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE"; static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL"; static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE"; static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE"; +static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX"; // Old env variable for backward compatibility static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH"; } // namespace tensorrt_env_vars @@ -178,6 +179,8 @@ struct TensorrtFuncState { bool filter_tactic_sources = false; nvinfer1::TacticSources tactic_sources; bool cuda_graph_enable = 0; + std::string cache_prefix; + std::string cache_suffix; }; // Minimum information to construct kernel function state for direct engine load code path @@ -290,6 +293,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool force_timing_cache_match_ = false; bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; + std::string cache_prefix_; // The OrtAllocator object will be get during ep compute time // and should be kept for the lifetime of TRT EP object. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index f7820ac8a0..28f6e1720f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -26,6 +26,7 @@ constexpr const char* kDLACore = "trt_dla_core"; constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; constexpr const char* kEngineCachePath = "trt_engine_cache_path"; +constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix"; constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; @@ -81,6 +82,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) @@ -124,6 +126,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, + {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, @@ -155,6 +158,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); + const std::string kEngineCachePrefix_ = empty_if_null(info.trt_engine_cache_prefix); const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); @@ -178,6 +182,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, + {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, @@ -267,6 +272,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); + trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix); trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path); trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 76223b7847..a133ef45af 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -54,6 +54,7 @@ struct TensorrtExecutionProviderInfo { bool dump_ep_context_model{false}; int ep_context_embed_mode{0}; bool ep_context_compute_capability_enable{1}; + std::string engine_cache_prefix{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index 07f6f8eb34..a8e3ae3ddf 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include #include @@ -695,4 +697,49 @@ bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +std::string join(const std::vector& vec, const std::string& delimiter) { + std::string result; + for (size_t i = 0; i < vec.size(); ++i) { + result += vec[i]; + if (i < vec.size() - 1) { + result += delimiter; + } + } + return result; +} + +/* + * Parse engine cache name suffix when user customizes prefix for engine cache name + * + * For example: + * When default subgraph name is "TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_2068723788287043730_189_189_fp16" + * This func will generate the suffix "2068723788287043730_189_fp16" + * + */ +std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) { + std::vector split_fused_node_name = split(fused_node_name, '_'); + if (split_fused_node_name.size() >= 3) { + // Get index of model hash from fused_node_name + std::string model_hash = split_fused_node_name[split_fused_node_name.size() - 3]; + size_t index = fused_node_name.find(model_hash); + // Parse suffix from trt_node_name_with_precision, as it has additional precision info + std::vector suffix_group = split(trt_node_name_with_precision.substr(index), '_'); + if (suffix_group.size() > 2) { + suffix_group.erase(suffix_group.begin() + 2); + } + return join(suffix_group, "_"); + } + return ""; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 0e29df72f0..62f124afbd 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -119,6 +119,7 @@ struct Tensorrt_Provider : Provider { info.dump_ep_context_model = options.trt_dump_ep_context_model != 0; info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.ep_context_compute_capability_enable = options.trt_ep_context_compute_capability_enable != 0; + info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; return std::make_shared(info); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index b9fd79997a..45d8006e6b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1419,6 +1419,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_profile_max_shapes = ""; trt_options_converted.trt_profile_opt_shapes = ""; trt_options_converted.trt_cuda_graph_enable = 0; + trt_options_converted.trt_engine_cache_prefix = ""; return trt_options_converted; } @@ -1982,6 +1983,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor if (ptr != nullptr) { delete[] ptr->trt_int8_calibration_table_name; delete[] ptr->trt_engine_cache_path; + delete[] ptr->trt_engine_cache_prefix; delete[] ptr->trt_timing_cache_path; delete[] ptr->trt_engine_decryption_lib_path; delete[] ptr->trt_tactic_sources; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 06eb2afdf8..d2cd6140b8 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -475,7 +475,7 @@ std::unique_ptr CreateExecutionProviderInstance( // So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance. // (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance // and TRT EP instance, so it won't be released.) - std::string calibration_table, cache_path, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; + std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -572,6 +572,13 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a path string i.e. 'engine_cache'.\n"); } + } else if (option.first == "trt_engine_cache_prefix") { + if (!option.second.empty()) { + cache_prefix = option.second; + params.trt_engine_cache_prefix = cache_prefix.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n"); + } } else if (option.first == "trt_engine_decryption_enable") { if (option.second == "True" || option.second == "true") { params.trt_engine_decryption_enable = true; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index f8d6296d2d..f1b9f05a21 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -93,6 +93,7 @@ namespace perftest { "\t [TensorRT only] [trt_dump_subgraphs]: Dump TRT subgraph to onnx model.\n" "\t [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n" "\t [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n" + "\t [TensorRT only] [trt_engine_cache_prefix]: Customize engine cache prefix when trt_engine_cache_enable is true.\n" "\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n" "\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n" "\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n" diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index d9f917f6d1..508739ae1d 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -122,6 +122,19 @@ void CreateBaseModel(std::string model_name, status = onnxruntime::Model::Save(model, model_name); } +bool HasCacheFileWithPrefix(const std::string& prefix) { + const std::filesystem::path current_dir = std::filesystem::current_path(); + for (const auto& entry : std::filesystem::directory_iterator(current_dir)) { + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + if (filename.rfind(prefix, 0) == 0) { + return true; + } + } + } + return false; +} + void RunSession(InferenceSession& session_object, RunOptions& run_options, NameMLValMap& feeds, @@ -177,6 +190,7 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string OrtTensorRTProviderOptionsV2 params; params.trt_engine_cache_enable = 1; + params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -192,6 +206,9 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string // Y: 1, 3, 3, 2, 2, 2 // Z: 1, 3, 3, 2, 2, 2 RunSession(session_object, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + // Verify on cache with customized prefix + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); } void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string sess_log_id, bool has_non_zero_node = false) { @@ -227,6 +244,7 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string OrtTensorRTProviderOptionsV2 params; params.trt_engine_cache_enable = 1; + params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -253,6 +271,9 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string for (auto& th : threads) th.join(); + + // Verify on cache with customized prefix + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); } TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) { @@ -426,6 +447,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { */ params.trt_engine_cache_enable = 1; + params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -551,6 +573,9 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { status = session_object2.Run(run_options, feeds, output_names, &fetches); + // Verify on cache with customized prefix + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); + if (input_type.compare("static") == 0) { // Can't run inference since input shape changes but the engine is built with static input ASSERT_FALSE(status.IsOK());