mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Simplify cache implementation and avoid static variables that may carry over between models
This commit is contained in:
parent
da03ed4473
commit
4b900dc585
8 changed files with 117 additions and 123 deletions
|
|
@ -113,18 +113,17 @@ tvm::Tensor Promote(const tvm::Expr& expr, const tvm::Array<tvm::Expr>& shape, c
|
|||
name);
|
||||
}
|
||||
|
||||
void DumpTVMModuleToFile(const std::string& filename_prefix, tvm::runtime::Module& module) {
|
||||
void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module) {
|
||||
const codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance();
|
||||
if (!settings.HasOption(codegen::CodeGenSettings::kCodeGenDumpModule))
|
||||
return;
|
||||
|
||||
static int dump_module_cnt = 0;
|
||||
// ISSUE: note that all option values are converted to lower case. It doesn't cause
|
||||
// any issue currently, because all supported formats (i.e. file exts) are of lower case.
|
||||
// Just keep in mind that we might have issue if somehow we started to support dump
|
||||
// formats with upper case, although it's quite unlikely.
|
||||
std::string format = settings.GetOptionValue(codegen::CodeGenSettings::kCodeGenDumpModule);
|
||||
std::string module_filename = filename_prefix + "_" + std::to_string(dump_module_cnt++) + "." + format;
|
||||
std::string module_filename = filename + "." + format;
|
||||
module->SaveToFile(module_filename, format);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ tvm::Tensor Promote(const tvm::Expr& expr,
|
|||
|
||||
tvm::Tensor MakeZeroTensor(const tvm::Array<tvm::Expr>& shape, HalideIR::Type type, const std::string& name);
|
||||
|
||||
void DumpTVMModuleToFile(const std::string& filename_prefix, tvm::runtime::Module& module);
|
||||
void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module);
|
||||
|
||||
bool BroadcastDim(const tvm::Array<tvm::Expr>& shape, size_t i, size_t output_rank, tvm::Expr& dim);
|
||||
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ const std::string WeightLayout::GetKey(
|
|||
ONNX_NAMESPACE::TensorProto_DataType proto_type,
|
||||
int input_dim,
|
||||
float pad_zero) {
|
||||
std::ostringstream key(name);
|
||||
key << "_type_" << static_cast<int>(proto_type);
|
||||
std::ostringstream key;
|
||||
key << name << "_type_" << static_cast<int>(proto_type);
|
||||
key << "_dim_" << input_dim;
|
||||
key << "_pad_zero_" << pad_zero;
|
||||
return NormalizeCppName(key.str());
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
#include "gsl/gsl"
|
||||
#include <topi/detail/extern.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#define _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING // required by VS 2019
|
||||
#define _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING // required by VS 2019
|
||||
#include <experimental/filesystem>
|
||||
#undef _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING
|
||||
#include <atomic>
|
||||
|
|
@ -102,117 +102,102 @@ static void ParseVersion(const char* version, int* major, int* minor, int* patch
|
|||
*patch = ver_num_fn(val);
|
||||
}
|
||||
|
||||
static void VerifyCacheVersion(const std::string& so_path) {
|
||||
static std::atomic<bool> cache_version_checked{false};
|
||||
static std::mutex cache_version_mutex;
|
||||
static bool VerifyCacheVersion(const std::string& so_path) {
|
||||
// ensure we have _ORTInternal_GetCacheVersion_ function
|
||||
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
|
||||
if (f == nullptr)
|
||||
return false;
|
||||
|
||||
// make sure we only check cache version once
|
||||
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
|
||||
std::lock_guard<std::mutex> lock(cache_version_mutex);
|
||||
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
|
||||
cache_version_checked.store(true, std::memory_order::memory_order_release);
|
||||
// ensure we have _ORTInternal_GetCacheVersion_ function
|
||||
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
|
||||
ORT_ENFORCE(f, "NULL library function pointer!");
|
||||
typedef const char* (*GetVersionFunc)();
|
||||
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
|
||||
const char* cache_version = func();
|
||||
ORT_ENFORCE(cache_version, "Null cache version string!");
|
||||
int cur_major, cur_minor, cur_patch;
|
||||
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
|
||||
int cache_major, cache_minor, cache_patch;
|
||||
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
|
||||
|
||||
typedef const char* (*GetVersionFunc)();
|
||||
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
|
||||
const char* cache_version = func();
|
||||
ORT_ENFORCE(cache_version, "Null cache version string!");
|
||||
int cur_major, cur_minor, cur_patch;
|
||||
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
|
||||
int cache_major, cache_minor, cache_patch;
|
||||
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
|
||||
|
||||
// make version check strict until we have thorough design for compatibility
|
||||
ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor),
|
||||
"Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__,
|
||||
") doesn't match cached dll version (", cache_version, ")");
|
||||
|
||||
cache_version_checked = true;
|
||||
}
|
||||
// make version check strict until we have thorough design for compatibility
|
||||
bool version_match = (cur_major == cache_major) && (cur_minor == cache_minor);
|
||||
if (!version_match) {
|
||||
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Current nuphar runtime version (" << __NUPHAR_CACHE_VERSION__ << ") doesn't match cached dll version (" << cache_version << ")";
|
||||
}
|
||||
return version_match;
|
||||
}
|
||||
|
||||
static bool disable_caching_due_to_checksum_failure = false;
|
||||
|
||||
static bool VerifyTVMModuleChecksum(const std::string& so_path) {
|
||||
static std::string last_so_path;
|
||||
static bool last_checksum_validated = false;
|
||||
static std::mutex checksum_mutex;
|
||||
if (last_so_path != so_path) {
|
||||
std::lock_guard<std::mutex> lock(checksum_mutex);
|
||||
if (last_so_path != so_path) {
|
||||
disable_caching_due_to_checksum_failure = false; // reset disabled caching for a new file
|
||||
last_so_path = so_path;
|
||||
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false);
|
||||
if (f) {
|
||||
typedef void (*GetChecksumFunc)(const char*&, size_t&);
|
||||
GetChecksumFunc func = reinterpret_cast<GetChecksumFunc>(f);
|
||||
const char* model_checksum;
|
||||
size_t model_checksum_len;
|
||||
func(model_checksum,
|
||||
model_checksum_len);
|
||||
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false);
|
||||
if (f) {
|
||||
typedef void (*GetChecksumFunc)(const char*&, size_t&);
|
||||
GetChecksumFunc func = reinterpret_cast<GetChecksumFunc>(f);
|
||||
const char* model_checksum;
|
||||
size_t model_checksum_len;
|
||||
func(model_checksum,
|
||||
model_checksum_len);
|
||||
|
||||
codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance();
|
||||
// When checksum is expected by dll/so, user must set environment variable
|
||||
// NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model.
|
||||
// User may choose to run with base model or simplified mode and any match
|
||||
// would be regarded as validated.
|
||||
// Note that checksum validation here is not designed as a security measurement,
|
||||
// so checksum compute is not done inside ORT.
|
||||
last_checksum_validated =
|
||||
setting.OptionMatches(
|
||||
kNupharCacheModelChecksum,
|
||||
std::string(model_checksum, model_checksum_len));
|
||||
|
||||
if (!last_checksum_validated) {
|
||||
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT...";
|
||||
disable_caching_due_to_checksum_failure = true;
|
||||
}
|
||||
} else {
|
||||
// do not validate checksum if dll didn't require it (usually during debugging)
|
||||
// TODO: force checksum validation in final release
|
||||
last_checksum_validated = true;
|
||||
codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance();
|
||||
// When checksum is expected by dll/so, user must set environment variable
|
||||
// NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model.
|
||||
// User may choose to run with base model or simplified mode and any match
|
||||
// would be regarded as validated.
|
||||
// Note that checksum validation here is not designed as a security measurement,
|
||||
// so checksum compute is not done inside ORT.
|
||||
if (setting.OptionMatches(
|
||||
kNupharCacheModelChecksum,
|
||||
std::string(model_checksum, model_checksum_len))) {
|
||||
return true;
|
||||
} else {
|
||||
static std::mutex warn_mutex;
|
||||
std::lock_guard<std::mutex> warn_lock(warn_mutex);
|
||||
static std::string last_warned_so_path;
|
||||
if (last_warned_so_path != so_path) {
|
||||
// warning only once for each so_path
|
||||
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT...";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// do not validate checksum if dll didn't require it (usually during debugging)
|
||||
// TODO: force checksum validation in final release
|
||||
return true;
|
||||
}
|
||||
return last_checksum_validated;
|
||||
}
|
||||
|
||||
tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name) {
|
||||
CacheStatus LoadTVMPackedFuncFromCache(const std::string& func_name, tvm::runtime::PackedFunc& func) {
|
||||
std::string so_path;
|
||||
if (!GetCacheSoFilePath(so_path))
|
||||
return nullptr;
|
||||
if (!GetCacheSoFilePath(so_path)) {
|
||||
if (codegen::CodeGenSettings::Instance().HasOption(kNupharCachePath)) {
|
||||
return CacheStatus::Missing;
|
||||
} else {
|
||||
return CacheStatus::NotInUse;
|
||||
}
|
||||
}
|
||||
|
||||
VerifyCacheVersion(so_path);
|
||||
if (!VerifyCacheVersion(so_path)) {
|
||||
return CacheStatus::Mismatch;
|
||||
}
|
||||
|
||||
if (!VerifyTVMModuleChecksum(so_path))
|
||||
return nullptr;
|
||||
return CacheStatus::Mismatch;
|
||||
|
||||
tvm::runtime::Module module = tvm::runtime::Module::LoadFromFile(so_path);
|
||||
tvm::runtime::PackedFunc func = module.GetFunction(func_name);
|
||||
func = module.GetFunction(func_name);
|
||||
if (func == nullptr) {
|
||||
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cannot find " << func_name << " in cache, using JIT...";
|
||||
return CacheStatus::Missing;
|
||||
}
|
||||
return func;
|
||||
return CacheStatus::Found;
|
||||
}
|
||||
|
||||
void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module) {
|
||||
fs::path path;
|
||||
|
||||
if (disable_caching_due_to_checksum_failure)
|
||||
return;
|
||||
|
||||
static std::mutex save_cache_mutex;
|
||||
static std::unordered_set<std::string> existing_files;
|
||||
std::lock_guard<std::mutex> lock(save_cache_mutex);
|
||||
if (existing_files.count(filename) == 0 &&
|
||||
GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) {
|
||||
existing_files.insert(filename);
|
||||
if (GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) {
|
||||
path.append(filename + ".o");
|
||||
if (fs::exists(path)) {
|
||||
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving...";
|
||||
//LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving...";
|
||||
return;
|
||||
}
|
||||
module->SaveToFile(path.string(), "o");
|
||||
|
|
|
|||
|
|
@ -19,8 +19,15 @@ struct NupharSubgraphUnit; //forward
|
|||
// Helper functions to create or load from offline cached dll
|
||||
// note after saving to obj file, we need to use tvm Python to create dll
|
||||
// using script at onnxruntime/core/codegen/mti/scripts/create_shared.py
|
||||
tvm::runtime::PackedFunc
|
||||
LoadTVMPackedFuncFromCache(const std::string& func_name);
|
||||
|
||||
enum class CacheStatus {
|
||||
NotInUse,
|
||||
Mismatch,
|
||||
Missing,
|
||||
Found,
|
||||
};
|
||||
|
||||
CacheStatus LoadTVMPackedFuncFromCache(const std::string& func_name, tvm::runtime::PackedFunc& func);
|
||||
void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module);
|
||||
|
||||
std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const CodeGenTarget& codegen_target, int64_t parallel_min_workloads);
|
||||
|
|
|
|||
|
|
@ -56,13 +56,16 @@ static tvm::runtime::PackedFunc LowerLayoutFunc(const tvm_codegen::WeightLayout*
|
|||
|
||||
std::string func_name = layout->Name() + "_marshall";
|
||||
|
||||
tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name);
|
||||
|
||||
if (cached_func == nullptr) {
|
||||
tvm::runtime::PackedFunc cached_func;
|
||||
auto cache_status = nuphar::LoadTVMPackedFuncFromCache(func_name, cached_func);
|
||||
if (cache_status != nuphar::CacheStatus::Found) {
|
||||
ORT_ENFORCE(cached_func == nullptr);
|
||||
auto lowered = tvm::lower(S, {inputs[0], outputs[0]}, func_name, {}, config);
|
||||
auto module = tvm::build(lowered, tvm::target::llvm(), tvm::Target(), config);
|
||||
tvm_codegen::DumpTVMModuleToFile(func_name, module);
|
||||
nuphar::SaveTVMModuleToCache(func_name, module);
|
||||
if (cache_status == nuphar::CacheStatus::Missing) {
|
||||
nuphar::SaveTVMModuleToCache(func_name, module);
|
||||
}
|
||||
cached_func = module.GetFunction(func_name);
|
||||
}
|
||||
return cached_func;
|
||||
|
|
|
|||
|
|
@ -151,8 +151,9 @@ tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc(
|
|||
// JIT-caching and AOT are mutual exclusive.
|
||||
// Change it by not always saving a compiled func unless it is in JIT-Caching model.
|
||||
// In AOT, there should be another member func explicitly loading
|
||||
tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name);
|
||||
if (cached_func == nullptr) {
|
||||
tvm::runtime::PackedFunc cached_func;
|
||||
auto cache_status = nuphar::LoadTVMPackedFuncFromCache(func_name, cached_func);
|
||||
if (cache_status != nuphar::CacheStatus::Found) {
|
||||
codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance();
|
||||
|
||||
if (settings.HasOption(kNupharCacheForceNoJIT)) {
|
||||
|
|
@ -180,7 +181,9 @@ tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc(
|
|||
|
||||
tvm::runtime::Module module = tvm::build(lowered, tvm_target, tvm_host_target, config);
|
||||
tvm_codegen::DumpTVMModuleToFile(func_name, module);
|
||||
nuphar::SaveTVMModuleToCache(func_name, module);
|
||||
if (cache_status == nuphar::CacheStatus::Missing) {
|
||||
nuphar::SaveTVMModuleToCache(func_name, module);
|
||||
}
|
||||
cached_func = module.GetFunction(func_name);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -41,11 +41,13 @@ class TestNuphar(unittest.TestCase):
|
|||
|
||||
# run onnx_test_runner to verify results
|
||||
# use -M to disable memory pattern
|
||||
# use -j 1 -c 1 to run one model/session at a time when running multiple models
|
||||
onnx_test_runner = os.path.join(cwd, 'onnx_test_runner')
|
||||
subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-c', '1', '-j', '1', '-n', 'bidaf', cwd], check=True, cwd=cwd)
|
||||
subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-n', 'bidaf', cwd], check=True, cwd=cwd)
|
||||
|
||||
# test AOT on the quantized model
|
||||
if os.name not in ['nt', 'posix']:
|
||||
return # don't run the rest of test if AOT is not supported
|
||||
|
||||
cache_dir = os.path.join(cwd, 'nuphar_cache')
|
||||
if os.path.exists(cache_dir):
|
||||
shutil.rmtree(cache_dir)
|
||||
|
|
@ -57,33 +59,28 @@ class TestNuphar(unittest.TestCase):
|
|||
tp = onnx.load_tensor(os.path.join(bidaf_dir, 'test_data_set_0', 'input_{}.pb'.format(i)))
|
||||
feed[tp.name] = numpy_helper.to_array(tp)
|
||||
|
||||
# force codegen_target to be avx
|
||||
nuphar_settings = 'nuphar_codegen_target:avx'
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model)
|
||||
assert 'NupharExecutionProvider' in sess.get_providers()
|
||||
output = sess.run([], feed)
|
||||
for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]:
|
||||
nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir)
|
||||
for isa in ['avx', 'avx2', 'avx512']:
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings + ', nuphar_codegen_target:' + isa)
|
||||
sess = onnxrt.InferenceSession(model) # JIT cache happens when initializing session
|
||||
|
||||
nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir)
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) # JIT cache happens when initializing session
|
||||
assert 'NupharExecutionProvider' in sess.get_providers()
|
||||
output = sess.run([], feed)
|
||||
|
||||
cache_dir_content = os.listdir(cache_dir)
|
||||
assert len(cache_dir_content) == 1
|
||||
cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0])
|
||||
so_name = 'bidaf.so'
|
||||
if os.name in ['nt', 'posix'] : # Windows or Linux
|
||||
cache_dir_content = os.listdir(cache_dir)
|
||||
assert len(cache_dir_content) == 1
|
||||
cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0])
|
||||
so_name = os.path.basename(model) + '.so'
|
||||
subprocess.run([sys.executable, '-m', 'onnxruntime.nuphar.create_shared', '--input_dir', cache_versioned_dir, '--output_name', so_name], check=True)
|
||||
else:
|
||||
return # don't run the rest of test if AOT is not supported
|
||||
|
||||
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(cache_dir, so_name, 'on')
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) # JIT cache happens when initializing session
|
||||
assert 'NupharExecutionProvider' in sess.get_providers()
|
||||
sess.run([], feed)
|
||||
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(cache_dir, so_name, 'on')
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(model)
|
||||
sess.run([], feed)
|
||||
|
||||
# test avx
|
||||
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format(cache_dir, so_name, 'on', 'avx')
|
||||
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
|
||||
sess = onnxrt.InferenceSession(model)
|
||||
sess.run([], feed)
|
||||
|
||||
|
||||
def test_bert_squad(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue