added cache version for nuphar JIT binaries (#2646)

* added cache version for nuphar JIT binaries

Previously, when the user wrongfully loaded a JIT binary generated
from a Nuphar version different from the current used one, she
would get mysterious runtime failures, because we didn't perform
any version check on JIT binaries.

This change added cache versions to the Nuphar runtime and
JIT binaries. The Nuphar runtime will issue verbose message that
informs the user version-mismatch errors.

* address CR feedback

* include NUPHAR_CACHE_VERSION in python wheel
This commit is contained in:
Yang Chen 2019-12-14 22:46:30 -08:00 committed by GitHub
parent 7c87070b24
commit f7412899a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 123 additions and 23 deletions

View file

@ -251,7 +251,7 @@ endif()
if (onnxruntime_USE_NUPHAR)
file(GLOB onnxruntime_python_nuphar_python_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*.*"
"${ONNXRUNTIME_ROOT}/core/providers/nuphar/scripts/*"
)
add_custom_command(
TARGET onnxruntime_pybind11_state POST_BUILD

View file

@ -34,7 +34,6 @@ static const std::unordered_set<std::string> valid_keys = {
kNupharIMatMulForceMkl,
kNupharMatmulExec,
kNupharCachePath,
kNupharCacheVersion,
kNupharCacheSoName,
kNupharCacheModelChecksum,
kNupharCacheForceNoJIT,

View file

@ -14,7 +14,6 @@ constexpr static const char* kNupharDumpPartition = "nuphar_dump_partition";
constexpr static const char* kNupharDumpFusedNodes = "nuphar_dump_fused_nodes";
constexpr static const char* kNupharMatmulExec = "nuphar_matmul_exec";
constexpr static const char* kNupharCachePath = "nuphar_cache_path";
constexpr static const char* kNupharCacheVersion = "nuphar_cache_version";
constexpr static const char* kNupharCacheSoName = "nuphar_cache_so_name";
constexpr static const char* kNupharCacheModelChecksum = "nuphar_cache_model_checksum";
constexpr static const char* kNupharCacheForceNoJIT = "nuphar_cache_force_no_jit";
@ -48,13 +47,6 @@ constexpr static const char* kNupharCodeGenTarget = "nuphar_codegen_target";
// Option to control nuphar code to run with parallel schedule
constexpr static const char* kNupharParallelMinWorkloads = "nuphar_parallel_min_workloads";
// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/
// 1. MAJOR version when you make incompatible changes that old cache files no longer work,
// 2. MINOR version when you add functionality in a backwards - compatible manner, and
// 3. PATCH version when you make backwards - compatible bug fixes.
// NOTE this version needs to be updated when generated code may change
constexpr static const char* kNupharCacheVersion_Current = "1.0.0";
constexpr static const char* kNupharCacheSoName_Default = "jit.so";
void CreateNupharCodeGenSettings(const NupharExecutionProviderInfo& info);

View file

@ -11,10 +11,12 @@
#include "core/common/logging/logging.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
#include "core/providers/nuphar/scripts/NUPHAR_CACHE_VERSION"
#include "gsl/gsl"
#include <topi/detail/extern.h>
#include <tvm/ir_pass.h>
#include <experimental/filesystem>
#include <atomic>
#include <fstream>
namespace fs = std::experimental::filesystem;
@ -27,13 +29,6 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) {
if (!settings.HasOption(kNupharCachePath))
return false;
std::string version;
if (settings.HasOption(kNupharCacheVersion)) {
version = settings.GetOptionValue(kNupharCacheVersion);
} else {
version = kNupharCacheVersion_Current;
}
path = settings.GetOptionValue(kNupharCachePath);
if (!create && !fs::is_directory(path))
return false;
@ -43,7 +38,7 @@ static bool GetOrCreateTVMModuleCacheDirectory(fs::path& path, bool create) {
throw std::runtime_error("Failed to create directory " + path.string());
}
path.append(version);
path.append(__NUPHAR_CACHE_VERSION__);
if (!create && !fs::is_directory(path))
return false;
@ -80,6 +75,63 @@ static void* GetFuncFromLibrary(const std::string& so_path, const std::string& f
return func;
}
static void ParseVersion(const char* version, int* major, int* minor, int* patch) {
std::stringstream ss(version);
std::string val;
auto ver_num_fn = [](const std::string& val) {
ORT_ENFORCE(!val.empty(), "Empty version number");
if (val.length() > 1 && val[0] == '0') {
ORT_THROW("Invalid version number: ", val);
}
ORT_ENFORCE(std::all_of(val.begin(), val.end(), [](char c) { return isdigit(c); }),
"Invalid version number: ", val);
return std::stoi(val);
};
std::getline(ss, val, '.');
ORT_ENFORCE(ss.good(), "Invalid version format: ", version);
*major = ver_num_fn(val);
std::getline(ss, val, '.');
*minor = ver_num_fn(val);
std::getline(ss, val);
*patch = ver_num_fn(val);
}
static void VerifyCacheVersion(const std::string& so_path) {
static std::atomic<bool> cache_version_checked{false};
static std::mutex cache_version_mutex;
// make sure we only check cache version once
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
std::lock_guard<std::mutex> lock(cache_version_mutex);
if (!cache_version_checked.load(std::memory_order::memory_order_acquire)) {
cache_version_checked.store(true, std::memory_order::memory_order_release);
// ensure we have _ORTInternal_GetCacheVersion_ function
void* f = GetFuncFromLibrary(so_path, "_ORTInternal_GetCacheVersion", /*throw_if_not_found*/ true);
ORT_ENFORCE(f, "NULL library function pointer!");
typedef const char* (*GetVersionFunc)();
GetVersionFunc func = reinterpret_cast<GetVersionFunc>(f);
const char* cache_version = func();
ORT_ENFORCE(cache_version, "Null cache version string!");
int cur_major, cur_minor, cur_patch;
ParseVersion(__NUPHAR_CACHE_VERSION__, &cur_major, &cur_minor, &cur_patch);
int cache_major, cache_minor, cache_patch;
ParseVersion(cache_version, &cache_major, &cache_minor, &cache_patch);
// make version check strict until we have thorough design for compatibility
ORT_ENFORCE((cur_major == cache_major) && (cur_minor == cache_minor),
"Current nuphar runtime version (", __NUPHAR_CACHE_VERSION__,
") doesn't match cached dll version (", cache_version, ")");
cache_version_checked = true;
}
}
}
static bool disable_caching_due_to_checksum_failure = false;
static bool VerifyTVMModuleChecksum(const std::string& so_path) {
@ -131,6 +183,8 @@ tvm::runtime::PackedFunc LoadTVMPackedFuncFromCache(const std::string& func_name
if (!GetCacheSoFilePath(so_path))
return nullptr;
VerifyCacheVersion(so_path);
if (!VerifyTVMModuleChecksum(so_path))
return nullptr;

View file

@ -0,0 +1,9 @@
// cache version number (MAJOR.MINOR.PATCH) following https://semver.org/
// 1. MAJOR version when you make incompatible changes that old cache files no longer work,
// 2. MINOR version when you add functionality in a backwards - compatible manner, and
// 3. PATCH version when you make backwards - compatible bug fixes.
// NOTE this version needs to be updated when generated code may change
#ifndef __NUPHAR_CACHE_VERSION__
#define __NUPHAR_CACHE_VERSION__ "2.3.0"
#endif

View file

@ -6,6 +6,7 @@ setlocal EnableDelayedExpansion
if "%1"=="" goto Usage
set SCRIPT_DIR=%~dp0
set CACHE_DIR=%~f1
set MODEL_FILE=%~f2
@ -46,6 +47,16 @@ echo __declspec(dllexport) >>%CHECKSUM_CC%
echo void _ORTInternal_GetCheckSum(const char*^& cs, size_t^& len) { >> %CHECKSUM_CC%
echo cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;} >>%CHECKSUM_CC%
REM generate cache version
set CACHE_VERSION_CC=%CACHE_DIR%\cache_version.cc
set VERSION_FILE=%SCRIPT_DIR%NUPHAR_CACHE_VERSION
echo Generating %CACHE_VERSION_CC%...
echo #include "%VERSION_FILE%" >%CACHE_VERSION_CC%
echo extern "C" >>%CACHE_VERSION_CC%
echo __declspec(dllexport) >>%CACHE_VERSION_CC%
echo const char* _ORTInternal_GetCacheVersion() { >> %CACHE_VERSION_CC%
echo return __NUPHAR_CACHE_VERSION__;} >>%CACHE_VERSION_CC%
:Compile
cd /d %CACHE_DIR%
for /f %%i in ('dir /b *.cc') do (
@ -61,4 +72,4 @@ exit /b
:Usage
echo Usage: %0 cache_dir [model_file] [output_dll]
echo The generated file would be cache_dir\output_dll
exit /b
exit /b

View file

@ -38,6 +38,18 @@ def gen_checksum(file_checksum, input_dir):
print(' cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;', file=checksum_cc)
print('}', file=checksum_cc)
def gen_cache_version(input_dir):
name = 'ORTInternal_cache_version'
with open(os.path.join(input_dir, name + '.cc'), 'w') as cache_version_cc:
header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NUPHAR_CACHE_VERSION')
print('#include "{}"'.format(header_file), file=cache_version_cc)
print('extern "C"', file=cache_version_cc)
if is_windows():
print('__declspec(dllexport)', file=cache_version_cc)
print('const char* _ORTInternal_GetCacheVersion() {', file=cache_version_cc)
print(' return __NUPHAR_CACHE_VERSION__;', file=cache_version_cc)
print('}', file=cache_version_cc)
def compile_all_cc(path):
for f in os.listdir(path):
name, ext = os.path.splitext(f)
@ -65,6 +77,8 @@ if __name__ == '__main__':
input_checksum = gen_md5(args.input_model)
gen_checksum(input_checksum, args.input_dir)
gen_cache_version(args.input_dir)
if is_windows():
# create dllmain
name = 'ORTInternal_dllmain'
@ -85,4 +99,4 @@ if __name__ == '__main__':
if not args.keep_input:
for f in objs:
os.remove(os.path.join(args.input_dir, f))
os.remove(os.path.join(args.input_dir, f))

View file

@ -4,6 +4,8 @@
set -x -e -o pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
function usage {
echo Usage: create_shared.sh -c cache_dir -m input_model_file -o output_so_file
echo The generated file would be cache_dir/output_so_file
@ -34,6 +36,8 @@ if ! [ -x "$(command -v g++)" ]; then
exit 1
fi
declare -a all_cc_files
cd $CACHE_DIR
if [ -x "$MODEL_FILE" ]; then
# generate checksum.cc
@ -46,10 +50,25 @@ void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) {
cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;
}
__EOF__
g++ -std=c++14 -fPIC -o checksum.o -c checksum.cc
rm checksum.cc
all_cc_files+=(checksum)
fi
# generate cache_version.cc
VERSION_FILE="${SCRIPT_DIR}/NUPHAR_CACHE_VERSION"
cat > $CACHE_DIR/cache_version.cc <<__EOF__
#include "$VERSION_FILE"
extern "C"
const char* _ORTInternal_GetCacheVersion() {
return __NUPHAR_CACHE_VERSION__;
}
__EOF__
all_cc_files+=(cache_version)
for cc_file in "${all_cc_files[@]}"; do
g++ -std=c++14 -fPIC -o "$cc_file".o -c "$cc_file".cc
rm "$cc_file".cc
done
# link
if ls *.o 1> /dev/null 2>&1; then
OBJS=""
@ -61,4 +80,4 @@ if ls *.o 1> /dev/null 2>&1; then
g++ -shared -fPIC -o $CACHE_DIR/$OUTPUT_SO_FILE $OBJS
fi
rm *.o
fi
fi

View file

@ -165,6 +165,8 @@ examples = [path.join('datasets', x) for x in examples_names]
# Extra files such as EULA and ThirdPartyNotices
extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md"]
if package_name == 'onnxruntime-nuphar':
extra.extend([path.join('nuphar', 'NUPHAR_CACHE_VERSION')])
# Description
README = path.join(getcwd(), "docs/python/README.rst")