mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Update engine hash id generator with model name/model content/metadata (#13015)
**Update engine hash id generator with model name/model
content/metadata**
**Description**:
* Updated engine id generator, which use model name/model inputs &
outputs/env metadata (instead of model path) to generate hash
* New bridged API were introduced in order to enable id generator in the
TRTEP utility
**Motivation and Context**
- Why is this change required? What problem does it solve? To fix this
[issue](https://github.com/triton-inference-server/server/issues/4587)
caused by id generator using model path
How to use:
* Call [TRTGenerateMetaDefId(const GraphViewer& graph_viewer, HashValue&
model_hash)](0fcce74a56/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc (L715))
to generate hash id for TRT engine cache
How to test:
* On WIndows, run:
* .\onnxruntime_test_all.exe
--gtest_filter=TensorrtExecutionProviderTest.TRTMetadefIdGeneratorUsingModelHashing
* .\onnxruntime_test_all.exe
--gtest_filter=TensorrtExecutionProviderTest.TRTSubgraphIdGeneratorUsingModelHashing
**Appendix**
* [Existing engine id generator that uses model
path](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/execution_provider.cc#L112-L182)
This commit is contained in:
parent
39e20686a0
commit
240aeadf1a
9 changed files with 561 additions and 1 deletions
|
|
@ -663,6 +663,10 @@ struct ProviderHost {
|
|||
virtual bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) = 0;
|
||||
|
||||
virtual const Node* Graph__ParentNode(const Graph* p) const = 0;
|
||||
virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0;
|
||||
virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0;
|
||||
virtual const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0;
|
||||
virtual bool Graph__IsSubgraph(const Graph* p) = 0;
|
||||
|
||||
// GraphViewer
|
||||
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;
|
||||
|
|
|
|||
|
|
@ -695,6 +695,10 @@ struct Graph final {
|
|||
bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { return g_host->Graph__GetInitializedTensor(this, tensor_name, value); }
|
||||
|
||||
const Node* ParentNode() const { return g_host->Graph__ParentNode(this); }
|
||||
const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); }
|
||||
const std::string& Name() const noexcept { return g_host->Graph__Name(this); }
|
||||
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); }
|
||||
bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(Graph)
|
||||
};
|
||||
|
|
|
|||
349
onnxruntime/core/providers/tensorrt/murmurhash3.cc
Normal file
349
onnxruntime/core/providers/tensorrt/murmurhash3.cc
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "murmurhash3.h"
|
||||
|
||||
// Original source: https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
|
||||
//-----------------------------------------------------------------------------
|
||||
// MurmurHash3 was written by Austin Appleby, and is placed in the public
|
||||
// domain. The author hereby disclaims copyright to this source code.
|
||||
|
||||
// Note - The x86 and x64 versions do _not_ produce the same results, as the
|
||||
// algorithms are optimized for their respective platforms. You can still
|
||||
// compile and run any of them on any platform, but your performance with the
|
||||
// non-native version will be less than optimal.
|
||||
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
#include "core/framework/endian.h"
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Platform-specific functions and macros
|
||||
|
||||
// Microsoft Visual Studio
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
#define FORCE_INLINE __forceinline
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#define ROTL32(x, y) _rotl(x, y)
|
||||
#define ROTL64(x, y) _rotl64(x, y)
|
||||
|
||||
#define BIG_CONSTANT(x) (x)
|
||||
|
||||
// Other compilers
|
||||
|
||||
#else // defined(_MSC_VER)
|
||||
|
||||
#define FORCE_INLINE inline __attribute__((always_inline))
|
||||
|
||||
inline uint32_t rotl32(uint32_t x, int8_t r) {
|
||||
return (x << r) | (x >> (32 - r));
|
||||
}
|
||||
|
||||
inline uint64_t rotl64(uint64_t x, int8_t r) {
|
||||
return (x << r) | (x >> (64 - r));
|
||||
}
|
||||
|
||||
#define ROTL32(x, y) rotl32(x, y)
|
||||
#define ROTL64(x, y) rotl64(x, y)
|
||||
|
||||
#define BIG_CONSTANT(x) (x##LLU)
|
||||
|
||||
#endif // !defined(_MSC_VER)
|
||||
#include <cstddef>
|
||||
//-----------------------------------------------------------------------------
|
||||
// Block read - on little-endian machines this is a single load,
|
||||
// while on big-endian or unknown machines the byte accesses should
|
||||
// still get optimized into the most efficient instruction.
|
||||
//
|
||||
// Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/
|
||||
// were manually applied to original murmurhash3 source code.
|
||||
FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) {
|
||||
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
|
||||
return p[i];
|
||||
} else {
|
||||
const uint8_t* c = (const uint8_t*)&p[i];
|
||||
return (uint32_t)c[0] |
|
||||
(uint32_t)c[1] << 8 |
|
||||
(uint32_t)c[2] << 16 |
|
||||
(uint32_t)c[3] << 24;
|
||||
}
|
||||
}
|
||||
|
||||
FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) {
|
||||
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
|
||||
return p[i];
|
||||
} else {
|
||||
const uint8_t* c = (const uint8_t*)&p[i];
|
||||
return (uint64_t)c[0] |
|
||||
(uint64_t)c[1] << 8 |
|
||||
(uint64_t)c[2] << 16 |
|
||||
(uint64_t)c[3] << 24 |
|
||||
(uint64_t)c[4] << 32 |
|
||||
(uint64_t)c[5] << 40 |
|
||||
(uint64_t)c[6] << 48 |
|
||||
(uint64_t)c[7] << 56;
|
||||
}
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Finalization mix - force all bits of a hash block to avalanche
|
||||
|
||||
FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) {
|
||||
h ^= h >> 16;
|
||||
h *= 0x85ebca6b;
|
||||
h ^= h >> 13;
|
||||
h *= 0xc2b2ae35;
|
||||
h ^= h >> 16;
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
//----------
|
||||
|
||||
FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) {
|
||||
k ^= k >> 33;
|
||||
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
|
||||
k ^= k >> 33;
|
||||
k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
|
||||
k ^= k >> 33;
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
namespace onnxruntime {
|
||||
void MurmurHash3::x86_32(const void* key, int len,
|
||||
uint32_t seed, void* out) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
const int nblocks = len / 4;
|
||||
|
||||
uint32_t h1 = seed;
|
||||
|
||||
constexpr uint32_t c1 = 0xcc9e2d51;
|
||||
constexpr uint32_t c2 = 0x1b873593;
|
||||
|
||||
//----------
|
||||
// body
|
||||
|
||||
const uint32_t* blocks = (const uint32_t*)(data + static_cast<ptrdiff_t>(nblocks) * 4);
|
||||
|
||||
for (int i = -nblocks; i; i++) {
|
||||
uint32_t k1 = getblock32(blocks, i);
|
||||
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
|
||||
h1 ^= k1;
|
||||
h1 = ROTL32(h1, 13);
|
||||
h1 = h1 * 5 + 0xe6546b64;
|
||||
}
|
||||
|
||||
//----------
|
||||
// tail
|
||||
|
||||
const uint8_t* tail = (const uint8_t*)(data + static_cast<ptrdiff_t>(nblocks) * 4);
|
||||
|
||||
uint32_t k1 = 0;
|
||||
|
||||
switch (len & 3) {
|
||||
case 3:
|
||||
k1 ^= tail[2] << 16;
|
||||
[[fallthrough]];
|
||||
case 2:
|
||||
k1 ^= tail[1] << 8;
|
||||
[[fallthrough]];
|
||||
case 1:
|
||||
k1 ^= tail[0];
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
};
|
||||
|
||||
//----------
|
||||
// finalization
|
||||
|
||||
h1 ^= len;
|
||||
|
||||
h1 = fmix32(h1);
|
||||
|
||||
*(uint32_t*)out = h1;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
void MurmurHash3::x86_128(const void* key, int len, uint32_t seed, void* out) {
|
||||
const uint8_t* data = (const uint8_t*)key;
|
||||
const int nblocks = len / 16;
|
||||
|
||||
uint32_t h1 = seed;
|
||||
uint32_t h2 = seed;
|
||||
uint32_t h3 = seed;
|
||||
uint32_t h4 = seed;
|
||||
|
||||
constexpr uint32_t c1 = 0x239b961b;
|
||||
constexpr uint32_t c2 = 0xab0e9789;
|
||||
constexpr uint32_t c3 = 0x38b34ae5;
|
||||
constexpr uint32_t c4 = 0xa1e38b93;
|
||||
|
||||
//----------
|
||||
// body
|
||||
|
||||
const uint32_t* blocks = (const uint32_t*)(data + static_cast<ptrdiff_t>(nblocks) * 16);
|
||||
|
||||
for (int i = -nblocks; i; i++) {
|
||||
uint32_t k1 = getblock32(blocks, i * 4 + 0);
|
||||
uint32_t k2 = getblock32(blocks, i * 4 + 1);
|
||||
uint32_t k3 = getblock32(blocks, i * 4 + 2);
|
||||
uint32_t k4 = getblock32(blocks, i * 4 + 3);
|
||||
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
|
||||
h1 = ROTL32(h1, 19);
|
||||
h1 += h2;
|
||||
h1 = h1 * 5 + 0x561ccd1b;
|
||||
|
||||
k2 *= c2;
|
||||
k2 = ROTL32(k2, 16);
|
||||
k2 *= c3;
|
||||
h2 ^= k2;
|
||||
|
||||
h2 = ROTL32(h2, 17);
|
||||
h2 += h3;
|
||||
h2 = h2 * 5 + 0x0bcaa747;
|
||||
|
||||
k3 *= c3;
|
||||
k3 = ROTL32(k3, 17);
|
||||
k3 *= c4;
|
||||
h3 ^= k3;
|
||||
|
||||
h3 = ROTL32(h3, 15);
|
||||
h3 += h4;
|
||||
h3 = h3 * 5 + 0x96cd1c35;
|
||||
|
||||
k4 *= c4;
|
||||
k4 = ROTL32(k4, 18);
|
||||
k4 *= c1;
|
||||
h4 ^= k4;
|
||||
|
||||
h4 = ROTL32(h4, 13);
|
||||
h4 += h1;
|
||||
h4 = h4 * 5 + 0x32ac3b17;
|
||||
}
|
||||
|
||||
//----------
|
||||
// tail
|
||||
|
||||
const uint8_t* tail = (const uint8_t*)(data + static_cast<ptrdiff_t>(nblocks) * 16);
|
||||
|
||||
uint32_t k1 = 0;
|
||||
uint32_t k2 = 0;
|
||||
uint32_t k3 = 0;
|
||||
uint32_t k4 = 0;
|
||||
|
||||
switch (len & 15) {
|
||||
case 15:
|
||||
k4 ^= tail[14] << 16;
|
||||
[[fallthrough]];
|
||||
case 14:
|
||||
k4 ^= tail[13] << 8;
|
||||
[[fallthrough]];
|
||||
case 13:
|
||||
k4 ^= tail[12] << 0;
|
||||
k4 *= c4;
|
||||
k4 = ROTL32(k4, 18);
|
||||
k4 *= c1;
|
||||
h4 ^= k4;
|
||||
[[fallthrough]];
|
||||
case 12:
|
||||
k3 ^= tail[11] << 24;
|
||||
[[fallthrough]];
|
||||
case 11:
|
||||
k3 ^= tail[10] << 16;
|
||||
[[fallthrough]];
|
||||
case 10:
|
||||
k3 ^= tail[9] << 8;
|
||||
[[fallthrough]];
|
||||
case 9:
|
||||
k3 ^= tail[8] << 0;
|
||||
k3 *= c3;
|
||||
k3 = ROTL32(k3, 17);
|
||||
k3 *= c4;
|
||||
h3 ^= k3;
|
||||
[[fallthrough]];
|
||||
case 8:
|
||||
k2 ^= tail[7] << 24;
|
||||
[[fallthrough]];
|
||||
case 7:
|
||||
k2 ^= tail[6] << 16;
|
||||
[[fallthrough]];
|
||||
case 6:
|
||||
k2 ^= tail[5] << 8;
|
||||
[[fallthrough]];
|
||||
case 5:
|
||||
k2 ^= tail[4] << 0;
|
||||
k2 *= c2;
|
||||
k2 = ROTL32(k2, 16);
|
||||
k2 *= c3;
|
||||
h2 ^= k2;
|
||||
[[fallthrough]];
|
||||
case 4:
|
||||
k1 ^= tail[3] << 24;
|
||||
[[fallthrough]];
|
||||
case 3:
|
||||
k1 ^= tail[2] << 16;
|
||||
[[fallthrough]];
|
||||
case 2:
|
||||
k1 ^= tail[1] << 8;
|
||||
[[fallthrough]];
|
||||
case 1:
|
||||
k1 ^= tail[0] << 0;
|
||||
k1 *= c1;
|
||||
k1 = ROTL32(k1, 15);
|
||||
k1 *= c2;
|
||||
h1 ^= k1;
|
||||
};
|
||||
|
||||
//----------
|
||||
// finalization
|
||||
|
||||
h1 ^= len;
|
||||
h2 ^= len;
|
||||
h3 ^= len;
|
||||
h4 ^= len;
|
||||
|
||||
h1 += h2;
|
||||
h1 += h3;
|
||||
h1 += h4;
|
||||
h2 += h1;
|
||||
h3 += h1;
|
||||
h4 += h1;
|
||||
|
||||
h1 = fmix32(h1);
|
||||
h2 = fmix32(h2);
|
||||
h3 = fmix32(h3);
|
||||
h4 = fmix32(h4);
|
||||
|
||||
h1 += h2;
|
||||
h1 += h3;
|
||||
h1 += h4;
|
||||
h2 += h1;
|
||||
h3 += h1;
|
||||
h4 += h1;
|
||||
|
||||
((uint32_t*)out)[0] = h1;
|
||||
((uint32_t*)out)[1] = h2;
|
||||
((uint32_t*)out)[2] = h3;
|
||||
((uint32_t*)out)[3] = h4;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
16
onnxruntime/core/providers/tensorrt/murmurhash3.h
Normal file
16
onnxruntime/core/providers/tensorrt/murmurhash3.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnxruntime {
|
||||
struct MurmurHash3 {
|
||||
// generate 32-bit hash from input and write to 'out'
|
||||
static void x86_32(const void* key, int len, uint32_t seed, void* out);
|
||||
|
||||
// generate 128-bit hash from input and write to 'out'.
|
||||
static void x86_128(const void* key, int len, uint32_t seed, void* out);
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -712,7 +712,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
|
|||
|
||||
// Generate unique kernel name for TRT subgraph
|
||||
HashValue model_hash = 0;
|
||||
int id = GenerateMetaDefId(graph, model_hash);
|
||||
int id = TRTGenerateMetaDefId(graph, model_hash);
|
||||
std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id);
|
||||
auto meta_def = IndexedSubGraph_MetaDef::Create();
|
||||
const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph";
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@
|
|||
#include <experimental/filesystem>
|
||||
#include "flatbuffers/idl.h"
|
||||
#include "ort_trt_int8_cal_table.fbs.h"
|
||||
#include "murmurhash3.h"
|
||||
#include <NvInfer.h>
|
||||
#include "core/providers/cuda/cuda_pch.h"
|
||||
|
||||
namespace fs = std::experimental::filesystem;
|
||||
|
||||
|
|
@ -194,4 +197,105 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) {
|
|||
fs::remove(entry);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper class to generate engine id via model name/model content/env metadata
|
||||
class TRTModelMetadefIdGenerator {
|
||||
public:
|
||||
int TRTGenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) {
|
||||
model_hash = 0;
|
||||
|
||||
// find the top level graph
|
||||
const Graph* cur_graph = &graph_viewer.GetGraph();
|
||||
while (cur_graph->IsSubgraph()) {
|
||||
cur_graph = cur_graph->ParentGraph();
|
||||
}
|
||||
|
||||
uint32_t instance_hash[4] = {0, 0, 0, 0};
|
||||
|
||||
const Graph& main_graph = *cur_graph;
|
||||
|
||||
// hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use
|
||||
// the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique
|
||||
// fingerprint for the instance that can use used as the key to the hash of the graph name/inputs&outputs/metadata.
|
||||
MurmurHash3::x86_128(&main_graph, gsl::narrow_cast<int32_t>(sizeof(Graph)), instance_hash[0], &instance_hash);
|
||||
HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32);
|
||||
|
||||
// if we've already hashed this main graph instance use the cached value
|
||||
auto entry = trt_main_graph_hash_.find(graph_instance_hash);
|
||||
if (entry != trt_main_graph_hash_.cend()) {
|
||||
model_hash = entry->second;
|
||||
} else {
|
||||
uint32_t hash[4] = {0, 0, 0, 0};
|
||||
|
||||
// Use graph name instead of path to avoid cache regeneration if path changes
|
||||
const auto& model_name_str = main_graph.Name();
|
||||
if (!model_name_str.empty()) {
|
||||
MurmurHash3::x86_128(model_name_str.data(), gsl::narrow_cast<int32_t>(model_name_str.size()), hash[0], &hash);
|
||||
}
|
||||
|
||||
auto hash_str = [&hash](const std::string& str) {
|
||||
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
|
||||
};
|
||||
|
||||
// fingerprint the main graph by hashing graph inputs
|
||||
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
|
||||
// hashing output of each node
|
||||
const int number_of_ort_nodes = graph_viewer.NumberOfNodes();
|
||||
std::vector<size_t> nodes_vector(number_of_ort_nodes);
|
||||
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
|
||||
const std::vector<NodeIndex>& node_index = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto& index : nodes_vector) {
|
||||
const auto& node = graph_viewer.GetNode(node_index[index]);
|
||||
for (const auto* node_arg : node->OutputDefs()) {
|
||||
if (node_arg->Exists()) {
|
||||
hash_str(node_arg->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
hash_str("LINUX");
|
||||
#elif defined(_WIN32)
|
||||
hash_str("WINDOWS");
|
||||
#endif
|
||||
|
||||
#ifdef ORT_VERSION
|
||||
hash_str(ORT_VERSION);
|
||||
#endif
|
||||
|
||||
#ifdef CUDA_VERSION
|
||||
hash_str(std::to_string(CUDA_VERSION));
|
||||
#endif
|
||||
|
||||
#if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR)
|
||||
std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR);
|
||||
hash_str(TRT_VERSION);
|
||||
#endif
|
||||
|
||||
model_hash = hash[0] | (uint64_t(hash[1]) << 32);
|
||||
trt_main_graph_hash_[graph_instance_hash] = model_hash;
|
||||
}
|
||||
|
||||
// return the current unique id, and increment to update
|
||||
return trt_model_metadef_id_[model_hash]++;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<HashValue, HashValue> trt_main_graph_hash_; // map graph instance hash to model contents hash
|
||||
std::unordered_map<HashValue, int> trt_model_metadef_id_; // current unique id for model
|
||||
};
|
||||
|
||||
std::unique_ptr<TRTModelMetadefIdGenerator> trt_metadef_id_generator_ = std::make_unique<TRTModelMetadefIdGenerator>();
|
||||
|
||||
// Calll TRTGenerateMetaDefId to generate hash id for TRT engine cache
|
||||
int TRTGenerateMetaDefId(const GraphViewer& graph_viewer, HashValue& model_hash) {
|
||||
// if the EP is shared across multiple sessions there's a very small potential for concurrency issues.
|
||||
// use a lock when generating an id to be paranoid
|
||||
static OrtMutex mutex;
|
||||
std::lock_guard<OrtMutex> lock(mutex);
|
||||
return trt_metadef_id_generator_->TRTGenerateId(graph_viewer, model_hash);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -753,6 +753,10 @@ struct ProviderHostImpl : ProviderHost {
|
|||
bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) override { return p->GetInitializedTensor(tensor_name, value); }
|
||||
|
||||
const Node* Graph__ParentNode(const Graph* p) const override { return p->ParentNode(); }
|
||||
const Graph* Graph__ParentGraph(const Graph* p) const override { return p->ParentGraph(); }
|
||||
const std::string& Graph__Name(const Graph* p) const noexcept override { return p->Name(); }
|
||||
const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); }
|
||||
bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); }
|
||||
|
||||
// GraphViewer (wrapped)
|
||||
void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }
|
||||
|
|
|
|||
|
|
@ -274,6 +274,85 @@ TEST(TensorrtExecutionProviderTest, MultiThreadsTestWithOneSessionMultiThreadsIn
|
|||
RunWithOneSessionMultiThreadsInference(model_name, sess_log_id);
|
||||
}
|
||||
|
||||
// Test loading same model in different way, when hash id is generated via model name/model content/env metadata
|
||||
TEST(TensorrtExecutionProviderTest, TRTMetadefIdGeneratorUsingModelHashing) {
|
||||
auto model_path = ORT_TSTR("testdata/mnist.onnx");
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_path, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
GraphViewer viewer(graph);
|
||||
|
||||
// get the hash for the model when loaded from file
|
||||
HashValue model_hash;
|
||||
int id = TRTGenerateMetaDefId(viewer, model_hash);
|
||||
ASSERT_EQ(id, 0);
|
||||
ASSERT_NE(model_hash, 0);
|
||||
|
||||
// now load the model from bytes and check the hash differs
|
||||
std::ifstream model_file_stream(model_path, std::ios::in | std::ios::binary);
|
||||
|
||||
std::shared_ptr<Model> model2;
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
ASSERT_STATUS_OK(Model::Load(model_file_stream, &model_proto));
|
||||
ASSERT_STATUS_OK(Model::Load(std::move(model_proto), PathString(), model2, nullptr,
|
||||
DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
Graph& graph2 = model2->MainGraph();
|
||||
GraphViewer viewer2(graph2);
|
||||
|
||||
HashValue model_hash2;
|
||||
int id2 = TRTGenerateMetaDefId(viewer2, model_hash2);
|
||||
|
||||
// test comparing model 1 & 2
|
||||
ASSERT_EQ(model_hash, model_hash2) << "model1 has same graph name/nodes/env metadata as model2";
|
||||
ASSERT_EQ(id2, 1) << "id2 should be 1 as model 1 & 2 have same hash";
|
||||
|
||||
// Test loading same model from different path, see if hash values are same as well
|
||||
model_path = ORT_TSTR("testdata/TRTEP_test_model/mnist.onnx");
|
||||
std::shared_ptr<Model> model3;
|
||||
ASSERT_TRUE(Model::Load(model_path, model3, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
Graph& graph3 = model3->MainGraph();
|
||||
GraphViewer viewer3(graph3);
|
||||
HashValue model_hash3;
|
||||
int id3 = TRTGenerateMetaDefId(viewer3, model_hash3);
|
||||
ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded";
|
||||
ASSERT_EQ(id3, 2) << "id3 should be 2 as model 1 & 2 & 3 have same hash";
|
||||
}
|
||||
|
||||
// Compare on TRT subgraph id when repeatedly calling TRTGenerateMetaDefId
|
||||
TEST(TensorrtExecutionProviderTest, TRTSubgraphIdGeneratorUsingModelHashing) {
|
||||
// Load model
|
||||
auto model_path = ORT_TSTR("testdata/mnist.onnx");
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_path, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
Graph& main_graph = model->MainGraph();
|
||||
GraphViewer graph(main_graph);
|
||||
HashValue model_hash;
|
||||
|
||||
// Graph id acquired
|
||||
int graph_id = TRTGenerateMetaDefId(graph, model_hash);
|
||||
int asserted_subgraph_id = graph_id + 1;
|
||||
|
||||
// mock fetching subgraphs and generate id by calling TRTGenerateMetaDefId repeatedly
|
||||
const int number_of_ort_nodes = graph.NumberOfNodes();
|
||||
std::vector<size_t> nodes_vector(number_of_ort_nodes);
|
||||
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
|
||||
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
|
||||
|
||||
for (const auto& index : nodes_vector) {
|
||||
const auto& node = graph.GetNode(node_index[index]);
|
||||
std::cout << "->" << node->Name();
|
||||
|
||||
// Check if id increment each time TRTGenerateMetaDefId is called
|
||||
int subgraph_id = TRTGenerateMetaDefId(graph, model_hash);
|
||||
ASSERT_EQ(subgraph_id, asserted_subgraph_id) << "id will increment as TRTGenerateMetaDefId is repeatedly called";
|
||||
asserted_subgraph_id++;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TensorrtExecutionProviderCacheTest, Run) {
|
||||
// GetParam() returns the parameter of following format:
|
||||
// ##cache type##_##input shape type##
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/TRTEP_test_model/mnist.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/TRTEP_test_model/mnist.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue