mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
added cache version for nuphar JIT binaries (#2646)
* added cache version for nuphar JIT binaries Previously, when the user wrongfully loaded a JIT binary generated from a Nuphar version different from the current used one, she would get mysterious runtime failures, because we didn't perform any version check on JIT binaries. This change added cache versions to the Nuphar runtime and JIT binaries. The Nuphar runtime will issue verbose message that informs the user version-mismatch errors. * address CR feedback * include NUPHAR_CACHE_VERSION in python wheel
This commit is contained in:
parent
7c87070b24
commit
f7412899a1
9 changed files with 123 additions and 23 deletions
|
|
@ -251,7 +251,7 @@ endif()
|
|||
|
||||
if (onnxruntime_USE_NUPHAR)
|
||||
file(GLOB onnxruntime_python_nuphar_python_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*.*"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*"
|
||||
)
|
||||
add_custom_command(
|
||||
TARGET onnxruntime_pybind11_state POST_BUILD
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ static const std::unordered_set<std::string> valid_keys = {
|
|||
kNupharIMatMulForceMkl,
|
||||
kNupharMatmulExec,
|
||||
kNupharCachePath,
|
||||
kNupharCacheVersion,
|
||||
kNupharCacheSoName,
|
||||
kNupharCacheModelChecksum,
|
||||
kNupharCacheForceNoJIT,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ constexpr static const char* kNupharDumpPartition = "nuphar_dump_partition";
|
|||
constexpr static const char* kNupharDumpFusedNodes = "nuphar_dump_fused_nodes";
|
||||
constexpr static const char* kNupharMatmulExec = "nuphar_matmul_exec";
|
||||
constexpr static const char* kNupharCachePath = "nuphar_cache_path";
|
||||
constexpr static const char* kNupharCacheVersion = "nuphar_cache_version";
|
||||
constexpr static const char* kNupharCacheSoName = "nuphar_cache_so_name";
|
||||
constexpr static const char* kNupharCacheModelChecksum = "nuphar_cache_model_checksum";
|
||||
constexpr static const char* kNupharCacheForceNoJIT = "nuphar_cache_force_no_jit";
|
||||
|
|
@ -48,13 +47,6 @@ constexpr static const char* kNupharCodeGenTarget = "nuphar_codegen_target";
|
|||
// Option to control nuphar code to run with parallel schedule
|
||||
constexpr static const char* kNupharParallelMinWorkloads = "nuphar_parallel_min_workloads";
|
||||
|
||||
// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/
|
||||
// 1. MAJOR version when you make incompatible changes that old cache files no longer work,
|
||||
// 2. MINOR version when you add functionality in a backwards - compatible manner, and
|
||||
// 3. PATCH version when you make backwards - compatible bug fixes.
|
||||
// NOTE this version needs to be updated when generated code may change
|
||||
constexpr static const char* kNupharCacheVersion_Current = "1.0.0";
|
||||
|
||||
constexpr static const char* kNupharCacheSoName_Default = "jit.so";
|
||||
|
||||
void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info);
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@
|
|||
#include "core/common/logging/logging.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION"
|
||||
#include "gsl/gsl"
|
||||
#include <topi/detail/extern.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <experimental/filesystem>
|
||||
#include <atomic>
|
||||
#include <fstream>
|
||||
namespace fs = std::experimental::filesystem;
|
||||
|
||||
|
|
@ -27,13 +29,6 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) {
|
|||
if (!settings.HasOption(kNupharCachePath))
|
||||
return false;
|
||||
|
||||
std::string version;
|
||||
if (settings.HasOption(kNupharCacheVersion)) {
|
||||
version = settings.GetOptionValue(kNupharCacheVersion);
|
||||
} else {
|
||||
version = kNupharCacheVersion_Current;
|
||||
}
|
||||
|
||||
path = settings.GetOptionValue(kNupharCachePath);
|
||||
if (!create && !fs::is_directory(path))
|
||||
return false;
|
||||
|
|
@ -43,7 +38,7 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) {
|
|||
throw std::runtime_error("Failed to create directory " + path.string());
|
||||
}
|
||||
|
||||
path.append(version);
|
||||
path.append(__NUPHAR_CACHE_VERSION__);
|
||||
if (!create && !fs::is_directory(path))
|
||||
return false;
|
||||
|
||||
|
|
@ -80,6 +75,63 @@ static void* GetFuncFromLibrary(const std::string& so_path, const std::string& f
|
|||
return func;
|
||||
}
|
||||
|
||||
static void ParseVersion(const char* version, int* major, int* minor, int* patch) {
|
||||
std::stringstream ss(version);
|
||||
std::string val;
|
||||
|
||||
auto ver_num_fn = [](const std::string& val) {
|
||||
ORT_ENFORCE(!val.empty(), "Empty version number");
|
||||
if (val.length() > 1 && val[0] == '0') {
|
||||
ORT_THROW("Invalid version number: ", val);
|
||||
}
|
||||
ORT_ENFORCE(std::all_of(val.begin(), val.end(), [](char c) { return isdigit(c); }),
|
||||
"Invalid version number: ", val);
|
||||
return std::stoi(val);
|
||||
};
|
||||
|
||||
std::getline(ss, val, '.');
|
||||
ORT_ENFORCE(ss.good(), "Invalid version format: ", version);
|
||||
*major = ver_num_fn(val);
|
||||
|
||||
std::getline(ss, val, '.');
|
||||
*minor = ver_num_fn(val);
|
||||
|
||||
std::getline(ss, val);
|
||||
*patch = ver_num_fn(val);
|
||||
}
|
||||
|
||||
static void VerifyCacheVersion(const std::string& so_path) {
|
||||
static std::atomic<bool> cache_version_checked{false};
|
||||
static std::mutex cache_version_mutex;
|
||||
|
||||
// make sure we only check cache version once
|
||||
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
|
||||
std::lock_guard<std::mutex> lock(cache_version_mutex);
|
||||
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
|
||||
cache_version_checked.store(true, std::memory_order::memory_order_release);
|
||||
// ensure we have _ORTInternal_GetCacheVersion_ function
|
||||
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
|
||||
ORT_ENFORCE(f, "NULL library function pointer!");
|
||||
|
||||
typedef const char* (*GetVersionFunc)();
|
||||
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
|
||||
const char* cache_version = func();
|
||||
ORT_ENFORCE(cache_version, "Null cache version string!");
|
||||
int cur_major, cur_minor, cur_patch;
|
||||
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
|
||||
int cache_major, cache_minor, cache_patch;
|
||||
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
|
||||
|
||||
// make version check strict until we have thorough design for compatibility
|
||||
ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor),
|
||||
"Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__,
|
||||
") doesn't match cached dll version (", cache_version, ")");
|
||||
|
||||
cache_version_checked = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool disable_caching_due_to_checksum_failure = false;
|
||||
|
||||
static bool VerifyTVMModuleChecksum(const std::string& so_path) {
|
||||
|
|
@ -131,6 +183,8 @@ tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name
|
|||
if (!GetCacheSoFilePath(so_path))
|
||||
return nullptr;
|
||||
|
||||
VerifyCacheVersion(so_path);
|
||||
|
||||
if (!VerifyTVMModuleChecksum(so_path))
|
||||
return nullptr;
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/
|
||||
// 1. MAJOR version when you make incompatible changes that old cache files no longer work,
|
||||
// 2. MINOR version when you add functionality in a backwards - compatible manner, and
|
||||
// 3. PATCH version when you make backwards - compatible bug fixes.
|
||||
// NOTE this version needs to be updated when generated code may change
|
||||
|
||||
#ifndef __NUPHAR_CACHE_VERSION__
|
||||
#define __NUPHAR_CACHE_VERSION__ "2.3.0"
|
||||
#endif
|
||||
|
|
@ -6,6 +6,7 @@ setlocal EnableDelayedExpansion
|
|||
|
||||
if "%1"=="" goto Usage
|
||||
|
||||
set SCRIPT_DIR=%~dp0
|
||||
set CACHE_DIR=%~f1
|
||||
set MODEL_FILE=%~f2
|
||||
|
||||
|
|
@ -46,6 +47,16 @@ echo __declspec(dllexport) >>%CHECKSUM_CC%
|
|||
echo void _ORTInternal_GetCheckSum(const char*^& cs, size_t^& len) { >> %CHECKSUM_CC%
|
||||
echo cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;} >>%CHECKSUM_CC%
|
||||
|
||||
REM generate cache version
|
||||
set CACHE_VERSION_CC=%CACHE_DIR%\cache_version.cc
|
||||
set VERSION_FILE=%SCRIPT_DIR%NUPHAR_CACHE_VERSION
|
||||
echo Generating %CACHE_VERSION_CC%...
|
||||
echo #include "%VERSION_FILE%" >%CACHE_VERSION_CC%
|
||||
echo extern "C" >>%CACHE_VERSION_CC%
|
||||
echo __declspec(dllexport) >>%CACHE_VERSION_CC%
|
||||
echo const char* _ORTInternal_GetCacheVersion() { >> %CACHE_VERSION_CC%
|
||||
echo return __NUPHAR_CACHE_VERSION__;} >>%CACHE_VERSION_CC%
|
||||
|
||||
:Compile
|
||||
cd /d %CACHE_DIR%
|
||||
for /f %%i in ('dir /b *.cc') do (
|
||||
|
|
@ -61,4 +72,4 @@ exit /b
|
|||
:Usage
|
||||
echo Usage: %0 cache_dir [model_file] [output_dll]
|
||||
echo The generated file would be cache_dir\output_dll
|
||||
exit /b
|
||||
exit /b
|
||||
|
|
|
|||
|
|
@ -38,6 +38,18 @@ def gen_checksum(file_checksum, input_dir):
|
|||
print(' cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;', file=checksum_cc)
|
||||
print('}', file=checksum_cc)
|
||||
|
||||
def gen_cache_version(input_dir):
|
||||
name = 'ORTInternal_cache_version'
|
||||
with open(os.path.join(input_dir, name + '.cc'), 'w') as cache_version_cc:
|
||||
header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NUPHAR_CACHE_VERSION')
|
||||
print('#include "{}"'.format(header_file), file=cache_version_cc)
|
||||
print('extern "C"', file=cache_version_cc)
|
||||
if is_windows():
|
||||
print('__declspec(dllexport)', file=cache_version_cc)
|
||||
print('const char* _ORTInternal_GetCacheVersion() {', file=cache_version_cc)
|
||||
print(' return __NUPHAR_CACHE_VERSION__;', file=cache_version_cc)
|
||||
print('}', file=cache_version_cc)
|
||||
|
||||
def compile_all_cc(path):
|
||||
for f in os.listdir(path):
|
||||
name, ext = os.path.splitext(f)
|
||||
|
|
@ -65,6 +77,8 @@ if __name__ == '__main__':
|
|||
input_checksum = gen_md5(args.input_model)
|
||||
gen_checksum(input_checksum, args.input_dir)
|
||||
|
||||
gen_cache_version(args.input_dir)
|
||||
|
||||
if is_windows():
|
||||
# create dllmain
|
||||
name = 'ORTInternal_dllmain'
|
||||
|
|
@ -85,4 +99,4 @@ if __name__ == '__main__':
|
|||
|
||||
if not args.keep_input:
|
||||
for f in objs:
|
||||
os.remove(os.path.join(args.input_dir, f))
|
||||
os.remove(os.path.join(args.input_dir, f))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
set -x -e -o pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
|
||||
|
||||
function usage {
|
||||
echo Usage: create_shared.sh -c cache_dir -m input_model_file -o output_so_file
|
||||
echo The generated file would be cache_dir/output_so_file
|
||||
|
|
@ -34,6 +36,8 @@ if ! [ -x "$(command -v g++)" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
declare -a all_cc_files
|
||||
|
||||
cd $CACHE_DIR
|
||||
if [ -x "$MODEL_FILE" ]; then
|
||||
# generate checksum.cc
|
||||
|
|
@ -46,10 +50,25 @@ void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) {
|
|||
cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;
|
||||
}
|
||||
__EOF__
|
||||
g++ -std=c++14 -fPIC -o checksum.o -c checksum.cc
|
||||
rm checksum.cc
|
||||
all_cc_files+=(checksum)
|
||||
fi
|
||||
|
||||
# generate cache_version.cc
|
||||
VERSION_FILE="${SCRIPT_DIR}/NUPHAR_CACHE_VERSION"
|
||||
cat > $CACHE_DIR/cache_version.cc <<__EOF__
|
||||
#include "$VERSION_FILE"
|
||||
extern "C"
|
||||
const char* _ORTInternal_GetCacheVersion() {
|
||||
return __NUPHAR_CACHE_VERSION__;
|
||||
}
|
||||
__EOF__
|
||||
all_cc_files+=(cache_version)
|
||||
|
||||
for cc_file in "${all_cc_files[@]}"; do
|
||||
g++ -std=c++14 -fPIC -o "$cc_file".o -c "$cc_file".cc
|
||||
rm "$cc_file".cc
|
||||
done
|
||||
|
||||
# link
|
||||
if ls *.o 1> /dev/null 2>&1; then
|
||||
OBJS=""
|
||||
|
|
@ -61,4 +80,4 @@ if ls *.o 1> /dev/null 2>&1; then
|
|||
g++ -shared -fPIC -o $CACHE_DIR/$OUTPUT_SO_FILE $OBJS
|
||||
fi
|
||||
rm *.o
|
||||
fi
|
||||
fi
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -165,6 +165,8 @@ examples = [path.join('datasets', x) for x in examples_names]
|
|||
|
||||
# Extra files such as EULA and ThirdPartyNotices
|
||||
extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md"]
|
||||
if package_name == 'onnxruntime-nuphar':
|
||||
extra.extend([path.join('nuphar', 'NUPHAR_CACHE_VERSION')])
|
||||
|
||||
# Description
|
||||
README = path.join(getcwd(), "docs/python/README.rst")
|
||||
|
|
|
|||
Loading…
Reference in a new issue