From a216c9a3fa73ddd593f024b737dbf4aff0b06d68 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Wed, 15 Feb 2023 14:17:34 +0800 Subject: [PATCH] Offline tuning (#14558) Add the ability to get and set tuning results of an inference session. Also add tool to manipulate onnx file to embed the results into the model file and automatically load it on session initialization. --- onnxruntime/core/framework/tunable.h | 12 +- onnxruntime/core/framework/tuning_context.h | 41 ++++ .../core/framework/tuning_context_impl.h | 178 +++++++++++++++++- .../cuda/tunable/cuda_tuning_context.cc | 39 +++- .../cuda/tunable/cuda_tuning_context.h | 15 ++ .../providers/rocm/rocm_execution_provider.h | 2 +- .../rocm/tunable/rocm_tuning_context.cc | 65 ++++++- .../rocm/tunable/rocm_tuning_context.h | 17 ++ onnxruntime/core/session/inference_session.cc | 52 +++++ onnxruntime/core/session/inference_session.h | 18 ++ .../core/session/inference_session_utils.cc | 37 ++++ .../core/session/inference_session_utils.h | 6 + .../onnxruntime_inference_collection.py | 6 + .../python/onnxruntime_pybind_state.cc | 51 +++++ onnxruntime/test/framework/tunable_op_test.cc | 84 ++++++++- .../test/python/onnxruntime_test_python.py | 85 ++++++++- tools/python/offline_tuning.py | 169 +++++++++++++++++ 17 files changed, 868 insertions(+), 9 deletions(-) create mode 100644 tools/python/offline_tuning.py diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 29b4a443e2..50442268e7 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -136,12 +136,20 @@ class TunableOp { ITuningContext* ctx = params->TuningContext(); if (ctx->IsTunableOpEnabled()) { auto& mgr = ctx->GetTuningResultsManager(); - id = mgr.Lookup(Signature(), params->Signature()); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + id = mgr.Lookup(op_sig, params_sig); + if (id > static_cast(ops_.size())) { + LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig + << ", id:" << id << ", registered op:" << ops_.size(); + mgr.Delete(op_sig, params_sig); + id = -1; + } if (id < 0) { auto maybe_proxy_params = PreTuning(params); id = FindFastest(maybe_proxy_params); PostTuning(maybe_proxy_params); - mgr.Add(Signature(), params->Signature(), id); + mgr.Add(op_sig, params_sig, id); } } ORT_RETURN_IF_ERROR(ops_[id](params)); diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index 77c15d65b5..6cd61931b8 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -17,6 +17,7 @@ class TuningResultsValidator; class ITuningContext { public: + explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {} virtual ~ITuningContext() = default; virtual void EnableTunableOp() = 0; @@ -25,6 +26,14 @@ class ITuningContext { virtual TuningResultsManager& GetTuningResultsManager() = 0; virtual const TuningResultsManager& GetTuningResultsManager() const = 0; + + virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; + + virtual TuningResults GetTuningResults() const; + virtual Status LoadTuningResults(const TuningResults& tr); + + protected: + IExecutionProvider* ep_; }; class TuningResultsManager { @@ -36,6 +45,7 @@ class TuningResultsManager { int Lookup(const std::string& op_signature, const std::string& params_signature) const; void Add(const std::string& op_signature, const std::string& params_signature, int best_id); + void Delete(const std::string& op_signature, const std::string& params_signature); void Load(const std::unordered_map& results_to_load); std::unordered_map Dump() const; @@ -50,4 +60,35 @@ class TuningResultsManager { std::unordered_map results_; }; +class TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + virtual ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + Status ValidateAll(const std::unordered_map& to_validate) const; + + protected: + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + virtual std::string GetOrtVersion() const; + virtual Status ValidateOrtVersion(const std::string& value) const; + + virtual std::string GetOrtGitCommit() const; + virtual Status ValidateOrtGitCommit(const std::string& value) const; + + virtual std::string GetOrtBuildConfig() const; + virtual Status ValidateOrtBuildConfig(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"}; + + private: + GetValidateFuncs validators_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h index 50aeee1c67..c8b0583e3e 100644 --- a/onnxruntime/core/framework/tuning_context_impl.h +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// This file contains the implementation of TuningResultsManager. At the moment, there is no necessity to expose these +// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these // methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries // are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem, // the EP must has and only has one translation unit include this file. @@ -11,12 +11,32 @@ #pragma once +#include +#include +#include + #include "core/framework/tunable.h" #include "core/framework/tuning_context.h" #include "core/framework/tuning_results.h" namespace onnxruntime { +TuningResults ITuningContext::GetTuningResults() const { + TuningResults tr; + tr.ep = ep_->Type(); + tr.validators = GetTuningResultsValidator().GetAllValidators(); + tr.results = GetTuningResultsManager().Dump(); + return tr; +} + +Status ITuningContext::LoadTuningResults(const TuningResults& tr) { + ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch"); + LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep; + ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators)); + GetTuningResultsManager().Load(tr.results); + return Status::OK(); +} + KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const { std::scoped_lock l{lock_}; auto it = results_.find(op_signature); @@ -56,6 +76,7 @@ inline void AddImpl(const std::string& op_signature, return; } + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ") -> " << best_id; kernel_map[params_signature] = best_id; } @@ -70,6 +91,24 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin AddImpl(op_signature, params_signature, best_id, it->second); } +// NOLINTNEXTLINE(bugprone-easily-swappable-parameters) +void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + return; + } + + auto it2 = it->second.find(params_signature); + if (it2 == it->second.end()) { + return; + } + + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ")"; + it->second.erase(it2); +} + std::unordered_map TuningResultsManager::Dump() const { std::scoped_lock l{lock_}; return results_; @@ -81,6 +120,9 @@ void DisjointMergeImpl( /*out*/ std::unordered_map& results) { auto it = results.find(op_signature); if (it == results.end()) { + for(const auto& [param_sig, kernel_id] : kernel_map) { + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << param_sig << ") -> " << kernel_id; + } results[op_signature] = kernel_map; return; } @@ -106,4 +148,138 @@ void TuningResultsManager::Clear() { results_ = {}; } +static Status CheckMandatoryKeys( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + bool passed = true; + std::ostringstream oss; + for (const auto& k : TuningResultsValidator::mandatory_keys) { + if (gv_funcs.find(k) == gv_funcs.end()) { + passed = false; + oss << "key=\"" << k << "\" is not registered for Get and Validate. "; + } + + if (to_check.find(k) == to_check.end()) { + passed = false; + oss << "key=\"" << k << "\" is not provided for validation. "; + } + } + ORT_RETURN_IF(!passed, oss.str()); + return Status::OK(); +} + +static Status CheckKeysMatching( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + auto get_keys = [](const auto& it) -> std::string { return it.first; }; + std::vector required_keys; + std::vector provided_keys; + std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys); + std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys); + std::sort(required_keys.begin(), required_keys.end()); + std::sort(provided_keys.begin(), provided_keys.end()); + + std::unordered_set intersection; + std::set_intersection(required_keys.cbegin(), required_keys.cend(), + provided_keys.cbegin(), provided_keys.cend(), + std::inserter(intersection, intersection.end())); + bool matched = true; + std::ostringstream oss; + if (intersection.size() != required_keys.size()) { + matched = false; + for (const auto& k : required_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. "; + } + } + } + if (intersection.size() != provided_keys.size()) { + matched = false; + for (const auto& k : provided_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. "; + } + } + } + ORT_RETURN_IF(!matched, oss.str()); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtVersion() const { + return ORT_VERSION; +} + +Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const { + ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch"); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtGitCommit() const { + // TODO: + return ""; +} + +Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const { + // TODO: + ORT_UNUSED_PARAMETER(value); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtBuildConfig() const { + return ""; +} + +Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const { + auto current = GetOrtBuildConfig(); + ORT_RETURN_IF(current != value, + "onnxruntime building configuration mismatch: tuning results produced with library \"", + value, "\", current library built with \"", current, "\""); + return Status::OK(); +} + +TuningResultsValidator::TuningResultsValidator() { + RegisterValidator( + "ORT_VERSION", + [this]() { return GetOrtVersion(); }, + [this](auto&& k) { return ValidateOrtVersion(std::forward(k)); }); + + RegisterValidator( + "ORT_GIT_COMMIT", + [this]() { return GetOrtGitCommit(); }, + [this](auto&& k) { return ValidateOrtGitCommit(std::forward(k)); }); + + RegisterValidator( + "ORT_BUILD_CONFIG", + [this]() { return GetOrtBuildConfig(); }, + [this](auto&& k) { return ValidateOrtBuildConfig(std::forward(k)); }); +} + +Status TuningResultsValidator::ValidateAll(const std::unordered_map& to_validate) const { + ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate)); + ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate)); + + for (const auto& [key, value] : to_validate) { + const auto& it = validators_.find(key); + ORT_ENFORCE(it != validators_.cend()); + const ValidateFunc& validator = it->second.second; + ORT_RETURN_IF_ERROR(validator(value)); + } + + return Status::OK(); +} + +std::unordered_map TuningResultsValidator::GetAllValidators() const { + std::unordered_map ret; + for (const auto& [key, get_validate_func_pair] : validators_) { + const GetFunc& getter = get_validate_func_pair.first; + ret[key] = getter(); + } + return ret; +} + +void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) { + ORT_ENFORCE(validators_.find(key) == validators_.end()); + validators_[key] = std::make_pair(gf, vf); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index 55c79273b1..bea2889ca8 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -14,7 +14,40 @@ namespace onnxruntime { namespace cuda { namespace tunable { -CudaTuningContext::CudaTuningContext(CUDAExecutionProvider*, TunableOpInfo* info) : info_(info) {} +static std::string GetCudaVersion() { + int version; + CUDA_CALL_THROW(cudaRuntimeGetVersion(&version)); + return std::to_string(version); +} + +static Status ValidateCudaVersion(const std::string& value) { + auto current = GetCudaVersion(); + ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value, + ", onnxruntime currently run with CUDA ", current); + return Status::OK(); +} + +std::string CudaTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status CudaTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep) : ep_(ep) { + RegisterValidator("CUDA_VERSION", GetCudaVersion, ValidateCudaVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + +CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void CudaTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; @@ -38,6 +71,10 @@ const TuningResultsManager& CudaTuningContext::GetTuningResultsManager() const { return manager_; } +const TuningResultsValidator& CudaTuningContext::GetTuningResultsValidator() const { + return validator_; +} + } // namespace tunable } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h index df47c53c7a..10d0782f5b 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h @@ -15,6 +15,18 @@ class CUDAExecutionProvider; namespace cuda { namespace tunable { +class CudaTuningResultsValidator : public TuningResultsValidator { + public: + CudaTuningResultsValidator(CUDAExecutionProvider* ep); + + protected: + std::string GetDeviceModel() const; + Status ValidateDeviceModel(const std::string& value) const; + + private: + CUDAExecutionProvider* ep_; // non-owning handle +}; + class CudaTuningContext : public ITuningContext { public: explicit CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info); @@ -26,9 +38,12 @@ class CudaTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; + const TuningResultsValidator& GetTuningResultsValidator() const override; + private: TunableOpInfo* info_; // non-owning handle TuningResultsManager manager_; + CudaTuningResultsValidator validator_; }; } // namespace tunable diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index b2a5eab94a..cbe9129c87 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -14,7 +14,7 @@ #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "core/providers/rocm/tunable/rocm_tuning_context.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index c2888d7664..bead003e83 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -14,7 +14,66 @@ namespace onnxruntime { namespace rocm { namespace tunable { -RocmTuningContext::RocmTuningContext(ROCMExecutionProvider*, TunableOpInfo* info) : info_(info) {} +static std::string GetHipVersion() { + int version; + HIP_CALL_THROW(hipRuntimeGetVersion(&version)); + return std::to_string(version); +} + +static Status ValidateHipVersion(const std::string& value) { + auto current = GetHipVersion(); + ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value, + ", onnxruntime currently run with HIP ", current); + return Status::OK(); +} + +static std::string GetRocBlasVersion() { + char buf[64]; + ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256)); + buf[63] = '\0'; + return buf; +} + +static Status ValidateRocBlasVersion(const std::string& value) { + auto current = GetRocBlasVersion(); + ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value, + ", onnxruntime currently run with rocblas ", current); + return Status::OK(); +} + +std::string RocmTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { + RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); + RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + +std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { + std::ostringstream oss; + oss << "USE_CK=" << USE_COMPOSABLE_KERNEL << "|"; +#ifdef USE_ROCBLAS_EXTENSION_API + oss << "USE_ROCBLAS_EXTENSION_API=" << 1 << "|"; +#else + oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|"; +#endif + return oss.str(); +} + +RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void RocmTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; @@ -38,6 +97,10 @@ const TuningResultsManager& RocmTuningContext::GetTuningResultsManager() const { return manager_; } +const TuningResultsValidator& RocmTuningContext::GetTuningResultsValidator() const { + return validator_; +} + } // namespace tunable } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h index d6eb0886dd..cffc0e5614 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h @@ -15,6 +15,20 @@ class ROCMExecutionProvider; namespace rocm { namespace tunable { +class RocmTuningResultsValidator : public TuningResultsValidator { + public: + RocmTuningResultsValidator(ROCMExecutionProvider* ep); + + protected: + std::string GetOrtBuildConfig() const override; + + std::string GetDeviceModel() const; + Status ValidateDeviceModel(const std::string& value) const; + + private: + ROCMExecutionProvider* ep_; // non-owning handle +}; + class RocmTuningContext : public ITuningContext { public: explicit RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info); @@ -26,9 +40,12 @@ class RocmTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; + const TuningResultsValidator& GetTuningResultsValidator() const override; + private: TunableOpInfo* info_; // non-owning handle TuningResultsManager manager_; + RocmTuningResultsValidator validator_; }; } // namespace tunable diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 79068a0271..79766bb7b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1500,6 +1500,14 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); } } + + std::vector tuning_results; + bool found_tuning_results = false; + ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( + model_metadata_, tuning_results, found_tuning_results)); + if (found_tuning_results) { + ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results)); + } #endif // !defined(ORT_MINIMAL_BUILD) // Resolve memory pattern flags of the main graph and subgraph session states @@ -2182,6 +2190,50 @@ const profiling::Profiler& InferenceSession::GetProfiling() const { return session_profiler_; } +#if !defined(ORT_MINIMAL_BUILD) +std::vector InferenceSession::GetTuningResults() const { + std::vector ret; + for (const auto& provider : execution_providers_) { + const auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx != nullptr) { + ret.emplace_back(tuning_ctx->GetTuningResults()); + } + } + return ret; +} + +Status InferenceSession::SetTuningResults(const std::vector& trs, bool error_on_invalid) { + std::string msg; + + for (size_t i = 0; i < trs.size(); i++) { + const auto& tr = trs[i]; + auto* provider = execution_providers_.Get(tr.ep); + if (provider == nullptr) { + msg = MakeString("Cannot find execution provider ", tr.ep); + ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; + continue; + } + + auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx == nullptr) { + msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp."); + ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; + continue; + } + + auto status = tuning_ctx->LoadTuningResults(tr); + if (!status.IsOK()) { + msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage()); + ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; + } + } + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { return session_state_->GetAllocator(mem_info); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f01523c923..95b0dde281 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -18,6 +18,7 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" +#include "core/framework/tuning_results.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" @@ -448,6 +449,23 @@ class InferenceSession { */ const profiling::Profiler& GetProfiling() const; +#if !defined(ORT_MINIMAL_BUILD) + /** + * Get the TuningResults of TunableOp for every execution providers. + * @return The TuningResults of each execution provider. + */ + std::vector GetTuningResults() const; + + /** + * Set the TuningResults back to each execution provider. Mainly for offline tuning. + * @param trs is the list of TuningResults to be loaded. + * @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced. + * @return OK if success. + */ + Status SetTuningResults(const std::vector& trs, bool error_on_invalid = false); +#endif + + #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) MemoryProfiler& GetMemoryProfiler() { return memory_profiler_; diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index d938228c3a..3e2d03a930 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -103,6 +103,13 @@ static Status SetEnableProfiling(SessionOptions& session_options, return Status::OK(); } +// This function is called by nlohmann/json +void from_json(const json& j, TuningResults& trs) { + j.at("ep").get_to(trs.ep); + j.at("results").get_to(trs.results); + j.at("validators").get_to(trs.validators); +} + //--------------------------------------------------- //--- end of session options related helpers --- //--------------------------------------------------- @@ -227,6 +234,36 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options "Parsing RunOptions from ModelProto is not supported yet"); } +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + std::vector& results, + bool& key_found) { + results.clear(); + key_found = false; + auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); + if (it == metadata.custom_metadata_map.end()) { + return Status::OK(); + } + + key_found = true; + LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; + + Status status; + ORT_TRY { + auto parsed_tuning_results_json = json::parse(it->second); + results = parsed_tuning_results_json.get>(); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, + "Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring..."); + }); + ORT_RETURN_IF_ERROR(status); + } + + return Status::OK(); +} + } // namespace inference_session_utils } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index 3b021bfd86..a0bcdb9013 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -12,6 +12,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "core/framework/session_options.h" +#include "core/framework/tuning_results.h" #include "core/common/common.h" #include "nlohmann/json.hpp" using json = nlohmann::json; @@ -30,6 +31,7 @@ static constexpr const char* kOrtLoadConfigFromModelEnvVar = "ORT_LOAD_CONFIG_FR // static constexpr const char* kOrtConfigKey = "ort_config"; static constexpr const char* kSessionOptionsKey = "session_options"; +static constexpr const char* kTuningResultsKeys = "tuning_results"; class JsonConfigParser { public: @@ -56,6 +58,10 @@ class JsonConfigParser { bool is_ort_config_json_available_ = false; }; +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + /*out*/ std::vector& results, + /*out*/ bool& key_found); + #endif // !defined(ORT_MINIMAL_BUILD) } // namespace inference_session_utils diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f733c13b6d..0883c528c9 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -286,6 +286,12 @@ class Session: """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) + def get_tuning_results(self): + return self._sess.get_tuning_results() + + def set_tuning_results(self, results, *, error_on_invalid=False): + return self._sess.set_tuning_results(results, error_on_invalid) + def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices): """ Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 490eb92afc..de14b9aff3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1664,6 +1664,57 @@ including arg name, arg type (contains both type and shape).)pbdoc") status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + }) + .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { +#if !defined(ORT_MINIMAL_BUILD) + py::list ret; + for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) { + py::dict py_trs; + py_trs["ep"] = trs.ep; + py_trs["results"] = trs.results; + py_trs["validators"] = trs.validators; + ret.append(std::move(py_trs)); + } + + return ret; +#else + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); +#endif + }) + .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { +#if !defined(ORT_MINIMAL_BUILD) + std::vector tuning_results; + for (auto handle: results) { + auto py_trs = handle.cast(); + TuningResults trs; + trs.ep = py_trs["ep"].cast(); + + for (const auto [py_op_sig, py_kernel_map]: py_trs["results"].cast()) { + KernelMap kernel_map; + for (const auto [py_params_sig, py_kernel_id]: py_kernel_map.cast()) { + kernel_map[py_params_sig.cast()] = py_kernel_id.cast(); + } + trs.results[py_op_sig.cast()] = kernel_map; + } + + for (const auto [k, v]: py_trs["validators"].cast()) { + trs.validators[k.cast()] = v.cast(); + } + + tuning_results.emplace_back(std::move(trs)); + } + + Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid); + if (!status.IsOK()) { + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + } +#else + ORT_UNUSED_PARAMETER(sess); + ORT_UNUSED_PARAMETER(results); + ORT_UNUSED_PARAMETER(error_on_invalid); + ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); +#endif }); py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 8d3535cabc..d68882f0a0 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -20,8 +20,35 @@ namespace { // test on CPU and it does not use stream using StreamT = void*; +constexpr static const char* kTestKey = "THE_TEST_KEY"; +constexpr static const char* kValidTestValue = "THE_VALID_TEST_VALUE"; + +static std::string GetTestValue() { + return kValidTestValue; +} + +static Status ValidateTestValue(const std::string& value) { + auto current = GetTestValue(); + ORT_RETURN_IF(current != value, "Only ", kValidTestValue, " is valid for key ", kTestKey); + return Status::OK(); +} + +class TestTuningResultsValidator : public TuningResultsValidator { + public: + TestTuningResultsValidator() { + RegisterValidator(kTestKey, GetTestValue, ValidateTestValue); + }; + + protected: + std::string GetOrtBuildConfig() const override { + return "TEST_BUILD"; + } +}; + class TestTuningContext : public ITuningContext { public: + using ITuningContext::ITuningContext; + void EnableTunableOp() override { tuning_enabled_ = true; } void DisableTunableOp() override { tuning_enabled_ = false; } bool IsTunableOpEnabled() const override { return tuning_enabled_; } @@ -29,14 +56,19 @@ class TestTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override { return manager_; } const TuningResultsManager& GetTuningResultsManager() const override { return manager_; } + const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; } + + void ClearCache() { manager_.Clear(); } + private: bool tuning_enabled_{false}; TuningResultsManager manager_{}; + TestTuningResultsValidator validator_{}; }; class TestEP : public IExecutionProvider { static constexpr const char* kEPType = "TestEP"; - TestTuningContext tuning_ctx_{}; + TestTuningContext tuning_ctx_{this}; public: TestEP() : IExecutionProvider{kEPType, true} {} @@ -45,6 +77,7 @@ class TestEP : public IExecutionProvider { return const_cast(&tuning_ctx_); } + void ClearCache() { tuning_ctx_.ClearCache(); } }; class TestTimer : public ITimer { @@ -584,12 +617,59 @@ TEST(TuningContext, TunableOpRespectTuningContext) { { ASSERT_EQ(mgr.Lookup(op.Signature()).size(), 0u); - // TunableOp(...), respect the existing entry + // TunableOp(...), respect the existing entry (manually loaded) if id in bound mgr.Add(op.Signature(), params.Signature(), tuning::TunableVecAddSelectFast::kSlowFullId); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_EQ(last_run, "SlowFull"); } + + last_run.clear(); + mgr.Clear(); + { + // TunableOp(...), must not respect the existing entry if id not in bound + // manually create an out of bound id + mgr.Add(op.Signature(), params.Signature(), 1000000); + auto status = op(¶ms); + ASSERT_TRUE(status.IsOK()) << "TunableOp should recover from an out of bound id"; + ASSERT_EQ(last_run, "FastFull"); + ASSERT_EQ(mgr.Lookup(op.Signature(), params.Signature()), tuning::TunableVecAddSelectFast::kFastFullId); + } +#endif +} + +TEST(TuningContext, GetAndLoadTuningResults) { +#ifdef ORT_NO_RTTI + GTEST_SKIP() << "TunableOp needs RTTI to work correctly"; +#else + constexpr const int a = 7500000; + constexpr const int b = 42; + int c{}; + tuning::VecAddParamsRecordLastRun params(&a, &b, &c, 1, 0); + std::string last_run; + params.last_run = &last_run; + + tuning::TunableVecAddSelectFast op{}; + auto* ctx = params.TuningContext(); + ctx->EnableTunableOp(); + + auto status = op(¶ms); + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(last_run, "FastFull"); + + auto trs = ctx->GetTuningResults(); + ASSERT_EQ(trs.ep, "TestEP"); + + ASSERT_EQ(trs.validators.size(), TestTuningResultsValidator::mandatory_keys.size() + 1); + for (const auto& key : TestTuningResultsValidator::mandatory_keys) { + ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(key))); + } + ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(kTestKey))); + + ASSERT_EQ(trs.results.size(), 1u); + ASSERT_THAT(trs.results, ::testing::Contains(::testing::Key(op.Signature()))); + ASSERT_THAT(trs.results[op.Signature()], ::testing::Contains(::testing::Key(params.Signature()))); + ASSERT_EQ(trs.results[op.Signature()][params.Signature()], tuning::TunableVecAddSelectFast::kFastFullId); #endif } diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 89fd90ad3a..8232044a29 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. # pylint: disable=C0116,W0212,R1720,C0114 -# -*- coding: UTF-8 -*- +import copy import gc import os import platform @@ -387,6 +387,89 @@ class TestInferenceSession(unittest.TestCase): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) self.assertEqual(["CPUExecutionProvider"], sess.get_providers()) + def testGetAndSetTuningResults(self): + def getTuningResultsForEp(sess, ep): # without the outer list + tuning_results = sess.get_tuning_results() + self.assertGreaterEqual(len(tuning_results), 1) + tuning_results_for_this_ep = [t for t in tuning_results if t.get("ep") == ep] + self.assertEqual(len(tuning_results_for_this_ep), 1) + return tuning_results_for_this_ep[0] + + probe_op_sig = "probe_but_not_an_op_signature" + probe_params_sig = "probe_but_not_an_params_signature" + probe_value = 10000000 + + def copyTuningResultsWithProbe(tr): + tr = copy.deepcopy(tr) + tr["results"][probe_op_sig] = {probe_params_sig: probe_value} + return tr + + def assertTuningResultsLoaded(sess, ep): + tr = getTuningResultsForEp(sess, ep) + self.assertIn(probe_op_sig, tr["results"]) + self.assertEqual(tr["results"][probe_op_sig], {probe_params_sig: probe_value}) + + def assertTuningResultsNotLoaded(sess, ep): + tr = getTuningResultsForEp(sess, ep) + self.assertNotIn(probe_op_sig, tr["results"]) + + def doTestGetAndSetTuningResults(ep): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=[ep]) + tuning_results = getTuningResultsForEp(sess, ep) + + self.assertIn("ep", tuning_results) + self.assertIn("results", tuning_results) + self.assertIn("validators", tuning_results) + self.assertIn("ORT_VERSION", tuning_results["validators"]) + self.assertNotIn("NOT_A_VALIDATOR_KEY", tuning_results["validators"]) + + # invalid EP will be rejected + invalid_unknown_ep = copyTuningResultsWithProbe(tuning_results) + invalid_unknown_ep["ep"] = "UnknownEP" + sess.set_tuning_results([invalid_unknown_ep]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([invalid_unknown_ep], error_on_invalid=True) + self.assertIn("Cannot find execution provider UnknownEP", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + # missing validator key will be rejected + mismatched_validator_key_missing = copyTuningResultsWithProbe(tuning_results) + mismatched_validator_key_missing["validators"].pop("ORT_VERSION") + sess.set_tuning_results([mismatched_validator_key_missing]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([mismatched_validator_key_missing], error_on_invalid=True) + self.assertIn("ORT_VERSION", str(context.exception)) + self.assertIn("is not provided for validation", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + mismatched_validator_key_extra = copyTuningResultsWithProbe(tuning_results) + mismatched_validator_key_extra["validators"]["NOT_A_VALIDATOR_KEY"] = "NOT_USED" + sess.set_tuning_results([mismatched_validator_key_extra]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([mismatched_validator_key_extra], error_on_invalid=True) + self.assertIn("NOT_A_VALIDATOR_KEY", str(context.exception)) + self.assertIn("is unable to consume it", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + validation_failure = copyTuningResultsWithProbe(tuning_results) + validation_failure["validators"]["ORT_VERSION"] = "This is not a proper ORT_VERSION value!" + sess.set_tuning_results([validation_failure]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([validation_failure], error_on_invalid=True) + self.assertIn("Failed to load TuningResults", str(context.exception)) + self.assertIn("version mismatch", str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + loadable = copyTuningResultsWithProbe(tuning_results) + sess.set_tuning_results([loadable], error_on_invalid=True) + assertTuningResultsLoaded(sess, ep) + + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + doTestGetAndSetTuningResults("CUDAExecutionProvider") + + if "ROCMExecutionProvider" in onnxrt.get_available_providers(): + doTestGetAndSetTuningResults("ROCMExecutionProvider") + def testRunModel(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) diff --git a/tools/python/offline_tuning.py b/tools/python/offline_tuning.py new file mode 100644 index 0000000000..8dbae5efe8 --- /dev/null +++ b/tools/python/offline_tuning.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import copy +import json +import sys +from collections import OrderedDict +from pprint import pprint +from typing import Any, Dict, List + +import onnx + +TuningResults = Dict[str, Any] + +_TUNING_RESULTS_KEY = "tuning_results" + + +def _find_tuning_results_in_props(metadata_props): + for idx, prop in enumerate(metadata_props): + if prop.key == _TUNING_RESULTS_KEY: + return idx + return -1 + + +def extract(model: onnx.ModelProto): + idx = _find_tuning_results_in_props(model.metadata_props) + if idx < 0: + return None + + tuning_results_prop = model.metadata_props[idx] + return json.loads(tuning_results_prop.value) + + +def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False): + idx = _find_tuning_results_in_props(model.metadata_props) + assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!" + + if idx >= 0: + model.metadata_props.pop(idx) + + entry = model.metadata_props.add() + entry.key = _TUNING_RESULTS_KEY + entry.value = json.dumps(tuning_results) + return model + + +class Merger: + class EpAndValidators: + def __init__(self, ep: str, validators: Dict[str, str]): + self.ep = ep + self.validators = copy.deepcopy(validators) + self.key = (ep, tuple(sorted(validators.items()))) + + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.ep == other.ep and self.key == other.key + + def __init__(self): + self.ev_to_results = OrderedDict() + + def merge(self, tuning_results: List[TuningResults]): + for trs in tuning_results: + self._merge_one(trs) + + def get_merged(self): + tuning_results = [] + for ev, flat_results in self.ev_to_results.items(): + results = {} + trs = { + "ep": ev.ep, + "validators": ev.validators, + "results": results, + } + for (op_sig, params_sig), kernel_id in flat_results.items(): + kernel_map = results.setdefault(op_sig, {}) + kernel_map[params_sig] = kernel_id + tuning_results.append(trs) + return tuning_results + + def _merge_one(self, trs: TuningResults): + ev = Merger.EpAndValidators(trs["ep"], trs["validators"]) + flat_results = self.ev_to_results.setdefault(ev, {}) + for op_sig, kernel_map in trs["results"].items(): + for params_sig, kernel_id in kernel_map.items(): + if (op_sig, params_sig) not in flat_results: + flat_results[(op_sig, params_sig)] = kernel_id + + +def parse_args(): + parser = argparse.ArgumentParser() + sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd") + + extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.") + extract_parser.add_argument("input_onnx") + extract_parser.add_argument("output_json") + + embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.") + embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.") + embed_parser.add_argument("output_onnx", help="Path of the output onnx file.") + embed_parser.add_argument("input_onnx", help="Path of the input onnx file.") + embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.") + + merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.") + merge_parser.add_argument("output_json", help="Path of the output tuning results file.") + merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.") + + pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.") + pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.") + + args = parser.parse_args() + if len(vars(args)) == 0: + parser.print_help() + exit(-1) + return args + + +def main(): + args = parse_args() + if args.cmd == "extract": + tuning_results = extract(onnx.load_model(args.input_onnx)) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + sys.exit(-1) + json.dump(tuning_results, open(args.output_json, "w")) + elif args.cmd == "embed": + model = onnx.load_model(args.input_onnx) + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + model = embed(model, merger.get_merged(), args.force) + onnx.save_model(model, args.output_onnx) + elif args.cmd == "merge": + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + json.dump(merger.get_merged(), open(args.output_json, "w")) + elif args.cmd == "pprint": + tuning_results = None + try: + tuning_results = json.load(open(args.json_or_onnx, "r")) + except Exception: + # it might be an onnx file otherwise, try it latter + pass + + if tuning_results is None: + try: + model = onnx.load_model(args.json_or_onnx) + tuning_results = extract(model) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + sys.exit(-1) + except Exception: + pass + + if tuning_results is None: + sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!") + sys.exit(-1) + + pprint(tuning_results) + else: + # invalid choice will be handled by the parser + pass + + +if __name__ == "__main__": + main()