mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix subgraph index issue in TRT (#14305)
Subgraph index in TRT engine name keeps increasing when multiple sessions are created for the same model, which causes TRT engine not being reused and new engine is created again. The issue is because trt_model_id_generator_ is defined globally. This PR made following changes and improvements, 1. Define subgraph index as local variable thus it won't be shared across sessions. 2. Decouple subgraph index from hash id generator 3. Call hash id generator once at the beginning of GetCapability since hash id is shared between TRT subgraphs and there is no need to call it for each subgraph fix https://github.com/microsoft/onnxruntime/issues/14269
This commit is contained in:
parent
6d60dc24fe
commit
49cfb56cc3
4 changed files with 120 additions and 166 deletions
|
|
@ -623,7 +623,7 @@ bool TensorrtExecutionProvider::IsSubGraphFullySupported(SubGraphCollection_t su
|
|||
return number_of_trt_nodes == number_of_ort_nodes;
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const {
|
||||
std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph, const HashValue& model_hash, int subgraph_index) const {
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<size_t> node_set;
|
||||
node_set.reserve(graph_nodes_index.first.size());
|
||||
|
|
@ -742,12 +742,11 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
|
|||
}
|
||||
|
||||
// Generate unique kernel name for TRT subgraph
|
||||
HashValue model_hash = 0;
|
||||
int id = TRTGenerateModelId(graph, model_hash);
|
||||
std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id);
|
||||
std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index);
|
||||
auto meta_def = IndexedSubGraph_MetaDef::Create();
|
||||
const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph";
|
||||
meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + subgraph_id;
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT subgraph MetaDef name " + meta_def->name();
|
||||
|
||||
// Assign inputs and outputs to subgraph's meta_def
|
||||
for (const auto& input : inputs) {
|
||||
|
|
@ -945,7 +944,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
}
|
||||
|
||||
// Detect and remove cycles from supported node list
|
||||
bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph, bool remove_cycles) const {
|
||||
bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph, const HashValue& model_hash, bool remove_cycles) const {
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
bool trt_cycle = true, cycle_detected = false;
|
||||
while (trt_cycle) {
|
||||
|
|
@ -955,10 +954,11 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
|
|||
std::unordered_map<std::string, std::unordered_set<std::string>> input_to_nodes_map, node_to_outputs_map;
|
||||
std::unordered_set<size_t> non_trt_node_index(node_index.begin(), node_index.end());
|
||||
size_t id = 0;
|
||||
int subgraph_index = 0;
|
||||
for (const auto& group : supported_nodes_vector) {
|
||||
if (!group.first.empty()) {
|
||||
// Construct subgraph from node list
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph);
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index);
|
||||
|
||||
// Create node to inputs/outputs/index maps
|
||||
const auto& meta_def = sub_graph->GetMetaDef();
|
||||
|
|
@ -981,6 +981,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
|
|||
for (const auto& index : group.first) {
|
||||
non_trt_node_index.erase(node_index[index]);
|
||||
}
|
||||
subgraph_index++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1070,6 +1071,9 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
strcpy(model_path_, path_string.c_str());
|
||||
#endif
|
||||
|
||||
// Generate unique kernel name for TRT graph
|
||||
HashValue model_hash = TRTGenerateId(graph);
|
||||
|
||||
// Get supported node list from TensorRT parser
|
||||
const int number_of_ort_nodes = graph.NumberOfNodes();
|
||||
std::vector<size_t> nodes_vector(number_of_ort_nodes);
|
||||
|
|
@ -1124,7 +1128,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
}
|
||||
|
||||
// Detect and remove cycles from supported node list
|
||||
DetectTensorRTGraphCycles(supported_nodes_vector, graph);
|
||||
DetectTensorRTGraphCycles(supported_nodes_vector, graph, model_hash);
|
||||
|
||||
// Consolidate supported node list
|
||||
if (supported_nodes_vector.size() > 1) {
|
||||
|
|
@ -1135,7 +1139,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
}
|
||||
}
|
||||
SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}};
|
||||
if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, false)) {
|
||||
if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation";
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] TensorRT nodes are consolidated into one subgraph";
|
||||
|
|
@ -1191,12 +1195,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
}
|
||||
}
|
||||
|
||||
int number_of_trt_nodes = 0;
|
||||
int number_of_trt_nodes = 0, subgraph_index = 0;
|
||||
for (const auto& group : supported_nodes_vector) {
|
||||
if (!group.first.empty()) {
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph);
|
||||
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index);
|
||||
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
|
||||
number_of_trt_nodes += static_cast<int>(group.first.size());
|
||||
subgraph_index++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1264,7 +1269,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast<nvinfer1::IElementWiseLayer*>(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow";
|
||||
layer->setPrecision(nvinfer1::DataType::kFLOAT);
|
||||
next_layer->setPrecision(nvinfer1::DataType::kFLOAT);
|
||||
next_layer->setPrecision(nvinfer1::DataType::kFLOAT);
|
||||
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
|
||||
next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,9 +27,9 @@ static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"
|
|||
static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
|
||||
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
|
||||
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
|
||||
static const std::string kForceSequentialEngineBuild= "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD";
|
||||
static const std::string kContextMemorySharingEnable= "ORT_TENSORRT_CONTEXT_MEMORY_SHARING_ENABLE";
|
||||
static const std::string kLayerNormFP32Fallback= "ORT_TENSORRT_LAYER_NORM_FP32_FALLBACK";
|
||||
static const std::string kForceSequentialEngineBuild = "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD";
|
||||
static const std::string kContextMemorySharingEnable = "ORT_TENSORRT_CONTEXT_MEMORY_SHARING_ENABLE";
|
||||
static const std::string kLayerNormFP32Fallback = "ORT_TENSORRT_LAYER_NORM_FP32_FALLBACK";
|
||||
// Old env variable for backward compatibility
|
||||
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
|
||||
} // namespace tensorrt_env_vars
|
||||
|
|
@ -193,7 +193,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
|
||||
/**Get IndexedSubGraph based on node list of the subgraph*/
|
||||
std::unique_ptr<IndexedSubGraph> GetSubGraph(SubGraph_t graph_nodes_index,
|
||||
const GraphViewer& graph) const;
|
||||
const GraphViewer& graph, const HashValue& model_hash, int subgraph_index) const;
|
||||
|
||||
/**
|
||||
Get TensorRT supported node lists by calling Onnx-TensorRT parser recursively. Since each time the parser
|
||||
|
|
@ -205,7 +205,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations,
|
||||
const GraphViewer& graph, bool* early_termination) const;
|
||||
|
||||
bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph, bool remove_cycles = true) const;
|
||||
bool DetectTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph, const HashValue& model_hash, bool remove_cycles = true) const;
|
||||
|
||||
/**
|
||||
Get a unique_lock object to control the concurrency behavior.
|
||||
|
|
|
|||
|
|
@ -28,23 +28,23 @@ float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) {
|
|||
}
|
||||
|
||||
/*
|
||||
* Read calibration table for INT8 quantization
|
||||
* Two kind of calibration tables are supported,
|
||||
* 1. ORT generated calibration table
|
||||
* The table is pre-serialized by flatbuffers.
|
||||
* Each entry in the table is a key-value pair,
|
||||
* key: tensor name, value: maximum absolute value in floating point
|
||||
* For example,
|
||||
* data_0 2.008338
|
||||
* ...
|
||||
* 2. Native TensorRT generated calibration table
|
||||
* Data format is defined by TensorRT as,
|
||||
* tensor name : scale in 32-bit single precision IEEE754 format
|
||||
* For example,
|
||||
* TRT-7103-EntropyCalibration2
|
||||
* data_0: 4000889d
|
||||
* ...
|
||||
*/
|
||||
* Read calibration table for INT8 quantization
|
||||
* Two kind of calibration tables are supported,
|
||||
* 1. ORT generated calibration table
|
||||
* The table is pre-serialized by flatbuffers.
|
||||
* Each entry in the table is a key-value pair,
|
||||
* key: tensor name, value: maximum absolute value in floating point
|
||||
* For example,
|
||||
* data_0 2.008338
|
||||
* ...
|
||||
* 2. Native TensorRT generated calibration table
|
||||
* Data format is defined by TensorRT as,
|
||||
* tensor name : scale in 32-bit single precision IEEE754 format
|
||||
* For example,
|
||||
* TRT-7103-EntropyCalibration2
|
||||
* data_0: 4000889d
|
||||
* ...
|
||||
*/
|
||||
bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration_table, std::unordered_map<std::string, float>& dynamic_range_map) {
|
||||
std::ifstream infile(file_name, std::ios::binary | std::ios::in);
|
||||
if (!infile) {
|
||||
|
|
@ -95,13 +95,13 @@ bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration
|
|||
}
|
||||
|
||||
/*
|
||||
* Seralize engine profile
|
||||
* The profile contains min/max shape ranges of dynamic shape dimensions of each input tensor
|
||||
* For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b
|
||||
* has one dynamic shape dimension: dim_1. The data in profile will be,
|
||||
* key: tensor_a, value: dim_0 min_shape max_shape dim_2 min_shape max_shape
|
||||
* key: tensor_b, value: dim_1 min_shape max_shape
|
||||
*/
|
||||
* Seralize engine profile
|
||||
* The profile contains min/max shape ranges of dynamic shape dimensions of each input tensor
|
||||
* For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b
|
||||
* has one dynamic shape dimension: dim_1. The data in profile will be,
|
||||
* key: tensor_a, value: dim_0 min_shape max_shape dim_2 min_shape max_shape
|
||||
* key: tensor_b, value: dim_1 min_shape max_shape
|
||||
*/
|
||||
void SerializeProfile(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>>& shape_ranges) {
|
||||
// Serialize profile
|
||||
flexbuffers::Builder builder;
|
||||
|
|
@ -170,15 +170,15 @@ std::string GetCachePath(const std::string& root, const std::string& name) {
|
|||
/*
|
||||
* Get cache by type
|
||||
*
|
||||
* \param root root path of the cache
|
||||
* \param root root path of the cache
|
||||
* \param file_extension It could be ".engine", ".profile" or ".timing"
|
||||
*/
|
||||
*/
|
||||
std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_extension) {
|
||||
std::vector<fs::path> cache_files;
|
||||
for (const auto & entry : fs::directory_iterator(root)) {
|
||||
if (fs::path(file_extension) == fs::path(entry).extension()) {
|
||||
cache_files.push_back(fs::path(entry));
|
||||
}
|
||||
for (const auto& entry : fs::directory_iterator(root)) {
|
||||
if (fs::path(file_extension) == fs::path(entry).extension()) {
|
||||
cache_files.push_back(fs::path(entry));
|
||||
}
|
||||
}
|
||||
return cache_files;
|
||||
}
|
||||
|
|
@ -186,118 +186,106 @@ std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_
|
|||
bool IsCacheExistedByType(const std::string& root, std::string file_extension) {
|
||||
auto cache_files = GetCachesByType(root, file_extension);
|
||||
if (cache_files.size() == 0) {
|
||||
return false;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RemoveCachesByType(const std::string& root, std::string file_extension) {
|
||||
auto cache_files = GetCachesByType(root, file_extension);
|
||||
for (const auto & entry : cache_files) {
|
||||
for (const auto& entry : cache_files) {
|
||||
fs::remove(entry);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper class to generate engine id via model name/model content/env metadata
|
||||
class TRTModelIdGenerator {
|
||||
public:
|
||||
int TRTGenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) {
|
||||
model_hash = 0;
|
||||
HashValue TRTGenerateId(const GraphViewer& graph_viewer) {
|
||||
HashValue model_hash = 0;
|
||||
|
||||
// find the top level graph
|
||||
const Graph* cur_graph = &graph_viewer.GetGraph();
|
||||
while (cur_graph->IsSubgraph()) {
|
||||
cur_graph = cur_graph->ParentGraph();
|
||||
}
|
||||
// find the top level graph
|
||||
const Graph* cur_graph = &graph_viewer.GetGraph();
|
||||
while (cur_graph->IsSubgraph()) {
|
||||
cur_graph = cur_graph->ParentGraph();
|
||||
}
|
||||
|
||||
const Graph& main_graph = *cur_graph;
|
||||
uint32_t hash[4] = {0, 0, 0, 0};
|
||||
const Graph& main_graph = *cur_graph;
|
||||
uint32_t hash[4] = {0, 0, 0, 0};
|
||||
|
||||
auto hash_str = [&hash](const std::string& str) {
|
||||
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
|
||||
};
|
||||
auto hash_str = [&hash](const std::string& str) {
|
||||
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
|
||||
};
|
||||
|
||||
// Use model name instead of path to avoid cache regeneration if path changes
|
||||
const auto& model_path = main_graph.ModelPath();
|
||||
if (!model_path.IsEmpty()) {
|
||||
// Get model name
|
||||
PathString path_string = model_path.GetComponents().back();
|
||||
char arr[256];
|
||||
// Use model name instead of path to avoid cache regeneration if path changes
|
||||
const auto& model_path = main_graph.ModelPath();
|
||||
if (!model_path.IsEmpty()) {
|
||||
// Get model name
|
||||
PathString path_string = model_path.GetComponents().back();
|
||||
char arr[256];
|
||||
#ifdef _WIN32
|
||||
wcstombs_s(nullptr, arr, sizeof(arr), path_string.c_str(), sizeof(arr));
|
||||
wcstombs_s(nullptr, arr, sizeof(arr), path_string.c_str(), sizeof(arr));
|
||||
#else
|
||||
strcpy(arr, path_string.c_str());
|
||||
strcpy(arr, path_string.c_str());
|
||||
#endif
|
||||
std::string model_name(arr);
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name;
|
||||
// Ensure enough characters are hashed in case model names are too short
|
||||
int32_t model_name_length = gsl::narrow_cast<int32_t>(model_name.size());
|
||||
constexpr int32_t hash_string_length = 500;
|
||||
std::string repeat_model_name = model_name;
|
||||
for (int i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) {
|
||||
repeat_model_name += model_name;
|
||||
}
|
||||
hash_str(repeat_model_name);
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty";
|
||||
std::string model_name(arr);
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name;
|
||||
// Ensure enough characters are hashed in case model names are too short
|
||||
int32_t model_name_length = gsl::narrow_cast<int32_t>(model_name.size());
|
||||
constexpr int32_t hash_string_length = 500;
|
||||
std::string repeat_model_name = model_name;
|
||||
for (int i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) {
|
||||
repeat_model_name += model_name;
|
||||
}
|
||||
hash_str(repeat_model_name);
|
||||
} else {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty";
|
||||
}
|
||||
|
||||
// fingerprint the main graph by hashing graph inputs
|
||||
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
// fingerprint the main graph by hashing graph inputs
|
||||
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
|
||||
// hashing output of each node
|
||||
const int number_of_ort_nodes = graph_viewer.NumberOfNodes();
|
||||
std::vector<size_t> nodes_vector(number_of_ort_nodes);
|
||||
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
|
||||
const std::vector<NodeIndex>& node_index = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto& index : nodes_vector) {
|
||||
const auto& node = graph_viewer.GetNode(node_index[index]);
|
||||
for (const auto* node_arg : node->OutputDefs()) {
|
||||
if (node_arg->Exists()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
// fingerprint current graph by hashing graph inputs
|
||||
for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
|
||||
// hashing output of each node
|
||||
const int number_of_ort_nodes = graph_viewer.NumberOfNodes();
|
||||
std::vector<size_t> nodes_vector(number_of_ort_nodes);
|
||||
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
|
||||
const std::vector<NodeIndex>& node_index = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto& index : nodes_vector) {
|
||||
const auto& node = graph_viewer.GetNode(node_index[index]);
|
||||
for (const auto* node_arg : node->OutputDefs()) {
|
||||
if (node_arg->Exists()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
hash_str("LINUX");
|
||||
hash_str("LINUX");
|
||||
#elif defined(_WIN32)
|
||||
hash_str("WINDOWS");
|
||||
hash_str("WINDOWS");
|
||||
#endif
|
||||
|
||||
#ifdef ORT_VERSION
|
||||
hash_str(ORT_VERSION);
|
||||
hash_str(ORT_VERSION);
|
||||
#endif
|
||||
|
||||
#ifdef CUDA_VERSION
|
||||
hash_str(std::to_string(CUDA_VERSION));
|
||||
hash_str(std::to_string(CUDA_VERSION));
|
||||
#endif
|
||||
|
||||
#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR)
|
||||
std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR);
|
||||
hash_str(TRT_VERSION);
|
||||
std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR);
|
||||
hash_str(TRT_VERSION);
|
||||
#endif
|
||||
|
||||
model_hash = hash[0] | (uint64_t(hash[1]) << 32);
|
||||
model_hash = hash[0] | (uint64_t(hash[1]) << 32);
|
||||
|
||||
// return the current unique id, and increment to update
|
||||
return trt_model_id_[model_hash]++;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<HashValue, int> trt_model_id_; // current unique id for model
|
||||
};
|
||||
|
||||
std::unique_ptr<TRTModelIdGenerator> trt_model_id_generator_ = std::make_unique<TRTModelIdGenerator>();
|
||||
|
||||
// Calll TRTGenerateModelId to generate hash id for TRT engine cache
|
||||
int TRTGenerateModelId(const GraphViewer& graph_viewer, HashValue& model_hash) {
|
||||
// if the EP is shared across multiple sessions there's a very small potential for concurrency issues.
|
||||
// use a lock when generating an id to be paranoid
|
||||
static OrtMutex mutex;
|
||||
std::lock_guard<OrtMutex> lock(mutex);
|
||||
return trt_model_id_generator_->TRTGenerateId(graph_viewer, model_hash);
|
||||
}
|
||||
// return the current unique id
|
||||
return model_hash;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -287,9 +287,7 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) {
|
|||
GraphViewer viewer(graph);
|
||||
|
||||
// get the hash for the model when loaded from file
|
||||
HashValue model_hash;
|
||||
int id = TRTGenerateModelId(viewer, model_hash);
|
||||
ASSERT_EQ(id, 0);
|
||||
HashValue model_hash = TRTGenerateId(viewer);
|
||||
ASSERT_NE(model_hash, 0);
|
||||
|
||||
// now load the model from bytes and check the hash differs
|
||||
|
|
@ -301,14 +299,11 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) {
|
|||
ASSERT_STATUS_OK(Model::Load(std::move(model_proto), PathString(), model2, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
// Test loading same model from file and byte steam. Hash values should be different
|
||||
Graph& graph2 = model2->MainGraph();
|
||||
GraphViewer viewer2(graph2);
|
||||
|
||||
HashValue model_hash2;
|
||||
int id2 = TRTGenerateModelId(viewer2, model_hash2);
|
||||
|
||||
// test comparing model 1 & 2
|
||||
ASSERT_EQ(id2, 0) << "id2 should be 0";
|
||||
HashValue model_hash2= TRTGenerateId(viewer2);
|
||||
ASSERT_NE(model_hash, model_hash2);
|
||||
|
||||
// Test loading same model from different path, see if hash values are same as well
|
||||
model_path = ORT_TSTR("testdata/TRTEP_test_model/mnist.onnx");
|
||||
|
|
@ -316,42 +311,8 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) {
|
|||
ASSERT_TRUE(Model::Load(model_path, model3, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
Graph& graph3 = model3->MainGraph();
|
||||
GraphViewer viewer3(graph3);
|
||||
HashValue model_hash3;
|
||||
int id3 = TRTGenerateModelId(viewer3, model_hash3);
|
||||
HashValue model_hash3 = TRTGenerateId(viewer3);
|
||||
ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded";
|
||||
ASSERT_EQ(id3, 1) << "id3 should be 1 as model 1 & 3 have same hash";
|
||||
}
|
||||
|
||||
// Compare on TRT subgraph id when repeatedly calling TRTGenerateModelId
|
||||
TEST(TensorrtExecutionProviderTest, TRTSubgraphIdGeneratorUsingModelHashing) {
|
||||
// Load model
|
||||
auto model_path = ORT_TSTR("testdata/mnist.onnx");
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_path, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
Graph& main_graph = model->MainGraph();
|
||||
GraphViewer graph(main_graph);
|
||||
HashValue model_hash;
|
||||
|
||||
// Graph id acquired
|
||||
int graph_id = TRTGenerateModelId(graph, model_hash);
|
||||
int asserted_subgraph_id = graph_id + 1;
|
||||
|
||||
// mock fetching subgraphs and generate id by calling TRTGenerateModelId repeatedly
|
||||
const int number_of_ort_nodes = graph.NumberOfNodes();
|
||||
std::vector<size_t> nodes_vector(number_of_ort_nodes);
|
||||
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
|
||||
for (const auto& index : nodes_vector) {
|
||||
const auto& node = graph.GetNode(node_index[index]);
|
||||
std::cout << "->" << node->Name();
|
||||
|
||||
// Check if id increment each time TRTGenerateModelId is called
|
||||
int subgraph_id = TRTGenerateModelId(graph, model_hash);
|
||||
ASSERT_EQ(subgraph_id, asserted_subgraph_id) << "id will increment as TRTGenerateModelId is repeatedly called";
|
||||
asserted_subgraph_id++;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TensorrtExecutionProviderCacheTest, Run) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue