Offline tuning (#14558)

Add the ability to get and set tuning results of an inference session.
Also add tool to manipulate onnx file to embed the results into the
model file and automatically load it on session initialization.
This commit is contained in:
cloudhan 2023-02-15 14:17:34 +08:00 committed by GitHub
parent f638c5a2ae
commit a216c9a3fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 868 additions and 9 deletions

View file

@ -136,12 +136,20 @@ class TunableOp {
ITuningContext* ctx = params->TuningContext();
if (ctx->IsTunableOpEnabled()) {
auto& mgr = ctx->GetTuningResultsManager();
id = mgr.Lookup(Signature(), params->Signature());
auto op_sig = Signature();
auto params_sig = params->Signature();
id = mgr.Lookup(op_sig, params_sig);
if (id > static_cast<int>(ops_.size())) {
LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig
<< ", id:" << id << ", registered op:" << ops_.size();
mgr.Delete(op_sig, params_sig);
id = -1;
}
if (id < 0) {
auto maybe_proxy_params = PreTuning(params);
id = FindFastest(maybe_proxy_params);
PostTuning(maybe_proxy_params);
mgr.Add(Signature(), params->Signature(), id);
mgr.Add(op_sig, params_sig, id);
}
}
ORT_RETURN_IF_ERROR(ops_[id](params));

View file

@ -17,6 +17,7 @@ class TuningResultsValidator;
class ITuningContext {
public:
explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {}
virtual ~ITuningContext() = default;
virtual void EnableTunableOp() = 0;
@ -25,6 +26,14 @@ class ITuningContext {
virtual TuningResultsManager& GetTuningResultsManager() = 0;
virtual const TuningResultsManager& GetTuningResultsManager() const = 0;
virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0;
virtual TuningResults GetTuningResults() const;
virtual Status LoadTuningResults(const TuningResults& tr);
protected:
IExecutionProvider* ep_;
};
class TuningResultsManager {
@ -36,6 +45,7 @@ class TuningResultsManager {
int Lookup(const std::string& op_signature, const std::string& params_signature) const;
void Add(const std::string& op_signature, const std::string& params_signature, int best_id);
void Delete(const std::string& op_signature, const std::string& params_signature);
void Load(const std::unordered_map<std::string, KernelMap>& results_to_load);
std::unordered_map<std::string, KernelMap> Dump() const;
@ -50,4 +60,35 @@ class TuningResultsManager {
std::unordered_map<std::string, KernelMap> results_;
};
class TuningResultsValidator {
public:
using GetFunc = std::function<std::string()>;
using ValidateFunc = std::function<Status(const std::string&)>;
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
TuningResultsValidator();
virtual ~TuningResultsValidator() = default;
std::unordered_map<std::string, std::string> GetAllValidators() const;
Status ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
protected:
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
virtual std::string GetOrtVersion() const;
virtual Status ValidateOrtVersion(const std::string& value) const;
virtual std::string GetOrtGitCommit() const;
virtual Status ValidateOrtGitCommit(const std::string& value) const;
virtual std::string GetOrtBuildConfig() const;
virtual Status ValidateOrtBuildConfig(const std::string& value) const;
public:
static constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"};
private:
GetValidateFuncs validators_;
};
} // namespace onnxruntime

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file contains the implementation of TuningResultsManager. At the moment, there is no necessity to expose these
// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these
// methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries
// are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem,
// the EP must has and only has one translation unit include this file.
@ -11,12 +11,32 @@
#pragma once
#include <functional>
#include <unordered_set>
#include <utility>
#include "core/framework/tunable.h"
#include "core/framework/tuning_context.h"
#include "core/framework/tuning_results.h"
namespace onnxruntime {
TuningResults ITuningContext::GetTuningResults() const {
TuningResults tr;
tr.ep = ep_->Type();
tr.validators = GetTuningResultsValidator().GetAllValidators();
tr.results = GetTuningResultsManager().Dump();
return tr;
}
Status ITuningContext::LoadTuningResults(const TuningResults& tr) {
ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch");
LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep;
ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators));
GetTuningResultsManager().Load(tr.results);
return Status::OK();
}
KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const {
std::scoped_lock l{lock_};
auto it = results_.find(op_signature);
@ -56,6 +76,7 @@ inline void AddImpl(const std::string& op_signature,
return;
}
LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ") -> " << best_id;
kernel_map[params_signature] = best_id;
}
@ -70,6 +91,24 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin
AddImpl(op_signature, params_signature, best_id, it->second);
}
// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_};
auto it = results_.find(op_signature);
if (it == results_.end()) {
return;
}
auto it2 = it->second.find(params_signature);
if (it2 == it->second.end()) {
return;
}
LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ")";
it->second.erase(it2);
}
std::unordered_map<std::string, KernelMap> TuningResultsManager::Dump() const {
std::scoped_lock l{lock_};
return results_;
@ -81,6 +120,9 @@ void DisjointMergeImpl(
/*out*/ std::unordered_map<std::string, KernelMap>& results) {
auto it = results.find(op_signature);
if (it == results.end()) {
for(const auto& [param_sig, kernel_id] : kernel_map) {
LOGS_DEFAULT(VERBOSE) << op_signature << "(" << param_sig << ") -> " << kernel_id;
}
results[op_signature] = kernel_map;
return;
}
@ -106,4 +148,138 @@ void TuningResultsManager::Clear() {
results_ = {};
}
static Status CheckMandatoryKeys(
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
const std::unordered_map<std::string, std::string>& to_check) {
bool passed = true;
std::ostringstream oss;
for (const auto& k : TuningResultsValidator::mandatory_keys) {
if (gv_funcs.find(k) == gv_funcs.end()) {
passed = false;
oss << "key=\"" << k << "\" is not registered for Get and Validate. ";
}
if (to_check.find(k) == to_check.end()) {
passed = false;
oss << "key=\"" << k << "\" is not provided for validation. ";
}
}
ORT_RETURN_IF(!passed, oss.str());
return Status::OK();
}
static Status CheckKeysMatching(
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
const std::unordered_map<std::string, std::string>& to_check) {
auto get_keys = [](const auto& it) -> std::string { return it.first; };
std::vector<std::string> required_keys;
std::vector<std::string> provided_keys;
std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys);
std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys);
std::sort(required_keys.begin(), required_keys.end());
std::sort(provided_keys.begin(), provided_keys.end());
std::unordered_set<std::string> intersection;
std::set_intersection(required_keys.cbegin(), required_keys.cend(),
provided_keys.cbegin(), provided_keys.cend(),
std::inserter(intersection, intersection.end()));
bool matched = true;
std::ostringstream oss;
if (intersection.size() != required_keys.size()) {
matched = false;
for (const auto& k : required_keys) {
if (intersection.find(k) == intersection.end()) {
oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. ";
}
}
}
if (intersection.size() != provided_keys.size()) {
matched = false;
for (const auto& k : provided_keys) {
if (intersection.find(k) == intersection.end()) {
oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. ";
}
}
}
ORT_RETURN_IF(!matched, oss.str());
return Status::OK();
}
std::string TuningResultsValidator::GetOrtVersion() const {
return ORT_VERSION;
}
Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const {
ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch");
return Status::OK();
}
std::string TuningResultsValidator::GetOrtGitCommit() const {
// TODO:
return "";
}
Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const {
// TODO:
ORT_UNUSED_PARAMETER(value);
return Status::OK();
}
std::string TuningResultsValidator::GetOrtBuildConfig() const {
return "";
}
Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const {
auto current = GetOrtBuildConfig();
ORT_RETURN_IF(current != value,
"onnxruntime building configuration mismatch: tuning results produced with library \"",
value, "\", current library built with \"", current, "\"");
return Status::OK();
}
TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"ORT_VERSION",
[this]() { return GetOrtVersion(); },
[this](auto&& k) { return ValidateOrtVersion(std::forward<decltype(k)>(k)); });
RegisterValidator(
"ORT_GIT_COMMIT",
[this]() { return GetOrtGitCommit(); },
[this](auto&& k) { return ValidateOrtGitCommit(std::forward<decltype(k)>(k)); });
RegisterValidator(
"ORT_BUILD_CONFIG",
[this]() { return GetOrtBuildConfig(); },
[this](auto&& k) { return ValidateOrtBuildConfig(std::forward<decltype(k)>(k)); });
}
Status TuningResultsValidator::ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const {
ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate));
ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate));
for (const auto& [key, value] : to_validate) {
const auto& it = validators_.find(key);
ORT_ENFORCE(it != validators_.cend());
const ValidateFunc& validator = it->second.second;
ORT_RETURN_IF_ERROR(validator(value));
}
return Status::OK();
}
std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const {
std::unordered_map<std::string, std::string> ret;
for (const auto& [key, get_validate_func_pair] : validators_) {
const GetFunc& getter = get_validate_func_pair.first;
ret[key] = getter();
}
return ret;
}
void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) {
ORT_ENFORCE(validators_.find(key) == validators_.end());
validators_[key] = std::make_pair(gf, vf);
}
} // namespace onnxruntime

View file

@ -14,7 +14,40 @@ namespace onnxruntime {
namespace cuda {
namespace tunable {
CudaTuningContext::CudaTuningContext(CUDAExecutionProvider*, TunableOpInfo* info) : info_(info) {}
static std::string GetCudaVersion() {
int version;
CUDA_CALL_THROW(cudaRuntimeGetVersion(&version));
return std::to_string(version);
}
static Status ValidateCudaVersion(const std::string& value) {
auto current = GetCudaVersion();
ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value,
", onnxruntime currently run with CUDA ", current);
return Status::OK();
}
std::string CudaTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}
Status CudaTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}
CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep) : ep_(ep) {
RegisterValidator("CUDA_VERSION", GetCudaVersion, ValidateCudaVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}
CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info)
: ITuningContext(ep), info_(info), validator_(ep) {}
void CudaTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider";
@ -38,6 +71,10 @@ const TuningResultsManager& CudaTuningContext::GetTuningResultsManager() const {
return manager_;
}
const TuningResultsValidator& CudaTuningContext::GetTuningResultsValidator() const {
return validator_;
}
} // namespace tunable
} // namespace cuda
} // namespace onnxruntime

View file

@ -15,6 +15,18 @@ class CUDAExecutionProvider;
namespace cuda {
namespace tunable {
class CudaTuningResultsValidator : public TuningResultsValidator {
public:
CudaTuningResultsValidator(CUDAExecutionProvider* ep);
protected:
std::string GetDeviceModel() const;
Status ValidateDeviceModel(const std::string& value) const;
private:
CUDAExecutionProvider* ep_; // non-owning handle
};
class CudaTuningContext : public ITuningContext {
public:
explicit CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info);
@ -26,9 +38,12 @@ class CudaTuningContext : public ITuningContext {
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;
const TuningResultsValidator& GetTuningResultsValidator() const override;
private:
TunableOpInfo* info_; // non-owning handle
TuningResultsManager manager_;
CudaTuningResultsValidator validator_;
};
} // namespace tunable

View file

@ -14,7 +14,7 @@
#include "core/providers/rocm/rocm_pch.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
#include "core/providers/rocm/tunable/rocm_tuning_context.h"
namespace onnxruntime {

View file

@ -14,7 +14,66 @@ namespace onnxruntime {
namespace rocm {
namespace tunable {
RocmTuningContext::RocmTuningContext(ROCMExecutionProvider*, TunableOpInfo* info) : info_(info) {}
static std::string GetHipVersion() {
int version;
HIP_CALL_THROW(hipRuntimeGetVersion(&version));
return std::to_string(version);
}
static Status ValidateHipVersion(const std::string& value) {
auto current = GetHipVersion();
ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value,
", onnxruntime currently run with HIP ", current);
return Status::OK();
}
static std::string GetRocBlasVersion() {
char buf[64];
ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256));
buf[63] = '\0';
return buf;
}
static Status ValidateRocBlasVersion(const std::string& value) {
auto current = GetRocBlasVersion();
ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value,
", onnxruntime currently run with rocblas ", current);
return Status::OK();
}
std::string RocmTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}
Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}
RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} {
RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion);
RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}
std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
std::ostringstream oss;
oss << "USE_CK=" << USE_COMPOSABLE_KERNEL << "|";
#ifdef USE_ROCBLAS_EXTENSION_API
oss << "USE_ROCBLAS_EXTENSION_API=" << 1 << "|";
#else
oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|";
#endif
return oss.str();
}
RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info)
: ITuningContext(ep), info_(info), validator_(ep) {}
void RocmTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider";
@ -38,6 +97,10 @@ const TuningResultsManager& RocmTuningContext::GetTuningResultsManager() const {
return manager_;
}
const TuningResultsValidator& RocmTuningContext::GetTuningResultsValidator() const {
return validator_;
}
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -15,6 +15,20 @@ class ROCMExecutionProvider;
namespace rocm {
namespace tunable {
class RocmTuningResultsValidator : public TuningResultsValidator {
public:
RocmTuningResultsValidator(ROCMExecutionProvider* ep);
protected:
std::string GetOrtBuildConfig() const override;
std::string GetDeviceModel() const;
Status ValidateDeviceModel(const std::string& value) const;
private:
ROCMExecutionProvider* ep_; // non-owning handle
};
class RocmTuningContext : public ITuningContext {
public:
explicit RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info);
@ -26,9 +40,12 @@ class RocmTuningContext : public ITuningContext {
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;
const TuningResultsValidator& GetTuningResultsValidator() const override;
private:
TunableOpInfo* info_; // non-owning handle
TuningResultsManager manager_;
RocmTuningResultsValidator validator_;
};
} // namespace tunable

View file

@ -1500,6 +1500,14 @@ common::Status InferenceSession::Initialize() {
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));
}
}
std::vector<TuningResults> tuning_results;
bool found_tuning_results = false;
ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata(
model_metadata_, tuning_results, found_tuning_results));
if (found_tuning_results) {
ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results));
}
#endif // !defined(ORT_MINIMAL_BUILD)
// Resolve memory pattern flags of the main graph and subgraph session states
@ -2182,6 +2190,50 @@ const profiling::Profiler& InferenceSession::GetProfiling() const {
return session_profiler_;
}
#if !defined(ORT_MINIMAL_BUILD)
std::vector<TuningResults> InferenceSession::GetTuningResults() const {
std::vector<TuningResults> ret;
for (const auto& provider : execution_providers_) {
const auto* tuning_ctx = provider->GetTuningContext();
if (tuning_ctx != nullptr) {
ret.emplace_back(tuning_ctx->GetTuningResults());
}
}
return ret;
}
Status InferenceSession::SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid) {
std::string msg;
for (size_t i = 0; i < trs.size(); i++) {
const auto& tr = trs[i];
auto* provider = execution_providers_.Get(tr.ep);
if (provider == nullptr) {
msg = MakeString("Cannot find execution provider ", tr.ep);
ORT_RETURN_IF(error_on_invalid, msg);
LOGS(*session_logger_, WARNING) << msg;
continue;
}
auto* tuning_ctx = provider->GetTuningContext();
if (tuning_ctx == nullptr) {
msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp.");
ORT_RETURN_IF(error_on_invalid, msg);
LOGS(*session_logger_, WARNING) << msg;
continue;
}
auto status = tuning_ctx->LoadTuningResults(tr);
if (!status.IsOK()) {
msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage());
ORT_RETURN_IF(error_on_invalid, msg);
LOGS(*session_logger_, WARNING) << msg;
}
}
return Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD)
AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const {
return session_state_->GetAllocator(mem_info);
}

View file

@ -18,6 +18,7 @@
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/prepacked_weights_container.h"
#include "core/framework/session_state.h"
#include "core/framework/tuning_results.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/graph_transformer_mgr.h"
@ -448,6 +449,23 @@ class InferenceSession {
*/
const profiling::Profiler& GetProfiling() const;
#if !defined(ORT_MINIMAL_BUILD)
/**
* Get the TuningResults of TunableOp for every execution providers.
* @return The TuningResults of each execution provider.
*/
std::vector<TuningResults> GetTuningResults() const;
/**
* Set the TuningResults back to each execution provider. Mainly for offline tuning.
* @param trs is the list of TuningResults to be loaded.
* @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced.
* @return OK if success.
*/
Status SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid = false);
#endif
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
MemoryProfiler& GetMemoryProfiler() {
return memory_profiler_;

View file

@ -103,6 +103,13 @@ static Status SetEnableProfiling(SessionOptions& session_options,
return Status::OK();
}
// This function is called by nlohmann/json
void from_json(const json& j, TuningResults& trs) {
j.at("ep").get_to(trs.ep);
j.at("results").get_to(trs.results);
j.at("validators").get_to(trs.validators);
}
//---------------------------------------------------
//--- end of session options related helpers ---
//---------------------------------------------------
@ -227,6 +234,36 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options
"Parsing RunOptions from ModelProto is not supported yet");
}
Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
std::vector<TuningResults>& results,
bool& key_found) {
results.clear();
key_found = false;
auto it = metadata.custom_metadata_map.find(kTuningResultsKeys);
if (it == metadata.custom_metadata_map.end()) {
return Status::OK();
}
key_found = true;
LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model";
Status status;
ORT_TRY {
auto parsed_tuning_results_json = json::parse(it->second);
results = parsed_tuning_results_json.get<std::vector<TuningResults>>();
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = ORT_MAKE_STATUS(
ONNXRUNTIME, FAIL,
"Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring...");
});
ORT_RETURN_IF_ERROR(status);
}
return Status::OK();
}
} // namespace inference_session_utils
} // namespace onnxruntime

View file

@ -12,6 +12,7 @@
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "core/framework/session_options.h"
#include "core/framework/tuning_results.h"
#include "core/common/common.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;
@ -30,6 +31,7 @@ static constexpr const char* kOrtLoadConfigFromModelEnvVar = "ORT_LOAD_CONFIG_FR
//
static constexpr const char* kOrtConfigKey = "ort_config";
static constexpr const char* kSessionOptionsKey = "session_options";
static constexpr const char* kTuningResultsKeys = "tuning_results";
class JsonConfigParser {
public:
@ -56,6 +58,10 @@ class JsonConfigParser {
bool is_ort_config_json_available_ = false;
};
Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
/*out*/ std::vector<TuningResults>& results,
/*out*/ bool& key_found);
#endif // !defined(ORT_MINIMAL_BUILD)
} // namespace inference_session_utils

View file

@ -286,6 +286,12 @@ class Session:
"""
self._sess.run_with_iobinding(iobinding._iobinding, run_options)
def get_tuning_results(self):
return self._sess.get_tuning_results()
def set_tuning_results(self, results, *, error_on_invalid=False):
return self._sess.set_tuning_results(results, error_on_invalid)
def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices):
"""
Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.

View file

@ -1664,6 +1664,57 @@ including arg name, arg type (contains both type and shape).)pbdoc")
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
#if !defined(ORT_MINIMAL_BUILD)
py::list ret;
for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) {
py::dict py_trs;
py_trs["ep"] = trs.ep;
py_trs["results"] = trs.results;
py_trs["validators"] = trs.validators;
ret.append(std::move(py_trs));
}
return ret;
#else
ORT_UNUSED_PARAMETER(sess);
ORT_THROW("TunableOp and get_tuning_results are not supported in this build.");
#endif
})
.def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void {
#if !defined(ORT_MINIMAL_BUILD)
std::vector<TuningResults> tuning_results;
for (auto handle: results) {
auto py_trs = handle.cast<py::dict>();
TuningResults trs;
trs.ep = py_trs["ep"].cast<py::str>();
for (const auto [py_op_sig, py_kernel_map]: py_trs["results"].cast<py::dict>()) {
KernelMap kernel_map;
for (const auto [py_params_sig, py_kernel_id]: py_kernel_map.cast<py::dict>()) {
kernel_map[py_params_sig.cast<py::str>()] = py_kernel_id.cast<py::int_>();
}
trs.results[py_op_sig.cast<py::str>()] = kernel_map;
}
for (const auto [k, v]: py_trs["validators"].cast<py::dict>()) {
trs.validators[k.cast<py::str>()] = v.cast<py::str>();
}
tuning_results.emplace_back(std::move(trs));
}
Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
#else
ORT_UNUSED_PARAMETER(sess);
ORT_UNUSED_PARAMETER(results);
ORT_UNUSED_PARAMETER(error_on_invalid);
ORT_THROW("TunableOp and set_tuning_results are not supported in this build.");
#endif
});
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())

View file

@ -20,8 +20,35 @@ namespace {
// test on CPU and it does not use stream
using StreamT = void*;
constexpr static const char* kTestKey = "THE_TEST_KEY";
constexpr static const char* kValidTestValue = "THE_VALID_TEST_VALUE";
static std::string GetTestValue() {
return kValidTestValue;
}
static Status ValidateTestValue(const std::string& value) {
auto current = GetTestValue();
ORT_RETURN_IF(current != value, "Only ", kValidTestValue, " is valid for key ", kTestKey);
return Status::OK();
}
class TestTuningResultsValidator : public TuningResultsValidator {
public:
TestTuningResultsValidator() {
RegisterValidator(kTestKey, GetTestValue, ValidateTestValue);
};
protected:
std::string GetOrtBuildConfig() const override {
return "TEST_BUILD";
}
};
class TestTuningContext : public ITuningContext {
public:
using ITuningContext::ITuningContext;
void EnableTunableOp() override { tuning_enabled_ = true; }
void DisableTunableOp() override { tuning_enabled_ = false; }
bool IsTunableOpEnabled() const override { return tuning_enabled_; }
@ -29,14 +56,19 @@ class TestTuningContext : public ITuningContext {
TuningResultsManager& GetTuningResultsManager() override { return manager_; }
const TuningResultsManager& GetTuningResultsManager() const override { return manager_; }
const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; }
void ClearCache() { manager_.Clear(); }
private:
bool tuning_enabled_{false};
TuningResultsManager manager_{};
TestTuningResultsValidator validator_{};
};
class TestEP : public IExecutionProvider {
static constexpr const char* kEPType = "TestEP";
TestTuningContext tuning_ctx_{};
TestTuningContext tuning_ctx_{this};
public:
TestEP() : IExecutionProvider{kEPType, true} {}
@ -45,6 +77,7 @@ class TestEP : public IExecutionProvider {
return const_cast<TestTuningContext*>(&tuning_ctx_);
}
void ClearCache() { tuning_ctx_.ClearCache(); }
};
class TestTimer : public ITimer<StreamT> {
@ -584,12 +617,59 @@ TEST(TuningContext, TunableOpRespectTuningContext) {
{
ASSERT_EQ(mgr.Lookup(op.Signature()).size(), 0u);
// TunableOp(...), respect the existing entry
// TunableOp(...), respect the existing entry (manually loaded) if id in bound
mgr.Add(op.Signature(), params.Signature(), tuning::TunableVecAddSelectFast::kSlowFullId);
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(last_run, "SlowFull");
}
last_run.clear();
mgr.Clear();
{
// TunableOp(...), must not respect the existing entry if id not in bound
// manually create an out of bound id
mgr.Add(op.Signature(), params.Signature(), 1000000);
auto status = op(&params);
ASSERT_TRUE(status.IsOK()) << "TunableOp should recover from an out of bound id";
ASSERT_EQ(last_run, "FastFull");
ASSERT_EQ(mgr.Lookup(op.Signature(), params.Signature()), tuning::TunableVecAddSelectFast::kFastFullId);
}
#endif
}
TEST(TuningContext, GetAndLoadTuningResults) {
#ifdef ORT_NO_RTTI
GTEST_SKIP() << "TunableOp needs RTTI to work correctly";
#else
constexpr const int a = 7500000;
constexpr const int b = 42;
int c{};
tuning::VecAddParamsRecordLastRun params(&a, &b, &c, 1, 0);
std::string last_run;
params.last_run = &last_run;
tuning::TunableVecAddSelectFast op{};
auto* ctx = params.TuningContext();
ctx->EnableTunableOp();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(last_run, "FastFull");
auto trs = ctx->GetTuningResults();
ASSERT_EQ(trs.ep, "TestEP");
ASSERT_EQ(trs.validators.size(), TestTuningResultsValidator::mandatory_keys.size() + 1);
for (const auto& key : TestTuningResultsValidator::mandatory_keys) {
ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(key)));
}
ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(kTestKey)));
ASSERT_EQ(trs.results.size(), 1u);
ASSERT_THAT(trs.results, ::testing::Contains(::testing::Key(op.Signature())));
ASSERT_THAT(trs.results[op.Signature()], ::testing::Contains(::testing::Key(params.Signature())));
ASSERT_EQ(trs.results[op.Signature()][params.Signature()], tuning::TunableVecAddSelectFast::kFastFullId);
#endif
}

View file

@ -2,7 +2,7 @@
# Licensed under the MIT License.
# pylint: disable=C0116,W0212,R1720,C0114
# -*- coding: UTF-8 -*-
import copy
import gc
import os
import platform
@ -387,6 +387,89 @@ class TestInferenceSession(unittest.TestCase):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"])
self.assertEqual(["CPUExecutionProvider"], sess.get_providers())
def testGetAndSetTuningResults(self):
def getTuningResultsForEp(sess, ep): # without the outer list
tuning_results = sess.get_tuning_results()
self.assertGreaterEqual(len(tuning_results), 1)
tuning_results_for_this_ep = [t for t in tuning_results if t.get("ep") == ep]
self.assertEqual(len(tuning_results_for_this_ep), 1)
return tuning_results_for_this_ep[0]
probe_op_sig = "probe_but_not_an_op_signature"
probe_params_sig = "probe_but_not_an_params_signature"
probe_value = 10000000
def copyTuningResultsWithProbe(tr):
tr = copy.deepcopy(tr)
tr["results"][probe_op_sig] = {probe_params_sig: probe_value}
return tr
def assertTuningResultsLoaded(sess, ep):
tr = getTuningResultsForEp(sess, ep)
self.assertIn(probe_op_sig, tr["results"])
self.assertEqual(tr["results"][probe_op_sig], {probe_params_sig: probe_value})
def assertTuningResultsNotLoaded(sess, ep):
tr = getTuningResultsForEp(sess, ep)
self.assertNotIn(probe_op_sig, tr["results"])
def doTestGetAndSetTuningResults(ep):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=[ep])
tuning_results = getTuningResultsForEp(sess, ep)
self.assertIn("ep", tuning_results)
self.assertIn("results", tuning_results)
self.assertIn("validators", tuning_results)
self.assertIn("ORT_VERSION", tuning_results["validators"])
self.assertNotIn("NOT_A_VALIDATOR_KEY", tuning_results["validators"])
# invalid EP will be rejected
invalid_unknown_ep = copyTuningResultsWithProbe(tuning_results)
invalid_unknown_ep["ep"] = "UnknownEP"
sess.set_tuning_results([invalid_unknown_ep])
with self.assertRaises(RuntimeError) as context:
sess.set_tuning_results([invalid_unknown_ep], error_on_invalid=True)
self.assertIn("Cannot find execution provider UnknownEP", str(context.exception))
assertTuningResultsNotLoaded(sess, ep)
# missing validator key will be rejected
mismatched_validator_key_missing = copyTuningResultsWithProbe(tuning_results)
mismatched_validator_key_missing["validators"].pop("ORT_VERSION")
sess.set_tuning_results([mismatched_validator_key_missing])
with self.assertRaises(RuntimeError) as context:
sess.set_tuning_results([mismatched_validator_key_missing], error_on_invalid=True)
self.assertIn("ORT_VERSION", str(context.exception))
self.assertIn("is not provided for validation", str(context.exception))
assertTuningResultsNotLoaded(sess, ep)
mismatched_validator_key_extra = copyTuningResultsWithProbe(tuning_results)
mismatched_validator_key_extra["validators"]["NOT_A_VALIDATOR_KEY"] = "NOT_USED"
sess.set_tuning_results([mismatched_validator_key_extra])
with self.assertRaises(RuntimeError) as context:
sess.set_tuning_results([mismatched_validator_key_extra], error_on_invalid=True)
self.assertIn("NOT_A_VALIDATOR_KEY", str(context.exception))
self.assertIn("is unable to consume it", str(context.exception))
assertTuningResultsNotLoaded(sess, ep)
validation_failure = copyTuningResultsWithProbe(tuning_results)
validation_failure["validators"]["ORT_VERSION"] = "This is not a proper ORT_VERSION value!"
sess.set_tuning_results([validation_failure])
with self.assertRaises(RuntimeError) as context:
sess.set_tuning_results([validation_failure], error_on_invalid=True)
self.assertIn("Failed to load TuningResults", str(context.exception))
self.assertIn("version mismatch", str(context.exception))
assertTuningResultsNotLoaded(sess, ep)
loadable = copyTuningResultsWithProbe(tuning_results)
sess.set_tuning_results([loadable], error_on_invalid=True)
assertTuningResultsLoaded(sess, ep)
if "CUDAExecutionProvider" in onnxrt.get_available_providers():
doTestGetAndSetTuningResults("CUDAExecutionProvider")
if "ROCMExecutionProvider" in onnxrt.get_available_providers():
doTestGetAndSetTuningResults("ROCMExecutionProvider")
def testRunModel(self):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers)
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)

View file

@ -0,0 +1,169 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import copy
import json
import sys
from collections import OrderedDict
from pprint import pprint
from typing import Any, Dict, List
import onnx
TuningResults = Dict[str, Any]
_TUNING_RESULTS_KEY = "tuning_results"
def _find_tuning_results_in_props(metadata_props):
for idx, prop in enumerate(metadata_props):
if prop.key == _TUNING_RESULTS_KEY:
return idx
return -1
def extract(model: onnx.ModelProto):
idx = _find_tuning_results_in_props(model.metadata_props)
if idx < 0:
return None
tuning_results_prop = model.metadata_props[idx]
return json.loads(tuning_results_prop.value)
def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False):
idx = _find_tuning_results_in_props(model.metadata_props)
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"
if idx >= 0:
model.metadata_props.pop(idx)
entry = model.metadata_props.add()
entry.key = _TUNING_RESULTS_KEY
entry.value = json.dumps(tuning_results)
return model
class Merger:
class EpAndValidators:
def __init__(self, ep: str, validators: Dict[str, str]):
self.ep = ep
self.validators = copy.deepcopy(validators)
self.key = (ep, tuple(sorted(validators.items())))
def __hash__(self):
return hash(self.key)
def __eq__(self, other):
return self.ep == other.ep and self.key == other.key
def __init__(self):
self.ev_to_results = OrderedDict()
def merge(self, tuning_results: List[TuningResults]):
for trs in tuning_results:
self._merge_one(trs)
def get_merged(self):
tuning_results = []
for ev, flat_results in self.ev_to_results.items():
results = {}
trs = {
"ep": ev.ep,
"validators": ev.validators,
"results": results,
}
for (op_sig, params_sig), kernel_id in flat_results.items():
kernel_map = results.setdefault(op_sig, {})
kernel_map[params_sig] = kernel_id
tuning_results.append(trs)
return tuning_results
def _merge_one(self, trs: TuningResults):
ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
flat_results = self.ev_to_results.setdefault(ev, {})
for op_sig, kernel_map in trs["results"].items():
for params_sig, kernel_id in kernel_map.items():
if (op_sig, params_sig) not in flat_results:
flat_results[(op_sig, params_sig)] = kernel_id
def parse_args():
parser = argparse.ArgumentParser()
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")
extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
extract_parser.add_argument("input_onnx")
extract_parser.add_argument("output_json")
embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")
merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")
pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")
args = parser.parse_args()
if len(vars(args)) == 0:
parser.print_help()
exit(-1)
return args
def main():
args = parse_args()
if args.cmd == "extract":
tuning_results = extract(onnx.load_model(args.input_onnx))
if tuning_results is None:
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
sys.exit(-1)
json.dump(tuning_results, open(args.output_json, "w"))
elif args.cmd == "embed":
model = onnx.load_model(args.input_onnx)
merger = Merger()
for tuning_results in [json.load(open(f)) for f in args.input_json]:
merger.merge(tuning_results)
model = embed(model, merger.get_merged(), args.force)
onnx.save_model(model, args.output_onnx)
elif args.cmd == "merge":
merger = Merger()
for tuning_results in [json.load(open(f)) for f in args.input_json]:
merger.merge(tuning_results)
json.dump(merger.get_merged(), open(args.output_json, "w"))
elif args.cmd == "pprint":
tuning_results = None
try:
tuning_results = json.load(open(args.json_or_onnx, "r"))
except Exception:
# it might be an onnx file otherwise, try it latter
pass
if tuning_results is None:
try:
model = onnx.load_model(args.json_or_onnx)
tuning_results = extract(model)
if tuning_results is None:
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
sys.exit(-1)
except Exception:
pass
if tuning_results is None:
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
sys.exit(-1)
pprint(tuning_results)
else:
# invalid choice will be handled by the parser
pass
if __name__ == "__main__":
main()