Fix issues in TRT model ID generator (#13837)

There are some issues in
https://github.com/microsoft/onnxruntime/pull/13015,
1. Model name should be used rather than graph name in the model ID
generator.
2. Hash collision is observed in ID cache, which means different model
may have the same key and thus load same hash id from the cache.
3. For the class and function that generate model id, MetaDef in the
name is not appropriate.
4. Should reuse murmurhash3 rather than copy it over to TRT EP
This PR fixes those issues.
This commit is contained in:
stevenlix 2022-12-15 13:51:19 -08:00 committed by GitHub
parent b52e8bf718
commit c4ecbb96d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 96 additions and 434 deletions

View file

@ -9,6 +9,7 @@
#include "core/providers/shared/common.h"
#include "core/common/inlined_containers.h"
#include "core/framework/murmurhash3.h"
#include "core/framework/random_generator.h"
#include "core/providers/cpu/controlflow/if.h"
#include "core/providers/cpu/controlflow/loop.h"
@ -319,6 +320,10 @@ std::unique_ptr<IAllocator> CreateCUDAPinnedAllocator(int16_t device_id, const c
std::unique_ptr<IDataTransfer> CreateGPUDataTransfer() {
return g_host->CreateGPUDataTransfer();
}
void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) {
return g_host->MurmurHash3__x86_128(key, len, seed, out);
}
#endif
#ifdef USE_MIGRAPHX

View file

@ -668,6 +668,7 @@ struct ProviderHost {
virtual const Node* Graph__ParentNode(const Graph* p) const = 0;
virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0;
virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0;
virtual const Path& Graph__ModelPath(const Graph* p) const = 0;
virtual const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0;
virtual bool Graph__IsSubgraph(const Graph* p) = 0;
@ -703,6 +704,8 @@ struct ProviderHost {
// Path
virtual PathString Path__ToPathString(const Path* p) noexcept = 0;
virtual const std::vector<PathString>& Path__GetComponents(const Path* p) noexcept = 0;
virtual bool Path__IsEmpty(const Path* p) noexcept = 0;
// OpKernel
virtual const Node& OpKernel__Node(const OpKernel* p) = 0;
@ -879,8 +882,13 @@ struct ProviderHost {
virtual RandomGenerator& RandomGenerator__Default() = 0;
#endif
#if defined(USE_TENSORRT)
virtual void MurmurHash3__x86_128(const void* key, int len, uint32_t seed, void* out) = 0;
#endif
virtual ProviderHostCPU& GetProviderHostCPU() = 0;
};
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif

View file

@ -699,6 +699,7 @@ struct Graph final {
const Node* ParentNode() const { return g_host->Graph__ParentNode(this); }
const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); }
const std::string& Name() const noexcept { return g_host->Graph__Name(this); }
const Path& ModelPath() const { return g_host->Graph__ModelPath(this); }
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); }
bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); }
@ -745,6 +746,8 @@ struct GraphViewer final {
struct Path final {
PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); }
const std::vector<PathString>& GetComponents() const noexcept { return g_host->Path__GetComponents(this); }
bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); }
PROVIDER_DISALLOW_ALL(Path)
};

View file

@ -1,349 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "murmurhash3.h"
// Original source: https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the public
// domain. The author hereby disclaims copyright to this source code.
// Note - The x86 and x64 versions do _not_ produce the same results, as the
// algorithms are optimized for their respective platforms. You can still
// compile and run any of them on any platform, but your performance with the
// non-native version will be less than optimal.
/* Modifications Copyright (c) Microsoft. */
#include "core/framework/endian.h"
//-----------------------------------------------------------------------------
// Platform-specific functions and macros
// Microsoft Visual Studio
#if defined(_MSC_VER)
#define FORCE_INLINE __forceinline
#include <stdlib.h>
#define ROTL32(x, y) _rotl(x, y)
#define ROTL64(x, y) _rotl64(x, y)
#define BIG_CONSTANT(x) (x)
// Other compilers
#else // defined(_MSC_VER)
#define FORCE_INLINE inline __attribute__((always_inline))
inline uint32_t rotl32(uint32_t x, int8_t r) {
return (x << r) | (x >> (32 - r));
}
inline uint64_t rotl64(uint64_t x, int8_t r) {
return (x << r) | (x >> (64 - r));
}
#define ROTL32(x, y) rotl32(x, y)
#define ROTL64(x, y) rotl64(x, y)
#define BIG_CONSTANT(x) (x##LLU)
#endif // !defined(_MSC_VER)
#include <cstddef>
//-----------------------------------------------------------------------------
// Block read - on little-endian machines this is a single load,
// while on big-endian or unknown machines the byte accesses should
// still get optimized into the most efficient instruction.
//
// Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/
// were manually applied to original murmurhash3 source code.
FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) {
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
return p[i];
} else {
const uint8_t* c = (const uint8_t*)&p[i];
return (uint32_t)c[0] |
(uint32_t)c[1] << 8 |
(uint32_t)c[2] << 16 |
(uint32_t)c[3] << 24;
}
}
FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) {
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
return p[i];
} else {
const uint8_t* c = (const uint8_t*)&p[i];
return (uint64_t)c[0] |
(uint64_t)c[1] << 8 |
(uint64_t)c[2] << 16 |
(uint64_t)c[3] << 24 |
(uint64_t)c[4] << 32 |
(uint64_t)c[5] << 40 |
(uint64_t)c[6] << 48 |
(uint64_t)c[7] << 56;
}
}
//-----------------------------------------------------------------------------
// Finalization mix - force all bits of a hash block to avalanche
FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) {
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
return h;
}
//----------
FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) {
k ^= k >> 33;
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
k ^= k >> 33;
k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
k ^= k >> 33;
return k;
}
//-----------------------------------------------------------------------------
namespace onnxruntime {
void MurmurHash3::x86_32(const void* key, int len,
uint32_t seed, void* out) {
const uint8_t* data = (const uint8_t*)key;
const int nblocks = len / 4;
uint32_t h1 = seed;
constexpr uint32_t c1 = 0xcc9e2d51;
constexpr uint32_t c2 = 0x1b873593;
//----------
// body
const uint32_t* blocks = (const uint32_t*)(data + static_cast<ptrdiff_t>(nblocks) * 4);
for (int i = -nblocks; i; i++) {
uint32_t k1 = getblock32(blocks, i);
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
h1 = ROTL32(h1, 13);
h1 = h1 * 5 + 0xe6546b64;
}
//----------
// tail
const uint8_t* tail = (const uint8_t*)(data + static_cast<ptrdiff_t>(nblocks) * 4);
uint32_t k1 = 0;
switch (len & 3) {
case 3:
k1 ^= tail[2] << 16;
[[fallthrough]];
case 2:
k1 ^= tail[1] << 8;
[[fallthrough]];
case 1:
k1 ^= tail[0];
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
};
//----------
// finalization
h1 ^= len;
h1 = fmix32(h1);
*(uint32_t*)out = h1;
}
//-----------------------------------------------------------------------------
void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) {
const uint8_t* data = (const uint8_t*)key;
const int nblocks = len / 16;
uint32_t h1 = seed;
uint32_t h2 = seed;
uint32_t h3 = seed;
uint32_t h4 = seed;
constexpr uint32_t c1 = 0x239b961b;
constexpr uint32_t c2 = 0xab0e9789;
constexpr uint32_t c3 = 0x38b34ae5;
constexpr uint32_t c4 = 0xa1e38b93;
//----------
// body
const uint32_t* blocks = (const uint32_t*)(data + static_cast<ptrdiff_t>(nblocks) * 16);
for (int i = -nblocks; i; i++) {
uint32_t k1 = getblock32(blocks, i * 4 + 0);
uint32_t k2 = getblock32(blocks, i * 4 + 1);
uint32_t k3 = getblock32(blocks, i * 4 + 2);
uint32_t k4 = getblock32(blocks, i * 4 + 3);
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
h1 = ROTL32(h1, 19);
h1 += h2;
h1 = h1 * 5 + 0x561ccd1b;
k2 *= c2;
k2 = ROTL32(k2, 16);
k2 *= c3;
h2 ^= k2;
h2 = ROTL32(h2, 17);
h2 += h3;
h2 = h2 * 5 + 0x0bcaa747;
k3 *= c3;
k3 = ROTL32(k3, 17);
k3 *= c4;
h3 ^= k3;
h3 = ROTL32(h3, 15);
h3 += h4;
h3 = h3 * 5 + 0x96cd1c35;
k4 *= c4;
k4 = ROTL32(k4, 18);
k4 *= c1;
h4 ^= k4;
h4 = ROTL32(h4, 13);
h4 += h1;
h4 = h4 * 5 + 0x32ac3b17;
}
//----------
// tail
const uint8_t* tail = (const uint8_t*)(data + static_cast<ptrdiff_t>(nblocks) * 16);
uint32_t k1 = 0;
uint32_t k2 = 0;
uint32_t k3 = 0;
uint32_t k4 = 0;
switch (len & 15) {
case 15:
k4 ^= tail[14] << 16;
[[fallthrough]];
case 14:
k4 ^= tail[13] << 8;
[[fallthrough]];
case 13:
k4 ^= tail[12] << 0;
k4 *= c4;
k4 = ROTL32(k4, 18);
k4 *= c1;
h4 ^= k4;
[[fallthrough]];
case 12:
k3 ^= tail[11] << 24;
[[fallthrough]];
case 11:
k3 ^= tail[10] << 16;
[[fallthrough]];
case 10:
k3 ^= tail[9] << 8;
[[fallthrough]];
case 9:
k3 ^= tail[8] << 0;
k3 *= c3;
k3 = ROTL32(k3, 17);
k3 *= c4;
h3 ^= k3;
[[fallthrough]];
case 8:
k2 ^= tail[7] << 24;
[[fallthrough]];
case 7:
k2 ^= tail[6] << 16;
[[fallthrough]];
case 6:
k2 ^= tail[5] << 8;
[[fallthrough]];
case 5:
k2 ^= tail[4] << 0;
k2 *= c2;
k2 = ROTL32(k2, 16);
k2 *= c3;
h2 ^= k2;
[[fallthrough]];
case 4:
k1 ^= tail[3] << 24;
[[fallthrough]];
case 3:
k1 ^= tail[2] << 16;
[[fallthrough]];
case 2:
k1 ^= tail[1] << 8;
[[fallthrough]];
case 1:
k1 ^= tail[0] << 0;
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
};
//----------
// finalization
h1 ^= len;
h2 ^= len;
h3 ^= len;
h4 ^= len;
h1 += h2;
h1 += h3;
h1 += h4;
h2 += h1;
h3 += h1;
h4 += h1;
h1 = fmix32(h1);
h2 = fmix32(h2);
h3 = fmix32(h3);
h4 = fmix32(h4);
h1 += h2;
h1 += h3;
h1 += h4;
h2 += h1;
h3 += h1;
h4 += h1;
((uint32_t*)out)[0] = h1;
((uint32_t*)out)[1] = h2;
((uint32_t*)out)[2] = h3;
((uint32_t*)out)[3] = h4;
}
} // namespace onnxruntime

View file

@ -1,16 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
namespace onnxruntime {
struct MurmurHash3 {
// generate 32-bit hash from input and write to 'out'
static void x86_32(const void* key, int len, uint32_t seed, void* out);
// generate 128-bit hash from input and write to 'out'.
static void x86_128(const void* key, int len, uint32_t seed, void* out);
};
} // namespace onnxruntime

View file

@ -743,7 +743,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
// Generate unique kernel name for TRT subgraph
HashValue model_hash = 0;
int id = TRTGenerateMetaDefId(graph, model_hash);
int id = TRTGenerateModelId(graph, model_hash);
std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id);
auto meta_def = IndexedSubGraph_MetaDef::Create();
const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph";

View file

@ -8,9 +8,9 @@
#include <experimental/filesystem>
#include "flatbuffers/idl.h"
#include "ort_trt_int8_cal_table.fbs.h"
#include "murmurhash3.h"
#include <NvInferVersion.h>
#include "core/providers/cuda/cuda_pch.h"
#include "core/framework/murmurhash3.h"
namespace fs = std::experimental::filesystem;
@ -199,7 +199,7 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) {
}
// Helper class to generate engine id via model name/model content/env metadata
class TRTModelMetadefIdGenerator {
class TRTModelIdGenerator {
public:
int TRTGenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) {
model_hash = 0;
@ -210,92 +210,94 @@ class TRTModelMetadefIdGenerator {
cur_graph = cur_graph->ParentGraph();
}
uint32_t instance_hash[4] = {0, 0, 0, 0};
const Graph& main_graph = *cur_graph;
uint32_t hash[4] = {0, 0, 0, 0};
// hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use
// the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique
// fingerprint for the instance that can use used as the key to the hash of the graph name/inputs&outputs/metadata.
MurmurHash3::x86_128(&main_graph, gsl::narrow_cast<int32_t>(sizeof(Graph)), instance_hash[0], &instance_hash);
HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32);
auto hash_str = [&hash](const std::string& str) {
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
};
// if we've already hashed this main graph instance use the cached value
auto entry = trt_main_graph_hash_.find(graph_instance_hash);
if (entry != trt_main_graph_hash_.cend()) {
model_hash = entry->second;
// Use model name instead of path to avoid cache regeneration if path changes
const auto& model_path = main_graph.ModelPath();
if (!model_path.IsEmpty()) {
// Get model name
PathString path_string = model_path.GetComponents().back();
char arr[256];
#ifdef _WIN32
wcstombs_s(nullptr, arr, sizeof(arr), path_string.c_str(), sizeof(arr));
#else
strcpy(arr, path_string.c_str());
#endif
std::string model_name(arr);
LOGS_DEFAULT(INFO) << "[TensorRT EP] Model name is " << model_name;
// Ensure enough characters are hashed in case model names are too short
int32_t model_name_length = gsl::narrow_cast<int32_t>(model_name.size());
constexpr int32_t hash_string_length = 500;
std::string repeat_model_name = model_name;
for (int i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) {
repeat_model_name += model_name;
}
hash_str(repeat_model_name);
} else {
uint32_t hash[4] = {0, 0, 0, 0};
LOGS_DEFAULT(INFO) << "[TensorRT EP] Model path is empty";
}
// Use graph name instead of path to avoid cache regeneration if path changes
const auto& model_name_str = main_graph.Name();
if (!model_name_str.empty()) {
MurmurHash3::x86_128(model_name_str.data(), gsl::narrow_cast<int32_t>(model_name_str.size()), hash[0], &hash);
}
// fingerprint the main graph by hashing graph inputs
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
hash_str(node_arg->Name());
}
auto hash_str = [&hash](const std::string& str) {
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
};
// fingerprint the main graph by hashing graph inputs
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
hash_str(node_arg->Name());
}
// hashing output of each node
const int number_of_ort_nodes = graph_viewer.NumberOfNodes();
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
const std::vector<NodeIndex>& node_index = graph_viewer.GetNodesInTopologicalOrder();
for (const auto& index : nodes_vector) {
const auto& node = graph_viewer.GetNode(node_index[index]);
for (const auto* node_arg : node->OutputDefs()) {
if (node_arg->Exists()) {
hash_str(node_arg->Name());
}
// hashing output of each node
const int number_of_ort_nodes = graph_viewer.NumberOfNodes();
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
const std::vector<NodeIndex>& node_index = graph_viewer.GetNodesInTopologicalOrder();
for (const auto& index : nodes_vector) {
const auto& node = graph_viewer.GetNode(node_index[index]);
for (const auto* node_arg : node->OutputDefs()) {
if (node_arg->Exists()) {
hash_str(node_arg->Name());
}
}
}
#ifdef __linux__
hash_str("LINUX");
hash_str("LINUX");
#elif defined(_WIN32)
hash_str("WINDOWS");
hash_str("WINDOWS");
#endif
#ifdef ORT_VERSION
hash_str(ORT_VERSION);
hash_str(ORT_VERSION);
#endif
#ifdef CUDA_VERSION
hash_str(std::to_string(CUDA_VERSION));
hash_str(std::to_string(CUDA_VERSION));
#endif
#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR)
std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR);
hash_str(TRT_VERSION);
std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR);
hash_str(TRT_VERSION);
#endif
model_hash = hash[0] | (uint64_t(hash[1]) << 32);
trt_main_graph_hash_[graph_instance_hash] = model_hash;
}
model_hash = hash[0] | (uint64_t(hash[1]) << 32);
// return the current unique id, and increment to update
return trt_model_metadef_id_[model_hash]++;
return trt_model_id_[model_hash]++;
}
private:
std::unordered_map<HashValue, HashValue> trt_main_graph_hash_; // map graph instance hash to model contents hash
std::unordered_map<HashValue, int> trt_model_metadef_id_; // current unique id for model
std::unordered_map<HashValue, int> trt_model_id_; // current unique id for model
};
std::unique_ptr<TRTModelMetadefIdGenerator> trt_metadef_id_generator_ = std::make_unique<TRTModelMetadefIdGenerator>();
std::unique_ptr<TRTModelIdGenerator> trt_model_id_generator_ = std::make_unique<TRTModelIdGenerator>();
// Calll TRTGenerateMetaDefId to generate hash id for TRT engine cache
int TRTGenerateMetaDefId(const GraphViewer& graph_viewer, HashValue& model_hash) {
// Calll TRTGenerateModelId to generate hash id for TRT engine cache
int TRTGenerateModelId(const GraphViewer& graph_viewer, HashValue& model_hash) {
// if the EP is shared across multiple sessions there's a very small potential for concurrency issues.
// use a lock when generating an id to be paranoid
static OrtMutex mutex;
std::lock_guard<OrtMutex> lock(mutex);
return trt_metadef_id_generator_->TRTGenerateId(graph_viewer, model_hash);
return trt_model_id_generator_->TRTGenerateId(graph_viewer, model_hash);
}
}

View file

@ -28,6 +28,7 @@
#include "core/util/math.h"
#include "core/framework/sparse_utils.h"
#include "core/graph/graph_proto_serializer.h"
#include "core/framework/murmurhash3.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/common/string_helper.h"
@ -761,6 +762,7 @@ struct ProviderHostImpl : ProviderHost {
const Node* Graph__ParentNode(const Graph* p) const override { return p->ParentNode(); }
const Graph* Graph__ParentGraph(const Graph* p) const override { return p->ParentGraph(); }
const std::string& Graph__Name(const Graph* p) const noexcept override { return p->Name(); }
const Path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); }
const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); }
bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); }
@ -806,6 +808,8 @@ struct ProviderHostImpl : ProviderHost {
// Path (wrapped)
PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); }
const std::vector<PathString>& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); }
bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); }
// OpKernel (direct)
const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); }
@ -1012,6 +1016,12 @@ struct ProviderHostImpl : ProviderHost {
RandomGenerator& RandomGenerator__Default() override { return RandomGenerator::Default(); }
#endif
#if defined(USE_TENSORRT)
void MurmurHash3__x86_128(const void* key, int len, uint32_t seed, void* out) {
MurmurHash3::x86_128(key, len, seed, out);
}
#endif
ProviderHostCPU& GetProviderHostCPU() override { return onnxruntime::GetProviderHostCPU(); }
} provider_host_;
#if defined(_MSC_VER) && !defined(__clang__)

View file

@ -277,7 +277,7 @@ TEST(TensorrtExecutionProviderTest, SessionCreationWithSingleThreadAndInferenceW
}
// Test loading same model in different way, when hash id is generated via model name/model content/env metadata
TEST(TensorrtExecutionProviderTest, TRTMetadefIdGeneratorUsingModelHashing) {
TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) {
auto model_path = ORT_TSTR("testdata/mnist.onnx");
std::shared_ptr<Model> model;
@ -288,7 +288,7 @@ TEST(TensorrtExecutionProviderTest, TRTMetadefIdGeneratorUsingModelHashing) {
// get the hash for the model when loaded from file
HashValue model_hash;
int id = TRTGenerateMetaDefId(viewer, model_hash);
int id = TRTGenerateModelId(viewer, model_hash);
ASSERT_EQ(id, 0);
ASSERT_NE(model_hash, 0);
@ -305,11 +305,10 @@ TEST(TensorrtExecutionProviderTest, TRTMetadefIdGeneratorUsingModelHashing) {
GraphViewer viewer2(graph2);
HashValue model_hash2;
int id2 = TRTGenerateMetaDefId(viewer2, model_hash2);
int id2 = TRTGenerateModelId(viewer2, model_hash2);
// test comparing model 1 & 2
ASSERT_EQ(model_hash, model_hash2) << "model1 has same graph name/nodes/env metadata as model2";
ASSERT_EQ(id2, 1) << "id2 should be 1 as model 1 & 2 have same hash";
ASSERT_EQ(id2, 0) << "id2 should be 0";
// Test loading same model from different path, see if hash values are same as well
model_path = ORT_TSTR("testdata/TRTEP_test_model/mnist.onnx");
@ -318,12 +317,12 @@ TEST(TensorrtExecutionProviderTest, TRTMetadefIdGeneratorUsingModelHashing) {
Graph& graph3 = model3->MainGraph();
GraphViewer viewer3(graph3);
HashValue model_hash3;
int id3 = TRTGenerateMetaDefId(viewer3, model_hash3);
int id3 = TRTGenerateModelId(viewer3, model_hash3);
ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded";
ASSERT_EQ(id3, 2) << "id3 should be 2 as model 1 & 2 & 3 have same hash";
ASSERT_EQ(id3, 1) << "id3 should be 1 as model 1 & 3 have same hash";
}
// Compare on TRT subgraph id when repeatedly calling TRTGenerateMetaDefId
// Compare on TRT subgraph id when repeatedly calling TRTGenerateModelId
TEST(TensorrtExecutionProviderTest, TRTSubgraphIdGeneratorUsingModelHashing) {
// Load model
auto model_path = ORT_TSTR("testdata/mnist.onnx");
@ -335,10 +334,10 @@ TEST(TensorrtExecutionProviderTest, TRTSubgraphIdGeneratorUsingModelHashing) {
HashValue model_hash;
// Graph id acquired
int graph_id = TRTGenerateMetaDefId(graph, model_hash);
int graph_id = TRTGenerateModelId(graph, model_hash);
int asserted_subgraph_id = graph_id + 1;
// mock fetching subgraphs and generate id by calling TRTGenerateMetaDefId repeatedly
// mock fetching subgraphs and generate id by calling TRTGenerateModelId repeatedly
const int number_of_ort_nodes = graph.NumberOfNodes();
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
@ -348,9 +347,9 @@ TEST(TensorrtExecutionProviderTest, TRTSubgraphIdGeneratorUsingModelHashing) {
const auto& node = graph.GetNode(node_index[index]);
std::cout << "->" << node->Name();
// Check if id increment each time TRTGenerateMetaDefId is called
int subgraph_id = TRTGenerateMetaDefId(graph, model_hash);
ASSERT_EQ(subgraph_id, asserted_subgraph_id) << "id will increment as TRTGenerateMetaDefId is repeatedly called";
// Check if id increment each time TRTGenerateModelId is called
int subgraph_id = TRTGenerateModelId(graph, model_hash);
ASSERT_EQ(subgraph_id, asserted_subgraph_id) << "id will increment as TRTGenerateModelId is repeatedly called";
asserted_subgraph_id++;
}
}