diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 8525e7a1a2..d9e76697b8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -251,7 +251,7 @@ endif() if (onnxruntime_USE_NUPHAR) file(GLOB onnxruntime_python_nuphar_python_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*.*" + "${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*" ) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc b/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc index 271e12f50c..a6aef17e7f 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc +++ b/onnxruntime/core/providers/nuphar/common/nuphar_settings.cc @@ -34,7 +34,6 @@ static const std::unordered_set valid_keys = { kNupharIMatMulForceMkl, kNupharMatmulExec, kNupharCachePath, - kNupharCacheVersion, kNupharCacheSoName, kNupharCacheModelChecksum, kNupharCacheForceNoJIT, diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_settings.h b/onnxruntime/core/providers/nuphar/common/nuphar_settings.h index 5d2c149186..303daa98e3 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_settings.h +++ b/onnxruntime/core/providers/nuphar/common/nuphar_settings.h @@ -14,7 +14,6 @@ constexpr static const char* kNupharDumpPartition = "nuphar_dump_partition"; constexpr static const char* kNupharDumpFusedNodes = "nuphar_dump_fused_nodes"; constexpr static const char* kNupharMatmulExec = "nuphar_matmul_exec"; constexpr static const char* kNupharCachePath = "nuphar_cache_path"; -constexpr static const char* kNupharCacheVersion = "nuphar_cache_version"; constexpr static const char* kNupharCacheSoName = "nuphar_cache_so_name"; constexpr static const char* kNupharCacheModelChecksum = "nuphar_cache_model_checksum"; constexpr static const char* kNupharCacheForceNoJIT = "nuphar_cache_force_no_jit"; @@ -48,13 +47,6 @@ constexpr static const char* kNupharCodeGenTarget = "nuphar_codegen_target"; // Option to control nuphar code to run with parallel schedule constexpr static const char* kNupharParallelMinWorkloads = "nuphar_parallel_min_workloads"; -// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/ -// 1. MAJOR version when you make incompatible changes that old cache files no longer work, -// 2. MINOR version when you add functionality in a backwards - compatible manner, and -// 3. PATCH version when you make backwards - compatible bug fixes. -// NOTE this version needs to be updated when generated code may change -constexpr static const char* kNupharCacheVersion_Current = "1.0.0"; - constexpr static const char* kNupharCacheSoName_Default = "jit.so"; void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc index fac9f2c552..db5b9a3fc1 100644 --- a/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc +++ b/onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc @@ -11,10 +11,12 @@ #include "core/common/logging/logging.h" #include "core/platform/env.h" #include "core/providers/common.h" +#include "core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION" #include "gsl/gsl" #include #include #include +#include #include namespace fs = std::experimental::filesystem; @@ -27,13 +29,6 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) { if (!settings.HasOption(kNupharCachePath)) return false; - std::string version; - if (settings.HasOption(kNupharCacheVersion)) { - version = settings.GetOptionValue(kNupharCacheVersion); - } else { - version = kNupharCacheVersion_Current; - } - path = settings.GetOptionValue(kNupharCachePath); if (!create && !fs::is_directory(path)) return false; @@ -43,7 +38,7 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) { throw std::runtime_error("Failed to create directory " + path.string()); } - path.append(version); + path.append(__NUPHAR_CACHE_VERSION__); if (!create && !fs::is_directory(path)) return false; @@ -80,6 +75,63 @@ static void* GetFuncFromLibrary(const std::string& so_path, const std::string& f return func; } +static void ParseVersion(const char* version, int* major, int* minor, int* patch) { + std::stringstream ss(version); + std::string val; + + auto ver_num_fn = [](const std::string& val) { + ORT_ENFORCE(!val.empty(), "Empty version number"); + if (val.length() > 1 && val[0] == '0') { + ORT_THROW("Invalid version number: ", val); + } + ORT_ENFORCE(std::all_of(val.begin(), val.end(), [](char c) { return isdigit(c); }), + "Invalid version number: ", val); + return std::stoi(val); + }; + + std::getline(ss, val, '.'); + ORT_ENFORCE(ss.good(), "Invalid version format: ", version); + *major = ver_num_fn(val); + + std::getline(ss, val, '.'); + *minor = ver_num_fn(val); + + std::getline(ss, val); + *patch = ver_num_fn(val); +} + +static void VerifyCacheVersion(const std::string& so_path) { + static std::atomic cache_version_checked{false}; + static std::mutex cache_version_mutex; + + // make sure we only check cache version once + if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) { + std::lock_guard lock(cache_version_mutex); + if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) { + cache_version_checked.store(true, std::memory_order::memory_order_release); + // ensure we have _ORTInternal_GetCacheVersion_ function + void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true); + ORT_ENFORCE(f, "NULL library function pointer!"); + + typedef const char* (*GetVersionFunc)(); + GetVersionFunc func = reinterpret_cast(f); + const char* cache_version = func(); + ORT_ENFORCE(cache_version, "Null cache version string!"); + int cur_major, cur_minor, cur_patch; + ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch); + int cache_major, cache_minor, cache_patch; + ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch); + + // make version check strict until we have thorough design for compatibility + ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor), + "Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__, + ") doesn't match cached dll version (", cache_version, ")"); + + cache_version_checked = true; + } + } +} + static bool disable_caching_due_to_checksum_failure = false; static bool VerifyTVMModuleChecksum(const std::string& so_path) { @@ -131,6 +183,8 @@ tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name if (!GetCacheSoFilePath(so_path)) return nullptr; + VerifyCacheVersion(so_path); + if (!VerifyTVMModuleChecksum(so_path)) return nullptr; diff --git a/onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION b/onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION new file mode 100644 index 0000000000..6566ef23ba --- /dev/null +++ b/onnxruntime/core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION @@ -0,0 +1,9 @@ +// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/ +// 1. MAJOR version when you make incompatible changes that old cache files no longer work, +// 2. MINOR version when you add functionality in a backwards - compatible manner, and +// 3. PATCH version when you make backwards - compatible bug fixes. +// NOTE this version needs to be updated when generated code may change + +#ifndef __NUPHAR_CACHE_VERSION__ +#define __NUPHAR_CACHE_VERSION__ "2.3.0" +#endif diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd b/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd index dac840a226..f446242d23 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.cmd @@ -6,6 +6,7 @@ setlocal EnableDelayedExpansion if "%1"=="" goto Usage +set SCRIPT_DIR=%~dp0 set CACHE_DIR=%~f1 set MODEL_FILE=%~f2 @@ -46,6 +47,16 @@ echo __declspec(dllexport) >>%CHECKSUM_CC% echo void _ORTInternal_GetCheckSum(const char*^& cs, size_t^& len) { >> %CHECKSUM_CC% echo cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;} >>%CHECKSUM_CC% +REM generate cache version +set CACHE_VERSION_CC=%CACHE_DIR%\cache_version.cc +set VERSION_FILE=%SCRIPT_DIR%NUPHAR_CACHE_VERSION +echo Generating %CACHE_VERSION_CC%... +echo #include "%VERSION_FILE%" >%CACHE_VERSION_CC% +echo extern "C" >>%CACHE_VERSION_CC% +echo __declspec(dllexport) >>%CACHE_VERSION_CC% +echo const char* _ORTInternal_GetCacheVersion() { >> %CACHE_VERSION_CC% +echo return __NUPHAR_CACHE_VERSION__;} >>%CACHE_VERSION_CC% + :Compile cd /d %CACHE_DIR% for /f %%i in ('dir /b *.cc') do ( @@ -61,4 +72,4 @@ exit /b :Usage echo Usage: %0 cache_dir [model_file] [output_dll] echo The generated file would be cache_dir\output_dll -exit /b \ No newline at end of file +exit /b diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.py b/onnxruntime/core/providers/nuphar/scripts/create_shared.py index 8b384dafad..a68e1fdc75 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.py +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.py @@ -38,6 +38,18 @@ def gen_checksum(file_checksum, input_dir): print(' cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;', file=checksum_cc) print('}', file=checksum_cc) +def gen_cache_version(input_dir): + name = 'ORTInternal_cache_version' + with open(os.path.join(input_dir, name + '.cc'), 'w') as cache_version_cc: + header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NUPHAR_CACHE_VERSION') + print('#include "{}"'.format(header_file), file=cache_version_cc) + print('extern "C"', file=cache_version_cc) + if is_windows(): + print('__declspec(dllexport)', file=cache_version_cc) + print('const char* _ORTInternal_GetCacheVersion() {', file=cache_version_cc) + print(' return __NUPHAR_CACHE_VERSION__;', file=cache_version_cc) + print('}', file=cache_version_cc) + def compile_all_cc(path): for f in os.listdir(path): name, ext = os.path.splitext(f) @@ -65,6 +77,8 @@ if __name__ == '__main__': input_checksum = gen_md5(args.input_model) gen_checksum(input_checksum, args.input_dir) + gen_cache_version(args.input_dir) + if is_windows(): # create dllmain name = 'ORTInternal_dllmain' @@ -85,4 +99,4 @@ if __name__ == '__main__': if not args.keep_input: for f in objs: - os.remove(os.path.join(args.input_dir, f)) \ No newline at end of file + os.remove(os.path.join(args.input_dir, f)) diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.sh b/onnxruntime/core/providers/nuphar/scripts/create_shared.sh index 05f0cff917..4428180636 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.sh +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.sh @@ -4,6 +4,8 @@ set -x -e -o pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" + function usage { echo Usage: create_shared.sh -c cache_dir -m input_model_file -o output_so_file echo The generated file would be cache_dir/output_so_file @@ -34,6 +36,8 @@ if ! [ -x "$(command -v g++)" ]; then exit 1 fi +declare -a all_cc_files + cd $CACHE_DIR if [ -x "$MODEL_FILE" ]; then # generate checksum.cc @@ -46,10 +50,25 @@ void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) { cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1; } __EOF__ - g++ -std=c++14 -fPIC -o checksum.o -c checksum.cc - rm checksum.cc + all_cc_files+=(checksum) fi +# generate cache_version.cc +VERSION_FILE="${SCRIPT_DIR}/NUPHAR_CACHE_VERSION" +cat > $CACHE_DIR/cache_version.cc <<__EOF__ +#include "$VERSION_FILE" +extern "C" +const char* _ORTInternal_GetCacheVersion() { + return __NUPHAR_CACHE_VERSION__; +} +__EOF__ +all_cc_files+=(cache_version) + +for cc_file in "${all_cc_files[@]}"; do + g++ -std=c++14 -fPIC -o "$cc_file".o -c "$cc_file".cc + rm "$cc_file".cc +done + # link if ls *.o 1> /dev/null 2>&1; then OBJS="" @@ -61,4 +80,4 @@ if ls *.o 1> /dev/null 2>&1; then g++ -shared -fPIC -o $CACHE_DIR/$OUTPUT_SO_FILE $OBJS fi rm *.o -fi \ No newline at end of file +fi diff --git a/setup.py b/setup.py index 70659c815b..b53a6519ac 100644 --- a/setup.py +++ b/setup.py @@ -165,6 +165,8 @@ examples = [path.join('datasets', x) for x in examples_names] # Extra files such as EULA and ThirdPartyNotices extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md"] +if package_name == 'onnxruntime-nuphar': + extra.extend([path.join('nuphar', 'NUPHAR_CACHE_VERSION')]) # Description README = path.join(getcwd(), "docs/python/README.rst")