From e5270e3b4ff45f76a2038779ccfc52a0bac6ec4d Mon Sep 17 00:00:00 2001 From: cao lei Date: Tue, 27 Jun 2023 15:10:42 -0700 Subject: [PATCH] shared allocator for on device training (#16432) ### Description New logic to share allocators among module, optimizer and eval sessions for Training scenario ### Motivation and Context Previously on device training using shared allocator by sharing EP, now with new mechanism to share allocator, we need to explicitly register allocator in the environment. --------- Co-authored-by: Lei Cao --- onnxruntime/core/framework/session_state.cc | 2 +- .../orttraining/training_api/onnxruntime_training_c_api.cc | 4 ++++ orttraining/orttraining/training_api/optimizer.cc | 6 +----- orttraining/orttraining/training_api/optimizer.h | 2 -- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 581fbd0c7e..73976d9b0b 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -95,7 +95,7 @@ SessionState::SessionState(Graph& graph, // The allocator registration rule: // Each location (OrtDevice) will only have 1 allocator used for whole session. // The EP which is registered first will have higher priority - for (auto ep : execution_providers_) { + for (auto& ep : execution_providers_) { auto allocators = ep->CreatePreferredAllocators(); for (auto& alloc : allocators) { allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 773ca93648..f7086cd802 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -6,6 +6,7 @@ #include "core/framework/error_code_helper.h" #include "core/framework/random_seed.h" #include "core/session/abi_session_options_impl.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" #include "orttraining/training_api/checkpoint.h" @@ -32,6 +33,9 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* e _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) { API_IMPL_BEGIN + if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training."); + } std::unique_ptr train_sess; auto chkpt_state = reinterpret_cast(checkpoint_state); OrtStatus* status = nullptr; diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index 3a295d4994..66ae991caa 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -198,7 +198,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, const Environment& env, const std::vector>& providers) : optim_sess_(std::make_unique(session_options, env)), state_(state) { - Initialize(optim_path_or_bytes, session_options, env, providers); + Initialize(optim_path_or_bytes, providers); ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null."); auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -215,11 +215,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, } void Optimizer::Initialize(const std::string& optim_path_or_bytes, - const onnxruntime::SessionOptions& session_options, - const Environment& env, const std::vector>& providers) { - optim_sess_ = std::make_unique(session_options, env); - for (const auto& execution_provider : providers) { ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider)); } diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index ffbe293c88..6a20106f11 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -121,8 +121,6 @@ struct Optimizer { private: void Initialize(const std::string& optim_path_or_bytes, - const onnxruntime::SessionOptions& session_options, - const Environment& env, const std::vector>& providers); int64_t GetStep() const {