From 4b900dc585045601f04326bb0ea5e4b565cd6241 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Fri, 20 Dec 2019 17:52:41 -0800 Subject: [PATCH] Simplify cache implementation and avoid static variables that may carry over between models --- onnxruntime/core/codegen/mti/mti_tvm_utils.cc | 5 +- onnxruntime/core/codegen/mti/mti_tvm_utils.h | 2 +- .../passes/weight_layout/weight_layout.cc | 4 +- .../nuphar/common/nuphar_tvm_utils.cc | 149 ++++++++---------- .../nuphar/common/nuphar_tvm_utils.h | 11 +- .../nuphar/compiler/nuphar_codegen_ctx.cc | 11 +- .../nuphar/compiler/nuphar_compiler.cc | 9 +- .../python/onnxruntime_test_python_nuphar.py | 49 +++--- 8 files changed, 117 insertions(+), 123 deletions(-) diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc b/onnxruntime/core/codegen/mti/mti_tvm_utils.cc index 562834bf95..0afc25d627 100644 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc +++ b/onnxruntime/core/codegen/mti/mti_tvm_utils.cc @@ -113,18 +113,17 @@ tvm::Tensor Promote(const tvm::Expr& expr, const tvm::Array& shape, c name); } -void DumpTVMModuleToFile(const std::string& filename_prefix, tvm::runtime::Module& module) { +void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module) { const codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); if (!settings.HasOption(codegen::CodeGenSettings::kCodeGenDumpModule)) return; - static int dump_module_cnt = 0; // ISSUE: note that all option values are converted to lower case. It doesn't cause // any issue currently, because all supported formats (i.e. file exts) are of lower case. // Just keep in mind that we might have issue if somehow we started to support dump // formats with upper case, although it's quite unlikely. std::string format = settings.GetOptionValue(codegen::CodeGenSettings::kCodeGenDumpModule); - std::string module_filename = filename_prefix + "_" + std::to_string(dump_module_cnt++) + "." + format; + std::string module_filename = filename + "." + format; module->SaveToFile(module_filename, format); } diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.h b/onnxruntime/core/codegen/mti/mti_tvm_utils.h index 709269aacf..487c6415a7 100644 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.h +++ b/onnxruntime/core/codegen/mti/mti_tvm_utils.h @@ -51,7 +51,7 @@ tvm::Tensor Promote(const tvm::Expr& expr, tvm::Tensor MakeZeroTensor(const tvm::Array& shape, HalideIR::Type type, const std::string& name); -void DumpTVMModuleToFile(const std::string& filename_prefix, tvm::runtime::Module& module); +void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module); bool BroadcastDim(const tvm::Array& shape, size_t i, size_t output_rank, tvm::Expr& dim); diff --git a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc index 7cbb5692f7..ab3e647fd2 100644 --- a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc +++ b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc @@ -31,8 +31,8 @@ const std::string WeightLayout::GetKey( ONNX_NAMESPACE::TensorProto_DataType proto_type, int input_dim, float pad_zero) { - std::ostringstream key(name); - key << "_type_" << static_cast(proto_type); + std::ostringstream key; + key << name << "_type_" << static_cast(proto_type); key << "_dim_" << input_dim; key << "_pad_zero_" << pad_zero; return NormalizeCppName(key.str()); diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc index f0fcf46558..7e27842f1e 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc +++ b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc @@ -15,7 +15,7 @@ #include "gsl/gsl" #include #include -#define _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING // required by VS 2019 +#define _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING // required by VS 2019 #include #undef _SILENCE_EXPERIMENTAL_FILESYSTEM_DEPRECATION_WARNING #include @@ -102,117 +102,102 @@ static void ParseVersion(const char* version, int* major, int* minor, int* patch *patch = ver_num_fn(val); } -static void VerifyCacheVersion(const std::string& so_path) { - static std::atomic cache_version_checked{false}; - static std::mutex cache_version_mutex; +static bool VerifyCacheVersion(const std::string& so_path) { + // ensure we have _ORTInternal_GetCacheVersion_ function + void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true); + if (f == nullptr) + return false; - // make sure we only check cache version once - if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) { - std::lock_guard lock(cache_version_mutex); - if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) { - cache_version_checked.store(true, std::memory_order::memory_order_release); - // ensure we have _ORTInternal_GetCacheVersion_ function - void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true); - ORT_ENFORCE(f, "NULL library function pointer!"); + typedef const char* (*GetVersionFunc)(); + GetVersionFunc func = reinterpret_cast(f); + const char* cache_version = func(); + ORT_ENFORCE(cache_version, "Null cache version string!"); + int cur_major, cur_minor, cur_patch; + ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch); + int cache_major, cache_minor, cache_patch; + ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch); - typedef const char* (*GetVersionFunc)(); - GetVersionFunc func = reinterpret_cast(f); - const char* cache_version = func(); - ORT_ENFORCE(cache_version, "Null cache version string!"); - int cur_major, cur_minor, cur_patch; - ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch); - int cache_major, cache_minor, cache_patch; - ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch); - - // make version check strict until we have thorough design for compatibility - ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor), - "Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__, - ") doesn't match cached dll version (", cache_version, ")"); - - cache_version_checked = true; - } + // make version check strict until we have thorough design for compatibility + bool version_match = (cur_major == cache_major) && (cur_minor == cache_minor); + if (!version_match) { + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Current nuphar runtime version (" << __NUPHAR_CACHE_VERSION__ << ") doesn't match cached dll version (" << cache_version << ")"; } + return version_match; } -static bool disable_caching_due_to_checksum_failure = false; - static bool VerifyTVMModuleChecksum(const std::string& so_path) { - static std::string last_so_path; - static bool last_checksum_validated = false; - static std::mutex checksum_mutex; - if (last_so_path != so_path) { - std::lock_guard lock(checksum_mutex); - if (last_so_path != so_path) { - disable_caching_due_to_checksum_failure = false; // reset disabled caching for a new file - last_so_path = so_path; - void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false); - if (f) { - typedef void (*GetChecksumFunc)(const char*&, size_t&); - GetChecksumFunc func = reinterpret_cast(f); - const char* model_checksum; - size_t model_checksum_len; - func(model_checksum, - model_checksum_len); + void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCheckSum", /*throw_if_not_found*/ false); + if (f) { + typedef void (*GetChecksumFunc)(const char*&, size_t&); + GetChecksumFunc func = reinterpret_cast(f); + const char* model_checksum; + size_t model_checksum_len; + func(model_checksum, + model_checksum_len); - codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance(); - // When checksum is expected by dll/so, user must set environment variable - // NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model. - // User may choose to run with base model or simplified mode and any match - // would be regarded as validated. - // Note that checksum validation here is not designed as a security measurement, - // so checksum compute is not done inside ORT. - last_checksum_validated = - setting.OptionMatches( - kNupharCacheModelChecksum, - std::string(model_checksum, model_checksum_len)); - - if (!last_checksum_validated) { - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT..."; - disable_caching_due_to_checksum_failure = true; - } - } else { - // do not validate checksum if dll didn't require it (usually during debugging) - // TODO: force checksum validation in final release - last_checksum_validated = true; + codegen::CodeGenSettings& setting = codegen::CodeGenSettings::Instance(); + // When checksum is expected by dll/so, user must set environment variable + // NUPHAR_CACHE_MODEL_CHECKSUM from md5 digest of running model. + // User may choose to run with base model or simplified mode and any match + // would be regarded as validated. + // Note that checksum validation here is not designed as a security measurement, + // so checksum compute is not done inside ORT. + if (setting.OptionMatches( + kNupharCacheModelChecksum, + std::string(model_checksum, model_checksum_len))) { + return true; + } else { + static std::mutex warn_mutex; + std::lock_guard warn_lock(warn_mutex); + static std::string last_warned_so_path; + if (last_warned_so_path != so_path) { + // warning only once for each so_path + LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cache checksum validation failed, using JIT..."; } + return false; } + } else { + // do not validate checksum if dll didn't require it (usually during debugging) + // TODO: force checksum validation in final release + return true; } - return last_checksum_validated; } -tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name) { +CacheStatus LoadTVMPackedFuncFromCache(const std::string& func_name, tvm::runtime::PackedFunc& func) { std::string so_path; - if (!GetCacheSoFilePath(so_path)) - return nullptr; + if (!GetCacheSoFilePath(so_path)) { + if (codegen::CodeGenSettings::Instance().HasOption(kNupharCachePath)) { + return CacheStatus::Missing; + } else { + return CacheStatus::NotInUse; + } + } - VerifyCacheVersion(so_path); + if (!VerifyCacheVersion(so_path)) { + return CacheStatus::Mismatch; + } if (!VerifyTVMModuleChecksum(so_path)) - return nullptr; + return CacheStatus::Mismatch; tvm::runtime::Module module = tvm::runtime::Module::LoadFromFile(so_path); - tvm::runtime::PackedFunc func = module.GetFunction(func_name); + func = module.GetFunction(func_name); if (func == nullptr) { LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Cannot find " << func_name << " in cache, using JIT..."; + return CacheStatus::Missing; } - return func; + return CacheStatus::Found; } void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module) { fs::path path; - if (disable_caching_due_to_checksum_failure) - return; - static std::mutex save_cache_mutex; - static std::unordered_set existing_files; std::lock_guard lock(save_cache_mutex); - if (existing_files.count(filename) == 0 && - GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) { - existing_files.insert(filename); + if (GetOrCreateTVMModuleCacheDirectory(path, /*create*/ true)) { path.append(filename + ".o"); if (fs::exists(path)) { - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving..."; + //LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "Object file " << path << " already exists, skip saving..."; return; } module->SaveToFile(path.string(), "o"); diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h index fe9162911c..a9f75bdf17 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h +++ b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h @@ -19,8 +19,15 @@ struct NupharSubgraphUnit; //forward // Helper functions to create or load from offline cached dll // note after saving to obj file, we need to use tvm Python to create dll // using script at onnxruntime/core/codegen/mti/scripts/create_shared.py -tvm::runtime::PackedFunc -LoadTVMPackedFuncFromCache(const std::string& func_name); + +enum class CacheStatus { + NotInUse, + Mismatch, + Missing, + Found, +}; + +CacheStatus LoadTVMPackedFuncFromCache(const std::string& func_name, tvm::runtime::PackedFunc& func); void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& module); std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const CodeGenTarget& codegen_target, int64_t parallel_min_workloads); diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc b/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc index 08031691c5..f96913f2f9 100644 --- a/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.cc @@ -56,13 +56,16 @@ static tvm::runtime::PackedFunc LowerLayoutFunc(const tvm_codegen::WeightLayout* std::string func_name = layout->Name() + "_marshall"; - tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name); - - if (cached_func == nullptr) { + tvm::runtime::PackedFunc cached_func; + auto cache_status = nuphar::LoadTVMPackedFuncFromCache(func_name, cached_func); + if (cache_status != nuphar::CacheStatus::Found) { + ORT_ENFORCE(cached_func == nullptr); auto lowered = tvm::lower(S, {inputs[0], outputs[0]}, func_name, {}, config); auto module = tvm::build(lowered, tvm::target::llvm(), tvm::Target(), config); tvm_codegen::DumpTVMModuleToFile(func_name, module); - nuphar::SaveTVMModuleToCache(func_name, module); + if (cache_status == nuphar::CacheStatus::Missing) { + nuphar::SaveTVMModuleToCache(func_name, module); + } cached_func = module.GetFunction(func_name); } return cached_func; diff --git a/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc b/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc index 9b41ac88cf..4b648e0e38 100644 --- a/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc +++ b/onnxruntime/core/providers/nuphar/compiler/nuphar_compiler.cc @@ -151,8 +151,9 @@ tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc( // JIT-caching and AOT are mutual exclusive. // Change it by not always saving a compiled func unless it is in JIT-Caching model. // In AOT, there should be another member func explicitly loading - tvm::runtime::PackedFunc cached_func = nuphar::LoadTVMPackedFuncFromCache(func_name); - if (cached_func == nullptr) { + tvm::runtime::PackedFunc cached_func; + auto cache_status = nuphar::LoadTVMPackedFuncFromCache(func_name, cached_func); + if (cache_status != nuphar::CacheStatus::Found) { codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); if (settings.HasOption(kNupharCacheForceNoJIT)) { @@ -180,7 +181,9 @@ tvm::runtime::PackedFunc NupharCompiler::GetLoweredPackedFunc( tvm::runtime::Module module = tvm::build(lowered, tvm_target, tvm_host_target, config); tvm_codegen::DumpTVMModuleToFile(func_name, module); - nuphar::SaveTVMModuleToCache(func_name, module); + if (cache_status == nuphar::CacheStatus::Missing) { + nuphar::SaveTVMModuleToCache(func_name, module); + } cached_func = module.GetFunction(func_name); } diff --git a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py index b496a2b93e..f9387ffb73 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py +++ b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py @@ -41,11 +41,13 @@ class TestNuphar(unittest.TestCase): # run onnx_test_runner to verify results # use -M to disable memory pattern - # use -j 1 -c 1 to run one model/session at a time when running multiple models onnx_test_runner = os.path.join(cwd, 'onnx_test_runner') - subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-c', '1', '-j', '1', '-n', 'bidaf', cwd], check=True, cwd=cwd) + subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-n', 'bidaf', cwd], check=True, cwd=cwd) # test AOT on the quantized model + if os.name not in ['nt', 'posix']: + return # don't run the rest of test if AOT is not supported + cache_dir = os.path.join(cwd, 'nuphar_cache') if os.path.exists(cache_dir): shutil.rmtree(cache_dir) @@ -57,33 +59,28 @@ class TestNuphar(unittest.TestCase): tp = onnx.load_tensor(os.path.join(bidaf_dir, 'test_data_set_0', 'input_{}.pb'.format(i))) feed[tp.name] = numpy_helper.to_array(tp) - # force codegen_target to be avx - nuphar_settings = 'nuphar_codegen_target:avx' - onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) - sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) - assert 'NupharExecutionProvider' in sess.get_providers() - output = sess.run([], feed) + for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]: + nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir) + for isa in ['avx', 'avx2', 'avx512']: + onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings + ', nuphar_codegen_target:' + isa) + sess = onnxrt.InferenceSession(model) # JIT cache happens when initializing session - nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir) - onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) - sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) # JIT cache happens when initializing session - assert 'NupharExecutionProvider' in sess.get_providers() - output = sess.run([], feed) - - cache_dir_content = os.listdir(cache_dir) - assert len(cache_dir_content) == 1 - cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0]) - so_name = 'bidaf.so' - if os.name in ['nt', 'posix'] : # Windows or Linux + cache_dir_content = os.listdir(cache_dir) + assert len(cache_dir_content) == 1 + cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0]) + so_name = os.path.basename(model) + '.so' subprocess.run([sys.executable, '-m', 'onnxruntime.nuphar.create_shared', '--input_dir', cache_versioned_dir, '--output_name', so_name], check=True) - else: - return # don't run the rest of test if AOT is not supported - nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(cache_dir, so_name, 'on') - onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) - sess = onnxrt.InferenceSession(bidaf_int8_scan_only_model) # JIT cache happens when initializing session - assert 'NupharExecutionProvider' in sess.get_providers() - sess.run([], feed) + nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format(cache_dir, so_name, 'on') + onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) + sess = onnxrt.InferenceSession(model) + sess.run([], feed) + + # test avx + nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format(cache_dir, so_name, 'on', 'avx') + onnxrt.capi._pybind_state.set_nuphar_settings(nuphar_settings) + sess = onnxrt.InferenceSession(model) + sess.run([], feed) def test_bert_squad(self):