diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 29b4a443e2..50442268e7 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -136,12 +136,20 @@ class TunableOp { ITuningContext* ctx = params->TuningContext(); if (ctx->IsTunableOpEnabled()) { auto& mgr = ctx->GetTuningResultsManager(); - id = mgr.Lookup(Signature(), params->Signature()); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + id = mgr.Lookup(op_sig, params_sig); + if (id > static_cast(ops_.size())) { + LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig + << ", id:" << id << ", registered op:" << ops_.size(); + mgr.Delete(op_sig, params_sig); + id = -1; + } if (id < 0) { auto maybe_proxy_params = PreTuning(params); id = FindFastest(maybe_proxy_params); PostTuning(maybe_proxy_params); - mgr.Add(Signature(), params->Signature(), id); + mgr.Add(op_sig, params_sig, id); } } ORT_RETURN_IF_ERROR(ops_[id](params)); diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index 77c15d65b5..6cd61931b8 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -17,6 +17,7 @@ class TuningResultsValidator; class ITuningContext { public: + explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {} virtual ~ITuningContext() = default; virtual void EnableTunableOp() = 0; @@ -25,6 +26,14 @@ class ITuningContext { virtual TuningResultsManager& GetTuningResultsManager() = 0; virtual const TuningResultsManager& GetTuningResultsManager() const = 0; + + virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; + + virtual TuningResults GetTuningResults() const; + virtual Status LoadTuningResults(const TuningResults& tr); + + protected: + IExecutionProvider* ep_; }; class TuningResultsManager { @@ -36,6 +45,7 @@ class TuningResultsManager { int Lookup(const std::string& op_signature, const std::string& params_signature) const; void Add(const std::string& op_signature, const std::string& params_signature, int best_id); + void Delete(const std::string& op_signature, const std::string& params_signature); void Load(const std::unordered_map& results_to_load); std::unordered_map Dump() const; @@ -50,4 +60,35 @@ class TuningResultsManager { std::unordered_map results_; }; +class TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + virtual ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + Status ValidateAll(const std::unordered_map& to_validate) const; + + protected: + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + virtual std::string GetOrtVersion() const; + virtual Status ValidateOrtVersion(const std::string& value) const; + + virtual std::string GetOrtGitCommit() const; + virtual Status ValidateOrtGitCommit(const std::string& value) const; + + virtual std::string GetOrtBuildConfig() const; + virtual Status ValidateOrtBuildConfig(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"}; + + private: + GetValidateFuncs validators_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h index 50aeee1c67..c8b0583e3e 100644 --- a/onnxruntime/core/framework/tuning_context_impl.h +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// This file contains the implementation of TuningResultsManager. At the moment, there is no necessity to expose these +// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these // methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries // are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem, // the EP must has and only has one translation unit include this file. @@ -11,12 +11,32 @@ #pragma once +#include +#include +#include + #include "core/framework/tunable.h" #include "core/framework/tuning_context.h" #include "core/framework/tuning_results.h" namespace onnxruntime { +TuningResults ITuningContext::GetTuningResults() const { + TuningResults tr; + tr.ep = ep_->Type(); + tr.validators = GetTuningResultsValidator().GetAllValidators(); + tr.results = GetTuningResultsManager().Dump(); + return tr; +} + +Status ITuningContext::LoadTuningResults(const TuningResults& tr) { + ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch"); + LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep; + ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators)); + GetTuningResultsManager().Load(tr.results); + return Status::OK(); +} + KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const { std::scoped_lock l{lock_}; auto it = results_.find(op_signature); @@ -56,6 +76,7 @@ inline void AddImpl(const std::string& op_signature, return; } + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ") -> " << best_id; kernel_map[params_signature] = best_id; } @@ -70,6 +91,24 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin AddImpl(op_signature, params_signature, best_id, it->second); } +// NOLINTNEXTLINE(bugprone-easily-swappable-parameters) +void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + return; + } + + auto it2 = it->second.find(params_signature); + if (it2 == it->second.end()) { + return; + } + + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ")"; + it->second.erase(it2); +} + std::unordered_map TuningResultsManager::Dump() const { std::scoped_lock l{lock_}; return results_; @@ -81,6 +120,9 @@ void DisjointMergeImpl( /*out*/ std::unordered_map& results) { auto it = results.find(op_signature); if (it == results.end()) { + for(const auto& [param_sig, kernel_id] : kernel_map) { + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << param_sig << ") -> " << kernel_id; + } results[op_signature] = kernel_map; return; } @@ -106,4 +148,138 @@ void TuningResultsManager::Clear() { results_ = {}; } +static Status CheckMandatoryKeys( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + bool passed = true; + std::ostringstream oss; + for (const auto& k : TuningResultsValidator::mandatory_keys) { + if (gv_funcs.find(k) == gv_funcs.end()) { + passed = false; + oss << "key=\"" << k << "\" is not registered for Get and Validate. "; + } + + if (to_check.find(k) == to_check.end()) { + passed = false; + oss << "key=\"" << k << "\" is not provided for validation. "; + } + } + ORT_RETURN_IF(!passed, oss.str()); + return Status::OK(); +} + +static Status CheckKeysMatching( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + auto get_keys = [](const auto& it) -> std::string { return it.first; }; + std::vector required_keys; + std::vector provided_keys; + std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys); + std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys); + std::sort(required_keys.begin(), required_keys.end()); + std::sort(provided_keys.begin(), provided_keys.end()); + + std::unordered_set intersection; + std::set_intersection(required_keys.cbegin(), required_keys.cend(), + provided_keys.cbegin(), provided_keys.cend(), + std::inserter(intersection, intersection.end())); + bool matched = true; + std::ostringstream oss; + if (intersection.size() != required_keys.size()) { + matched = false; + for (const auto& k : required_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. "; + } + } + } + if (intersection.size() != provided_keys.size()) { + matched = false; + for (const auto& k : provided_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. "; + } + } + } + ORT_RETURN_IF(!matched, oss.str()); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtVersion() const { + return ORT_VERSION; +} + +Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const { + ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch"); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtGitCommit() const { + // TODO: + return ""; +} + +Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const { + // TODO: + ORT_UNUSED_PARAMETER(value); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtBuildConfig() const { + return ""; +} + +Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const { + auto current = GetOrtBuildConfig(); + ORT_RETURN_IF(current != value, + "onnxruntime building configuration mismatch: tuning results produced with library \"", + value, "\", current library built with \"", current, "\""); + return Status::OK(); +} + +TuningResultsValidator::TuningResultsValidator() { + RegisterValidator( + "ORT_VERSION", + [this]() { return GetOrtVersion(); }, + [this](auto&& k) { return ValidateOrtVersion(std::forward(k)); }); + + RegisterValidator( + "ORT_GIT_COMMIT", + [this]() { return GetOrtGitCommit(); }, + [this](auto&& k) { return ValidateOrtGitCommit(std::forward(k)); }); + + RegisterValidator( + "ORT_BUILD_CONFIG", + [this]() { return GetOrtBuildConfig(); }, + [this](auto&& k) { return ValidateOrtBuildConfig(std::forward(k)); }); +} + +Status TuningResultsValidator::ValidateAll(const std::unordered_map& to_validate) const { + ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate)); + ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate)); + + for (const auto& [key, value] : to_validate) { + const auto& it = validators_.find(key); + ORT_ENFORCE(it != validators_.cend()); + const ValidateFunc& validator = it->second.second; + ORT_RETURN_IF_ERROR(validator(value)); + } + + return Status::OK(); +} + +std::unordered_map TuningResultsValidator::GetAllValidators() const { + std::unordered_map ret; + for (const auto& [key, get_validate_func_pair] : validators_) { + const GetFunc& getter = get_validate_func_pair.first; + ret[key] = getter(); + } + return ret; +} + +void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) { + ORT_ENFORCE(validators_.find(key) == validators_.end()); + validators_[key] = std::make_pair(gf, vf); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index 55c79273b1..bea2889ca8 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -14,7 +14,40 @@ namespace onnxruntime { namespace cuda { namespace tunable { -CudaTuningContext::CudaTuningContext(CUDAExecutionProvider*, TunableOpInfo* info) : info_(info) {} +static std::string GetCudaVersion() { + int version; + CUDA_CALL_THROW(cudaRuntimeGetVersion(&version)); + return std::to_string(version); +} + +static Status ValidateCudaVersion(const std::string& value) { + auto current = GetCudaVersion(); + ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value, + ", onnxruntime currently run with CUDA ", current); + return Status::OK(); +} + +std::string CudaTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status CudaTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep) : ep_(ep) { + RegisterValidator("CUDA_VERSION", GetCudaVersion, ValidateCudaVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + +CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void CudaTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; @@ -38,6 +71,10 @@ const TuningResultsManager& CudaTuningContext::GetTuningResultsManager() const { return manager_; } +const TuningResultsValidator& CudaTuningContext::GetTuningResultsValidator() const { + return validator_; +} + } // namespace tunable } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h index df47c53c7a..10d0782f5b 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h @@ -15,6 +15,18 @@ class CUDAExecutionProvider; namespace cuda { namespace tunable { +class CudaTuningResultsValidator : public TuningResultsValidator { + public: + CudaTuningResultsValidator(CUDAExecutionProvider* ep); + + protected: + std::string GetDeviceModel() const; + Status ValidateDeviceModel(const std::string& value) const; + + private: + CUDAExecutionProvider* ep_; // non-owning handle +}; + class CudaTuningContext : public ITuningContext { public: explicit CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info); @@ -26,9 +38,12 @@ class CudaTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; + const TuningResultsValidator& GetTuningResultsValidator() const override; + private: TunableOpInfo* info_; // non-owning handle TuningResultsManager manager_; + CudaTuningResultsValidator validator_; }; } // namespace tunable diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index b2a5eab94a..cbe9129c87 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -14,7 +14,7 @@ #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "core/providers/rocm/tunable/rocm_tuning_context.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index c2888d7664..bead003e83 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -14,7 +14,66 @@ namespace onnxruntime { namespace rocm { namespace tunable { -RocmTuningContext::RocmTuningContext(ROCMExecutionProvider*, TunableOpInfo* info) : info_(info) {} +static std::string GetHipVersion() { + int version; + HIP_CALL_THROW(hipRuntimeGetVersion(&version)); + return std::to_string(version); +} + +static Status ValidateHipVersion(const std::string& value) { + auto current = GetHipVersion(); + ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value, + ", onnxruntime currently run with HIP ", current); + return Status::OK(); +} + +static std::string GetRocBlasVersion() { + char buf[64]; + ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256)); + buf[63] = '\0'; + return buf; +} + +static Status ValidateRocBlasVersion(const std::string& value) { + auto current = GetRocBlasVersion(); + ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value, + ", onnxruntime currently run with rocblas ", current); + return Status::OK(); +} + +std::string RocmTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { + RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); + RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + +std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { + std::ostringstream oss; + oss << "USE_CK=" << USE_COMPOSABLE_KERNEL << "|"; +#ifdef USE_ROCBLAS_EXTENSION_API + oss << "USE_ROCBLAS_EXTENSION_API=" << 1 << "|"; +#else + oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|"; +#endif + return oss.str(); +} + +RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void RocmTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; @@ -38,6 +97,10 @@ const TuningResultsManager& RocmTuningContext::GetTuningResultsManager() const { return manager_; } +const TuningResultsValidator& RocmTuningContext::GetTuningResultsValidator() const { + return validator_; +} + } // namespace tunable } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h index d6eb0886dd..cffc0e5614 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h @@ -15,6 +15,20 @@ class ROCMExecutionProvider; namespace rocm { namespace tunable { +class RocmTuningResultsValidator : public TuningResultsValidator { + public: + RocmTuningResultsValidator(ROCMExecutionProvider* ep); + + protected: + std::string GetOrtBuildConfig() const override; + + std::string GetDeviceModel() const; + Status ValidateDeviceModel(const std::string& value) const; + + private: + ROCMExecutionProvider* ep_; // non-owning handle +}; + class RocmTuningContext : public ITuningContext { public: explicit RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info); @@ -26,9 +40,12 @@ class RocmTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; + const TuningResultsValidator& GetTuningResultsValidator() const override; + private: TunableOpInfo* info_; // non-owning handle TuningResultsManager manager_; + RocmTuningResultsValidator validator_; }; } // namespace tunable diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 79068a0271..79766bb7b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1500,6 +1500,14 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); } } + + std::vector tuning_results; + bool found_tuning_results = false; + ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( + model_metadata_, tuning_results, found_tuning_results)); + if (found_tuning_results) { + ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results)); + } #endif // !defined(ORT_MINIMAL_BUILD) // Resolve memory pattern flags of the main graph and subgraph session states @@ -2182,6 +2190,50 @@ const profiling::Profiler& InferenceSession::GetProfiling() const { return session_profiler_; } +#if !defined(ORT_MINIMAL_BUILD) +std::vector InferenceSession::GetTuningResults() const { + std::vector ret; + for (const auto& provider : execution_providers_) { + const auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx != nullptr) { + ret.emplace_back(tuning_ctx->GetTuningResults()); + } + } + return ret; +} + +Status InferenceSession::SetTuningResults(const std::vector& trs, bool error_on_invalid) { + std::string msg; + + for (size_t i = 0; i < trs.size(); i++) { + const auto& tr = trs[i]; + auto* provider = execution_providers_.Get(tr.ep); + if (provider == nullptr) { + msg = MakeString("Cannot find execution provider ", tr.ep); + ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; + continue; + } + + auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx == nullptr) { + msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp."); + ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; + continue; + } + + auto status = tuning_ctx->LoadTuningResults(tr); + if (!status.IsOK()) { + msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage()); + ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; + } + } + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { return session_state_->GetAllocator(mem_info); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f01523c923..95b0dde281 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -18,6 +18,7 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" +#include "core/framework/tuning_results.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" @@ -448,6 +449,23 @@ class InferenceSession { */ const profiling::Profiler& GetProfiling() const; +#if !defined(ORT_MINIMAL_BUILD) + /** + * Get the TuningResults of TunableOp for every execution providers. + * @return The TuningResults of each execution provider. + */ + std::vector GetTuningResults() const; + + /** + * Set the TuningResults back to each execution provider. Mainly for offline tuning. + * @param trs is the list of TuningResults to be loaded. + * @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced. + * @return OK if success. + */ + Status SetTuningResults(const std::vector& trs, bool error_on_invalid = false); +#endif + + #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) MemoryProfiler& GetMemoryProfiler() { return memory_profiler_; diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index d938228c3a..3e2d03a930 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -103,6 +103,13 @@ static Status SetEnableProfiling(SessionOptions& session_options, return Status::OK(); } +// This function is called by nlohmann/json +void from_json(const json& j, TuningResults& trs) { + j.at("ep").get_to(trs.ep); + j.at("results").get_to(trs.results); + j.at("validators").get_to(trs.validators); +} + //--------------------------------------------------- //--- end of session options related helpers --- //--------------------------------------------------- @@ -227,6 +234,36 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options "Parsing RunOptions from ModelProto is not supported yet"); } +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + std::vector& results, + bool& key_found) { + results.clear(); + key_found = false; + auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); + if (it == metadata.custom_metadata_map.end()) { + return Status::OK(); + } + + key_found = true; + LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; + + Status status; + ORT_TRY { + auto parsed_tuning_results_json = json::parse(it->second); + results = parsed_tuning_results_json.get>(); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, + "Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring..."); + }); + ORT_RETURN_IF_ERROR(status); + } + + return Status::OK(); +} + } // namespace inference_session_utils } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index 3b021bfd86..a0bcdb9013 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -12,6 +12,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "core/framework/session_options.h" +#include "core/framework/tuning_results.h" #include "core/common/common.h" #include "nlohmann/json.hpp" using json = nlohmann::json; @@ -30,6 +31,7 @@ static constexpr const char* kOrtLoadConfigFromModelEnvVar = "ORT_LOAD_CONFIG_FR // static constexpr const char* kOrtConfigKey = "ort_config"; static constexpr const char* kSessionOptionsKey = "session_options"; +static constexpr const char* kTuningResultsKeys = "tuning_results"; class JsonConfigParser { public: @@ -56,6 +58,10 @@ class JsonConfigParser { bool is_ort_config_json_available_ = false; }; +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + /*out*/ std::vector& results, + /*out*/ bool& key_found); + #endif // !defined(ORT_MINIMAL_BUILD) } // namespace inference_session_utils diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f733c13b6d..0883c528c9 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -286,6 +286,12 @@ class Session: """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) + def get_tuning_results(self): + return self._sess.get_tuning_results() + + def set_tuning_results(self, results, *, error_on_invalid=False): + return self._sess.set_tuning_results(results, error_on_invalid) + def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices): """ Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 490eb92afc..de14b9aff3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1664,6 +1664,57 @@ including arg name, arg type (contains both type and shape).)pbdoc") status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + }) + .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { +#if !defined(ORT_MINIMAL_BUILD) + py::list ret; + for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) { + py::dict py_trs; + py_trs["ep"] = trs.ep; + py_trs["results"] = trs.results; + py_trs["validators"] = trs.validators; + ret.append(std::move(py_trs)); + } + + return ret; +#else + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); +#endif + }) + .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { +#if !defined(ORT_MINIMAL_BUILD) + std::vector tuning_results; + for (auto handle: results) { + auto py_trs = handle.cast(); + TuningResults trs; + trs.ep = py_trs["ep"].cast(); + + for (const auto [py_op_sig, py_kernel_map]: py_trs["results"].cast()) { + KernelMap kernel_map; + for (const auto [py_params_sig, py_kernel_id]: py_kernel_map.cast()) { + kernel_map[py_params_sig.cast()] = py_kernel_id.cast(); + } + trs.results[py_op_sig.cast()] = kernel_map; + } + + for (const auto [k, v]: py_trs["validators"].cast()) { + trs.validators[k.cast()] = v.cast(); + } + + tuning_results.emplace_back(std::move(trs)); + } + + Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid); + if (!status.IsOK()) { + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + } +#else + ORT_UNUSED_PARAMETER(sess); + ORT_UNUSED_PARAMETER(results); + ORT_UNUSED_PARAMETER(error_on_invalid); + ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); +#endif }); py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 8d3535cabc..d68882f0a0 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -20,8 +20,35 @@ namespace { // test on CPU and it does not use stream using StreamT = void*; +constexpr static const char* kTestKey = "THE_TEST_KEY"; +constexpr static const char* kValidTestValue = "THE_VALID_TEST_VALUE"; + +static std::string GetTestValue() { + return kValidTestValue; +} + +static Status ValidateTestValue(const std::string& value) { + auto current = GetTestValue(); + ORT_RETURN_IF(current != value, "Only ", kValidTestValue, " is valid for key ", kTestKey); + return Status::OK(); +} + +class TestTuningResultsValidator : public TuningResultsValidator { + public: + TestTuningResultsValidator() { + RegisterValidator(kTestKey, GetTestValue, ValidateTestValue); + }; + + protected: + std::string GetOrtBuildConfig() const override { + return "TEST_BUILD"; + } +}; + class TestTuningContext : public ITuningContext { public: + using ITuningContext::ITuningContext; + void EnableTunableOp() override { tuning_enabled_ = true; } void DisableTunableOp() override { tuning_enabled_ = false; } bool IsTunableOpEnabled() const override { return tuning_enabled_; } @@ -29,14 +56,19 @@ class TestTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override { return manager_; } const TuningResultsManager& GetTuningResultsManager() const override { return manager_; } + const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; } + + void ClearCache() { manager_.Clear(); } + private: bool tuning_enabled_{false}; TuningResultsManager manager_{}; + TestTuningResultsValidator validator_{}; }; class TestEP : public IExecutionProvider { static constexpr const char* kEPType = "TestEP"; - TestTuningContext tuning_ctx_{}; + TestTuningContext tuning_ctx_{this}; public: TestEP() : IExecutionProvider{kEPType, true} {} @@ -45,6 +77,7 @@ class TestEP : public IExecutionProvider { return const_cast(&tuning_ctx_); } + void ClearCache() { tuning_ctx_.ClearCache(); } }; class TestTimer : public ITimer { @@ -584,12 +617,59 @@ TEST(TuningContext, TunableOpRespectTuningContext) { { ASSERT_EQ(mgr.Lookup(op.Signature()).size(), 0u); - // TunableOp(...), respect the existing entry + // TunableOp(...), respect the existing entry (manually loaded) if id in bound mgr.Add(op.Signature(), params.Signature(), tuning::TunableVecAddSelectFast::kSlowFullId); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_EQ(last_run, "SlowFull"); } + + last_run.clear(); + mgr.Clear(); + { + // TunableOp(...), must not respect the existing entry if id not in bound + // manually create an out of bound id + mgr.Add(op.Signature(), params.Signature(), 1000000); + auto status = op(¶ms); + ASSERT_TRUE(status.IsOK()) << "TunableOp should recover from an out of bound id"; + ASSERT_EQ(last_run, "FastFull"); + ASSERT_EQ(mgr.Lookup(op.Signature(), params.Signature()), tuning::TunableVecAddSelectFast::kFastFullId); + } +#endif +} + +TEST(TuningContext, GetAndLoadTuningResults) { +#ifdef ORT_NO_RTTI + GTEST_SKIP() << "TunableOp needs RTTI to work correctly"; +#else + constexpr const int a = 7500000; + constexpr const int b = 42; + int c{}; + tuning::VecAddParamsRecordLastRun params(&a, &b, &c, 1, 0); + std::string last_run; + params.last_run = &last_run; + + tuning::TunableVecAddSelectFast op{}; + auto* ctx = params.TuningContext(); + ctx->EnableTunableOp(); + + auto status = op(¶ms); + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(last_run, "FastFull"); + + auto trs = ctx->GetTuningResults(); + ASSERT_EQ(trs.ep, "TestEP"); + + ASSERT_EQ(trs.validators.size(), TestTuningResultsValidator::mandatory_keys.size() + 1); + for (const auto& key : TestTuningResultsValidator::mandatory_keys) { + ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(key))); + } + ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(kTestKey))); + + ASSERT_EQ(trs.results.size(), 1u); + ASSERT_THAT(trs.results, ::testing::Contains(::testing::Key(op.Signature()))); + ASSERT_THAT(trs.results[op.Signature()], ::testing::Contains(::testing::Key(params.Signature()))); + ASSERT_EQ(trs.results[op.Signature()][params.Signature()], tuning::TunableVecAddSelectFast::kFastFullId); #endif } diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 89fd90ad3a..8232044a29 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. # pylint: disable=C0116,W0212,R1720,C0114 -# -*- coding: UTF-8 -*- +import copy import gc import os import platform @@ -387,6 +387,89 @@ class TestInferenceSession(unittest.TestCase): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) self.assertEqual(["CPUExecutionProvider"], sess.get_providers()) + def testGetAndSetTuningResults(self): + def getTuningResultsForEp(sess, ep): # without the outer list + tuning_results = sess.get_tuning_results() + self.assertGreaterEqual(len(tuning_results), 1) + tuning_results_for_this_ep = [t for t in tuning_results if t.get("ep") == ep] + self.assertEqual(len(tuning_results_for_this_ep), 1) + return tuning_results_for_this_ep[0] + + probe_op_sig = "probe_but_not_an_op_signature" + probe_params_sig = "probe_but_not_an_params_signature" + probe_value = 10000000 + + def copyTuningResultsWithProbe(tr): + tr = copy.deepcopy(tr) + tr["results"][probe_op_sig] = {probe_params_sig: probe_value} + return tr + + def assertTuningResultsLoaded(sess, ep): + tr = getTuningResultsForEp(sess, ep) + self.assertIn(probe_op_sig, tr["results"]) + self.assertEqual(tr["results"][probe_op_sig], {probe_params_sig: probe_value}) + + def assertTuningResultsNotLoaded(sess, ep): + tr = getTuningResultsForEp(sess, ep) + self.assertNotIn(probe_op_sig, tr["results"]) + + def doTestGetAndSetTuningResults(ep): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=[ep]) + tuning_results = getTuningResultsForEp(sess, ep) + + self.assertIn("ep", tuning_results) + self.assertIn("results", tuning_results) + self.assertIn("validators", tuning_results) + self.assertIn("ORT_VERSION", tuning_results["validators"]) + self.assertNotIn("NOT_A_VALIDATOR_KEY", tuning_results["validators"]) + + # invalid EP will be rejected + invalid_unknown_ep = copyTuningResultsWithProbe(tuning_results) + invalid_unknown_ep["ep"] = "UnknownEP" + sess.set_tuning_results([invalid_unknown_ep]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([invalid_unknown_ep], error_on_invalid=True) + self.assertIn("Cannot find execution provider UnknownEP", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + # missing validator key will be rejected + mismatched_validator_key_missing = copyTuningResultsWithProbe(tuning_results) + mismatched_validator_key_missing["validators"].pop("ORT_VERSION") + sess.set_tuning_results([mismatched_validator_key_missing]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([mismatched_validator_key_missing], error_on_invalid=True) + self.assertIn("ORT_VERSION", str(context.exception)) + self.assertIn("is not provided for validation", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + mismatched_validator_key_extra = copyTuningResultsWithProbe(tuning_results) + mismatched_validator_key_extra["validators"]["NOT_A_VALIDATOR_KEY"] = "NOT_USED" + sess.set_tuning_results([mismatched_validator_key_extra]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([mismatched_validator_key_extra], error_on_invalid=True) + self.assertIn("NOT_A_VALIDATOR_KEY", str(context.exception)) + self.assertIn("is unable to consume it", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + validation_failure = copyTuningResultsWithProbe(tuning_results) + validation_failure["validators"]["ORT_VERSION"] = "This is not a proper ORT_VERSION value!" + sess.set_tuning_results([validation_failure]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([validation_failure], error_on_invalid=True) + self.assertIn("Failed to load TuningResults", str(context.exception)) + self.assertIn("version mismatch", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + loadable = copyTuningResultsWithProbe(tuning_results) + sess.set_tuning_results([loadable], error_on_invalid=True) + assertTuningResultsLoaded(sess, ep) + + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + doTestGetAndSetTuningResults("CUDAExecutionProvider") + + if "ROCMExecutionProvider" in onnxrt.get_available_providers(): + doTestGetAndSetTuningResults("ROCMExecutionProvider") + def testRunModel(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) diff --git a/tools/python/offline_tuning.py b/tools/python/offline_tuning.py new file mode 100644 index 0000000000..8dbae5efe8 --- /dev/null +++ b/tools/python/offline_tuning.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import copy +import json +import sys +from collections import OrderedDict +from pprint import pprint +from typing import Any, Dict, List + +import onnx + +TuningResults = Dict[str, Any] + +_TUNING_RESULTS_KEY = "tuning_results" + + +def _find_tuning_results_in_props(metadata_props): + for idx, prop in enumerate(metadata_props): + if prop.key == _TUNING_RESULTS_KEY: + return idx + return -1 + + +def extract(model: onnx.ModelProto): + idx = _find_tuning_results_in_props(model.metadata_props) + if idx < 0: + return None + + tuning_results_prop = model.metadata_props[idx] + return json.loads(tuning_results_prop.value) + + +def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False): + idx = _find_tuning_results_in_props(model.metadata_props) + assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!" + + if idx >= 0: + model.metadata_props.pop(idx) + + entry = model.metadata_props.add() + entry.key = _TUNING_RESULTS_KEY + entry.value = json.dumps(tuning_results) + return model + + +class Merger: + class EpAndValidators: + def __init__(self, ep: str, validators: Dict[str, str]): + self.ep = ep + self.validators = copy.deepcopy(validators) + self.key = (ep, tuple(sorted(validators.items()))) + + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.ep == other.ep and self.key == other.key + + def __init__(self): + self.ev_to_results = OrderedDict() + + def merge(self, tuning_results: List[TuningResults]): + for trs in tuning_results: + self._merge_one(trs) + + def get_merged(self): + tuning_results = [] + for ev, flat_results in self.ev_to_results.items(): + results = {} + trs = { + "ep": ev.ep, + "validators": ev.validators, + "results": results, + } + for (op_sig, params_sig), kernel_id in flat_results.items(): + kernel_map = results.setdefault(op_sig, {}) + kernel_map[params_sig] = kernel_id + tuning_results.append(trs) + return tuning_results + + def _merge_one(self, trs: TuningResults): + ev = Merger.EpAndValidators(trs["ep"], trs["validators"]) + flat_results = self.ev_to_results.setdefault(ev, {}) + for op_sig, kernel_map in trs["results"].items(): + for params_sig, kernel_id in kernel_map.items(): + if (op_sig, params_sig) not in flat_results: + flat_results[(op_sig, params_sig)] = kernel_id + + +def parse_args(): + parser = argparse.ArgumentParser() + sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd") + + extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.") + extract_parser.add_argument("input_onnx") + extract_parser.add_argument("output_json") + + embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.") + embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.") + embed_parser.add_argument("output_onnx", help="Path of the output onnx file.") + embed_parser.add_argument("input_onnx", help="Path of the input onnx file.") + embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.") + + merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.") + merge_parser.add_argument("output_json", help="Path of the output tuning results file.") + merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.") + + pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.") + pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.") + + args = parser.parse_args() + if len(vars(args)) == 0: + parser.print_help() + exit(-1) + return args + + +def main(): + args = parse_args() + if args.cmd == "extract": + tuning_results = extract(onnx.load_model(args.input_onnx)) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + sys.exit(-1) + json.dump(tuning_results, open(args.output_json, "w")) + elif args.cmd == "embed": + model = onnx.load_model(args.input_onnx) + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + model = embed(model, merger.get_merged(), args.force) + onnx.save_model(model, args.output_onnx) + elif args.cmd == "merge": + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + json.dump(merger.get_merged(), open(args.output_json, "w")) + elif args.cmd == "pprint": + tuning_results = None + try: + tuning_results = json.load(open(args.json_or_onnx, "r")) + except Exception: + # it might be an onnx file otherwise, try it latter + pass + + if tuning_results is None: + try: + model = onnx.load_model(args.json_or_onnx) + tuning_results = extract(model) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + sys.exit(-1) + except Exception: + pass + + if tuning_results is None: + sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!") + sys.exit(-1) + + pprint(tuning_results) + else: + # invalid choice will be handled by the parser + pass + + +if __name__ == "__main__": + main()