From f7412899a107535d18fc63e464f14bb6987843eb Mon Sep 17 00:00:00 2001 From: Yang Chen <40417152+yangchen-MS@users.noreply.github.com> Date: Sat, 14 Dec 2019 22:46:30 -0800 Subject: [PATCH] added cache version for nuphar JIT binaries (#2646) * added cache version for nuphar JIT binaries Previously, when the user wrongfully loaded a JIT binary generated from a Nuphar version different from the current used one, she would get mysterious runtime failures, because we didn't perform any version check on JIT binaries. This change added cache versions to the Nuphar runtime and JIT binaries. The Nuphar runtime will issue verbose message that informs the user version-mismatch errors. * address CR feedback * include NUPHAR_CACHE_VERSION in python wheel --- cmake/onnxruntime_python.cmake | 2 +- .../nuphar/common/nuphar_settings.cc | 1 - .../providers/nuphar/common/nuphar_settings.h | 8 --- .../nuphar/common/nuphar_tvm_utils.cc | 70 ++++++++++++++++--- .../nuphar/scripts/NUPHAR_CACHE_VERSION | 9 +++ .../nuphar/scripts/create_shared.cmd | 13 +++- .../providers/nuphar/scripts/create_shared.py | 16 ++++- .../providers/nuphar/scripts/create_shared.sh | 25 ++++++- setup.py | 2 + 9 files changed, 123 insertions(+), 23 deletions(-) create mode 100644 onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 8525e7a1a2..d9e76697b8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -251,7 +251,7 @@ endif() if (onnxruntime_USE_NUPHAR) file(GLOB onnxruntime_python_nuphar_python_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*.*" + "${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*" ) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc b/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc index 271e12f50c..a6aef17e7f 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc +++ b/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc @@ -34,7 +34,6 @@ static const std::unordered_set valid_keys = { kNupharIMatMulForceMkl, kNupharMatmulExec, kNupharCachePath, - kNupharCacheVersion, kNupharCacheSoName, kNupharCacheModelChecksum, kNupharCacheForceNoJIT, diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_settings.h b/onnxruntime/core/providers/nuphar/common/nuphar_settings.h index 5d2c149186..303daa98e3 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_settings.h +++ b/onnxruntime/core/providers/nuphar/common/nuphar_settings.h @@ -14,7 +14,6 @@ constexpr static const char* kNupharDumpPartition = "nuphar_dump_partition"; constexpr static const char* kNupharDumpFusedNodes = "nuphar_dump_fused_nodes"; constexpr static const char* kNupharMatmulExec = "nuphar_matmul_exec"; constexpr static const char* kNupharCachePath = "nuphar_cache_path"; -constexpr static const char* kNupharCacheVersion = "nuphar_cache_version"; constexpr static const char* kNupharCacheSoName = "nuphar_cache_so_name"; constexpr static const char* kNupharCacheModelChecksum = "nuphar_cache_model_checksum"; constexpr static const char* kNupharCacheForceNoJIT = "nuphar_cache_force_no_jit"; @@ -48,13 +47,6 @@ constexpr static const char* kNupharCodeGenTarget = "nuphar_codegen_target"; // Option to control nuphar code to run with parallel schedule constexpr static const char* kNupharParallelMinWorkloads = "nuphar_parallel_min_workloads"; -// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/ -// 1. MAJOR version when you make incompatible changes that old cache files no longer work, -// 2. MINOR version when you add functionality in a backwards - compatible manner, and -// 3. PATCH version when you make backwards - compatible bug fixes. -// NOTE this version needs to be updated when generated code may change -constexpr static const char* kNupharCacheVersion_Current = "1.0.0"; - constexpr static const char* kNupharCacheSoName_Default = "jit.so"; void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc index fac9f2c552..db5b9a3fc1 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc +++ b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc @@ -11,10 +11,12 @@ #include "core/common/logging/logging.h" #include "core/platform/env.h" #include "core/providers/common.h" +#include "core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION" #include "gsl/gsl" #include #include #include +#include #include namespace fs = std::experimental::filesystem; @@ -27,13 +29,6 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) { if (!settings.HasOption(kNupharCachePath)) return false; - std::string version; - if (settings.HasOption(kNupharCacheVersion)) { - version = settings.GetOptionValue(kNupharCacheVersion); - } else { - version = kNupharCacheVersion_Current; - } - path = settings.GetOptionValue(kNupharCachePath); if (!create && !fs::is_directory(path)) return false; @@ -43,7 +38,7 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) { throw std::runtime_error("Failed to create directory " + path.string()); } - path.append(version); + path.append(__NUPHAR_CACHE_VERSION__); if (!create && !fs::is_directory(path)) return false; @@ -80,6 +75,63 @@ static void* GetFuncFromLibrary(const std::string& so_path, const std::string& f return func; } +static void ParseVersion(const char* version, int* major, int* minor, int* patch) { + std::stringstream ss(version); + std::string val; + + auto ver_num_fn = [](const std::string& val) { + ORT_ENFORCE(!val.empty(), "Empty version number"); + if (val.length() > 1 && val[0] == '0') { + ORT_THROW("Invalid version number: ", val); + } + ORT_ENFORCE(std::all_of(val.begin(), val.end(), [](char c) { return isdigit(c); }), + "Invalid version number: ", val); + return std::stoi(val); + }; + + std::getline(ss, val, '.'); + ORT_ENFORCE(ss.good(), "Invalid version format: ", version); + *major = ver_num_fn(val); + + std::getline(ss, val, '.'); + *minor = ver_num_fn(val); + + std::getline(ss, val); + *patch = ver_num_fn(val); +} + +static void VerifyCacheVersion(const std::string& so_path) { + static std::atomic cache_version_checked{false}; + static std::mutex cache_version_mutex; + + // make sure we only check cache version once + if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) { + std::lock_guard lock(cache_version_mutex); + if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) { + cache_version_checked.store(true, std::memory_order::memory_order_release); + // ensure we have _ORTInternal_GetCacheVersion_ function + void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true); + ORT_ENFORCE(f, "NULL library function pointer!"); + + typedef const char* (*GetVersionFunc)(); + GetVersionFunc func = reinterpret_cast(f); + const char* cache_version = func(); + ORT_ENFORCE(cache_version, "Null cache version string!"); + int cur_major, cur_minor, cur_patch; + ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch); + int cache_major, cache_minor, cache_patch; + ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch); + + // make version check strict until we have thorough design for compatibility + ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor), + "Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__, + ") doesn't match cached dll version (", cache_version, ")"); + + cache_version_checked = true; + } + } +} + static bool disable_caching_due_to_checksum_failure = false; static bool VerifyTVMModuleChecksum(const std::string& so_path) { @@ -131,6 +183,8 @@ tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name if (!GetCacheSoFilePath(so_path)) return nullptr; + VerifyCacheVersion(so_path); + if (!VerifyTVMModuleChecksum(so_path)) return nullptr; diff --git a/onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION b/onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION new file mode 100644 index 0000000000..6566ef23ba --- /dev/null +++ b/onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION @@ -0,0 +1,9 @@ +// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/ +// 1. MAJOR version when you make incompatible changes that old cache files no longer work, +// 2. MINOR version when you add functionality in a backwards - compatible manner, and +// 3. PATCH version when you make backwards - compatible bug fixes. +// NOTE this version needs to be updated when generated code may change + +#ifndef __NUPHAR_CACHE_VERSION__ +#define __NUPHAR_CACHE_VERSION__ "2.3.0" +#endif diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd b/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd index dac840a226..f446242d23 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd @@ -6,6 +6,7 @@ setlocal EnableDelayedExpansion if "%1"=="" goto Usage +set SCRIPT_DIR=%~dp0 set CACHE_DIR=%~f1 set MODEL_FILE=%~f2 @@ -46,6 +47,16 @@ echo __declspec(dllexport) >>%CHECKSUM_CC% echo void _ORTInternal_GetCheckSum(const char*^& cs, size_t^& len) { >> %CHECKSUM_CC% echo cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;} >>%CHECKSUM_CC% +REM generate cache version +set CACHE_VERSION_CC=%CACHE_DIR%\cache_version.cc +set VERSION_FILE=%SCRIPT_DIR%NUPHAR_CACHE_VERSION +echo Generating %CACHE_VERSION_CC%... +echo #include "%VERSION_FILE%" >%CACHE_VERSION_CC% +echo extern "C" >>%CACHE_VERSION_CC% +echo __declspec(dllexport) >>%CACHE_VERSION_CC% +echo const char* _ORTInternal_GetCacheVersion() { >> %CACHE_VERSION_CC% +echo return __NUPHAR_CACHE_VERSION__;} >>%CACHE_VERSION_CC% + :Compile cd /d %CACHE_DIR% for /f %%i in ('dir /b *.cc') do ( @@ -61,4 +72,4 @@ exit /b :Usage echo Usage: %0 cache_dir [model_file] [output_dll] echo The generated file would be cache_dir\output_dll -exit /b \ No newline at end of file +exit /b diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.py b/onnxruntime/core/providers/nuphar/scripts/create_shared.py index 8b384dafad..a68e1fdc75 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.py +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.py @@ -38,6 +38,18 @@ def gen_checksum(file_checksum, input_dir): print(' cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;', file=checksum_cc) print('}', file=checksum_cc) +def gen_cache_version(input_dir): + name = 'ORTInternal_cache_version' + with open(os.path.join(input_dir, name + '.cc'), 'w') as cache_version_cc: + header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NUPHAR_CACHE_VERSION') + print('#include "{}"'.format(header_file), file=cache_version_cc) + print('extern "C"', file=cache_version_cc) + if is_windows(): + print('__declspec(dllexport)', file=cache_version_cc) + print('const char* _ORTInternal_GetCacheVersion() {', file=cache_version_cc) + print(' return __NUPHAR_CACHE_VERSION__;', file=cache_version_cc) + print('}', file=cache_version_cc) + def compile_all_cc(path): for f in os.listdir(path): name, ext = os.path.splitext(f) @@ -65,6 +77,8 @@ if __name__ == '__main__': input_checksum = gen_md5(args.input_model) gen_checksum(input_checksum, args.input_dir) + gen_cache_version(args.input_dir) + if is_windows(): # create dllmain name = 'ORTInternal_dllmain' @@ -85,4 +99,4 @@ if __name__ == '__main__': if not args.keep_input: for f in objs: - os.remove(os.path.join(args.input_dir, f)) \ No newline at end of file + os.remove(os.path.join(args.input_dir, f)) diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.sh b/onnxruntime/core/providers/nuphar/scripts/create_shared.sh index 05f0cff917..4428180636 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.sh +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.sh @@ -4,6 +4,8 @@ set -x -e -o pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" + function usage { echo Usage: create_shared.sh -c cache_dir -m input_model_file -o output_so_file echo The generated file would be cache_dir/output_so_file @@ -34,6 +36,8 @@ if ! [ -x "$(command -v g++)" ]; then exit 1 fi +declare -a all_cc_files + cd $CACHE_DIR if [ -x "$MODEL_FILE" ]; then # generate checksum.cc @@ -46,10 +50,25 @@ void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) { cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1; } __EOF__ - g++ -std=c++14 -fPIC -o checksum.o -c checksum.cc - rm checksum.cc + all_cc_files+=(checksum) fi +# generate cache_version.cc +VERSION_FILE="${SCRIPT_DIR}/NUPHAR_CACHE_VERSION" +cat > $CACHE_DIR/cache_version.cc <<__EOF__ +#include "$VERSION_FILE" +extern "C" +const char* _ORTInternal_GetCacheVersion() { + return __NUPHAR_CACHE_VERSION__; +} +__EOF__ +all_cc_files+=(cache_version) + +for cc_file in "${all_cc_files[@]}"; do + g++ -std=c++14 -fPIC -o "$cc_file".o -c "$cc_file".cc + rm "$cc_file".cc +done + # link if ls *.o 1> /dev/null 2>&1; then OBJS="" @@ -61,4 +80,4 @@ if ls *.o 1> /dev/null 2>&1; then g++ -shared -fPIC -o $CACHE_DIR/$OUTPUT_SO_FILE $OBJS fi rm *.o -fi \ No newline at end of file +fi diff --git a/setup.py b/setup.py index 70659c815b..b53a6519ac 100644 --- a/setup.py +++ b/setup.py @@ -165,6 +165,8 @@ examples = [path.join('datasets', x) for x in examples_names] # Extra files such as EULA and ThirdPartyNotices extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md"] +if package_name == 'onnxruntime-nuphar': + extra.extend([path.join('nuphar', 'NUPHAR_CACHE_VERSION')]) # Description README = path.join(getcwd(), "docs/python/README.rst")