Simplify cache implementation and avoid static variables that may carry over between models

This commit is contained in:
KeDengMS 2019-12-20 17:52:41 -08:00 committed by Changming Sun
parent da03ed4473
commit 4b900dc585
8 changed files with 117 additions and 123 deletions

View file

@ -113,18 +113,17 @@ tvm::Tensor Promote(const tvm::Expr& expr, const tvm::Array<tvm::Expr>& shape, c
name);
}
void DumpTVMModuleToFile(const std::string& filename_prefix, tvm::runtime::Module& module) {
void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module) {
const codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance();
if (!settings.HasOption(codegen::CodeGenSettings::kCodeGenDumpModule))
return;
static int dump_module_cnt = 0;
// ISSUE: note that all option values are converted to lower case. It doesn't cause
// any issue currently, because all supported formats (i.e. file exts) are of lower case.
// Just keep in mind that we might have issue if somehow we started to support dump
// formats with upper case, although it's quite unlikely.
std::string format = settings.GetOptionValue(codegen::CodeGenSettings::kCodeGenDumpModule);
std::string module_filename = filename_prefix + "_" + std::to_string(dump_module_cnt++) + "." + format;
std::string module_filename = filename + "." + format;
module->SaveToFile(module_filename, format);
}

View file

@ -51,7 +51,7 @@ tvm::Tensor Promote(const tvm::Expr& expr,
tvm::Tensor MakeZeroTensor(const tvm::Array<tvm::Expr>& shape, HalideIR::Type type, const std::string& name);
void DumpTVMModuleToFile(const std::string& filename_prefix, tvm::runtime::Module& module);
void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module);
bool BroadcastDim(const tvm::Array<tvm::Expr>& shape, size_t i, size_t output_rank, tvm::Expr& dim);

View file

@ -31,8 +31,8 @@ const std::string WeightLayout::GetKey(
ONNX_NAMESPACE::TensorProto_DataType proto_type,
int input_dim,
float pad_zero) {
std::ostringstream key(name);
key << "_type_" << static_cast<int>(proto_type);
std::ostringstream key;
key << name << "_type_" << static_cast<int>(proto_type);
key << "_dim_" << input_dim;
key << "_pad_zero_" << pad_zero;
return NormalizeCppName(key.str());

View file

@ -15,7 +15,7 @@
#include "gsl/gsl"
#include <topi/detail/extern.h>
#include <tvm/ir_pass.h>
#define _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING // required by VS 2019
#define _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING // required by VS 2019
#include <experimental/filesystem>
#undef _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING
#include <atomic>
@ -102,117 +102,102 @@ static void ParseVersion(const char* version, int* major, int* minor, int* patch
*patch = ver_num_fn(val);
}
static void VerifyCacheVersion(const std::string& so_path) {
static std::atomic<bool> cache_version_checked{false};
static std::mutex cache_version_mutex;
static bool VerifyCacheVersion(const std::string& so_path) {
// ensure we have _ORTInternal_GetCacheVersion_ function
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
if (f == nullptr)
return false;
// make sure we only check cache version once
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
std::lock_guard<std::mutex> lock(cache_version_mutex);
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
cache_version_checked.store(true, std::memory_order::memory_order_release);
// ensure we have _ORTInternal_GetCacheVersion_ function
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
ORT_ENFORCE(f, "NULL library function pointer!");
typedef const char* (*GetVersionFunc)();
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
const char* cache_version = func();
ORT_ENFORCE(cache_version, "Null cache version string!");
int cur_major, cur_minor, cur_patch;
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
int cache_major, cache_minor, cache_patch;
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
typedef const char* (*GetVersionFunc)();
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
const char* cache_version = func();
ORT_ENFORCE(cache_version, "Null cache version string!");
int cur_major, cur_minor, cur_patch;
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
int cache_major, cache_minor, cache_patch;
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
// make version check strict until we have thorough design for compatibility
ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor),
"Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__,
") doesn't match cached dll version (", cache_version, ")");
cache_version_checked = true;
}
// make version check strict until we have thorough design for compatibility
bool version_match = (cur_major == cache_major) && (cur_minor == cache_minor);
if (!version_match) {
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Current nuphar runtime version (" << __NUPHAR_CACHE_VERSION__ << ") doesn't match cached dll version (" << cache_version << ")";
}
return version_match;
}
static bool disable_caching_due_to_checksum_failure = false;
static bool VerifyTVMModuleChecksum(const std::string& so_path) {
static std::string last_so_path;
static bool last_checksum_validated = false;
static std::mutex checksum_mutex;
if (last_so_path != so_path) {
std::lock_guard<std::mutex> lock(checksum_mutex);
if (last_so_path != so_path) {
disable_caching_due_to_checksum_failure = false; // reset disabled caching for a new file
last_so_path = so_path;
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false);
if (f) {
typedef void (*GetChecksumFunc)(const char*&, size_t&);
GetChecksumFunc func = reinterpret_cast<GetChecksumFunc>(f);
const char* model_checksum;
size_t model_checksum_len;
func(model_checksum,
model_checksum_len);
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false);
if (f) {
typedef void (*GetChecksumFunc)(const char*&, size_t&);
GetChecksumFunc func = reinterpret_cast<GetChecksumFunc>(f);
const char* model_checksum;
size_t model_checksum_len;
func(model_checksum,
model_checksum_len);
codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance();
// When checksum is expected by dll/so, user must set environment variable
// NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model.
// User may choose to run with base model or simplified mode and any match
// would be regarded as validated.
// Note that checksum validation here is not designed as a security measurement,
// so checksum compute is not done inside ORT.
last_checksum_validated =
setting.OptionMatches(
kNupharCacheModelChecksum,
std::string(model_checksum, model_checksum_len));
if (!last_checksum_validated) {
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT...";
disable_caching_due_to_checksum_failure = true;
}
} else {
// do not validate checksum if dll didn't require it (usually during debugging)
// TODO: force checksum validation in final release
last_checksum_validated = true;
codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance();
// When checksum is expected by dll/so, user must set environment variable
// NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model.
// User may choose to run with base model or simplified mode and any match
// would be regarded as validated.
// Note that checksum validation here is not designed as a security measurement,
// so checksum compute is not done inside ORT.
if (setting.OptionMatches(
kNupharCacheModelChecksum,
std::string(model_checksum, model_checksum_len))) {
return true;
} else {
static std::mutex warn_mutex;
std::lock_guard<std::mutex> warn_lock(warn_mutex);
static std::string last_warned_so_path;
if (last_warned_so_path != so_path) {
// warning only once for each so_path
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT...";
}
return false;
}
} else {
// do not validate checksum if dll didn't require it (usually during debugging)
// TODO: force checksum validation in final release
return true;
}
return last_checksum_validated;
}
tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name) {
CacheStatus LoadTVMPackedFuncFromCache(const std::string& func_name, tvm::runtime::PackedFunc& func) {
std::string so_path;
if (!GetCacheSoFilePath(so_path))
return nullptr;
if (!GetCacheSoFilePath(so_path)) {
if (codegen::CodeGenSettings::Instance().HasOption(kNupharCachePath)) {
return CacheStatus::Missing;
} else {
return CacheStatus::NotInUse;
}
}
VerifyCacheVersion(so_path);
if (!VerifyCacheVersion(so_path)) {
return CacheStatus::Mismatch;
}
if (!VerifyTVMModuleChecksum(so_path))
return nullptr;
return CacheStatus::Mismatch;
tvm::runtime::Module module = tvm::runtime::Module::LoadFromFile(so_path);
tvm::runtime::PackedFunc func = module.GetFunction(func_name);
func = module.GetFunction(func_name);
if (func == nullptr) {
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cannot find " << func_name << " in cache, using JIT...";
return CacheStatus::Missing;
}
return func;
return CacheStatus::Found;
}
void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module) {
fs::path path;
if (disable_caching_due_to_checksum_failure)
return;
static std::mutex save_cache_mutex;
static std::unordered_set<std::string> existing_files;
std::lock_guard<std::mutex> lock(save_cache_mutex);
if (existing_files.count(filename) == 0 &&
GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) {
existing_files.insert(filename);
if (GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) {
path.append(filename + ".o");
if (fs::exists(path)) {
LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving...";
//LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving...";
return;
}
module->SaveToFile(path.string(), "o");

View file

@ -19,8 +19,15 @@ struct NupharSubgraphUnit; //forward
// Helper functions to create or load from offline cached dll
// note after saving to obj file, we need to use tvm Python to create dll
// using script at onnxruntime/core/codegen/mti/scripts/create_shared.py
tvm::runtime::PackedFunc
LoadTVMPackedFuncFromCache(const std::string& func_name);
enum class CacheStatus {
NotInUse,
Mismatch,
Missing,
Found,
};
CacheStatus LoadTVMPackedFuncFromCache(const std::string& func_name, tvm::runtime::PackedFunc& func);
void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module);
std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const CodeGenTarget& codegen_target, int64_t parallel_min_workloads);

View file

@ -56,13 +56,16 @@ static tvm::runtime::PackedFunc LowerLayoutFunc(const tvm_codegen::WeightLayout*
std::string func_name = layout->Name() + "_marshall";
tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name);
if (cached_func == nullptr) {
tvm::runtime::PackedFunc cached_func;
auto cache_status = nuphar::LoadTVMPackedFuncFromCache(func_name, cached_func);
if (cache_status != nuphar::CacheStatus::Found) {
ORT_ENFORCE(cached_func == nullptr);
auto lowered = tvm::lower(S, {inputs[0], outputs[0]}, func_name, {}, config);
auto module = tvm::build(lowered, tvm::target::llvm(), tvm::Target(), config);
tvm_codegen::DumpTVMModuleToFile(func_name, module);
nuphar::SaveTVMModuleToCache(func_name, module);
if (cache_status == nuphar::CacheStatus::Missing) {
nuphar::SaveTVMModuleToCache(func_name, module);
}
cached_func = module.GetFunction(func_name);
}
return cached_func;

View file

@ -151,8 +151,9 @@ tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc(
// JIT-caching and AOT are mutual exclusive.
// Change it by not always saving a compiled func unless it is in JIT-Caching model.
// In AOT, there should be another member func explicitly loading
tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name);
if (cached_func == nullptr) {
tvm::runtime::PackedFunc cached_func;
auto cache_status = nuphar::LoadTVMPackedFuncFromCache(func_name, cached_func);
if (cache_status != nuphar::CacheStatus::Found) {
codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance();
if (settings.HasOption(kNupharCacheForceNoJIT)) {
@ -180,7 +181,9 @@ tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc(
tvm::runtime::Module module = tvm::build(lowered, tvm_target, tvm_host_target, config);
tvm_codegen::DumpTVMModuleToFile(func_name, module);
nuphar::SaveTVMModuleToCache(func_name, module);
if (cache_status == nuphar::CacheStatus::Missing) {
nuphar::SaveTVMModuleToCache(func_name, module);
}
cached_func = module.GetFunction(func_name);
}

View file

@ -41,11 +41,13 @@ class TestNuphar(unittest.TestCase):
# run onnx_test_runner to verify results
# use -M to disable memory pattern
# use -j 1 -c 1 to run one model/session at a time when running multiple models
onnx_test_runner = os.path.join(cwd, 'onnx_test_runner')
subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-c', '1', '-j', '1', '-n', 'bidaf', cwd], check=True, cwd=cwd)
subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-n', 'bidaf', cwd], check=True, cwd=cwd)
# test AOT on the quantized model
if os.name not in ['nt', 'posix']:
return # don't run the rest of test if AOT is not supported
cache_dir = os.path.join(cwd, 'nuphar_cache')
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
@ -57,33 +59,28 @@ class TestNuphar(unittest.TestCase):
tp = onnx.load_tensor(os.path.join(bidaf_dir, 'test_data_set_0', 'input_{}.pb'.format(i)))
feed[tp.name] = numpy_helper.to_array(tp)
# force codegen_target to be avx
nuphar_settings = 'nuphar_codegen_target:avx'
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model)
assert 'NupharExecutionProvider' in sess.get_providers()
output = sess.run([], feed)
for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]:
nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir)
for isa in ['avx', 'avx2', 'avx512']:
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings + ', nuphar_codegen_target:' + isa)
sess = onnxrt.InferenceSession(model) # JIT cache happens when initializing session
nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir)
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) # JIT cache happens when initializing session
assert 'NupharExecutionProvider' in sess.get_providers()
output = sess.run([], feed)
cache_dir_content = os.listdir(cache_dir)
assert len(cache_dir_content) == 1
cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0])
so_name = 'bidaf.so'
if os.name in ['nt', 'posix'] : # Windows or Linux
cache_dir_content = os.listdir(cache_dir)
assert len(cache_dir_content) == 1
cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0])
so_name = os.path.basename(model) + '.so'
subprocess.run([sys.executable, '-m', 'onnxruntime.nuphar.create_shared', '--input_dir', cache_versioned_dir, '--output_name', so_name], check=True)
else:
return # don't run the rest of test if AOT is not supported
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(cache_dir, so_name, 'on')
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) # JIT cache happens when initializing session
assert 'NupharExecutionProvider' in sess.get_providers()
sess.run([], feed)
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(cache_dir, so_name, 'on')
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(model)
sess.run([], feed)
# test avx
nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format(cache_dir, so_name, 'on', 'avx')
onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings)
sess = onnxrt.InferenceSession(model)
sess.run([], feed)
def test_bert_squad(self):