mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
shared allocator for on device training (#16432)
### Description <!-- Describe your changes. --> New logic to share allocators among module, optimizer and eval sessions for Training scenario ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Previously on device training using shared allocator by sharing EP, now with new mechanism to share allocator, we need to explicitly register allocator in the environment. --------- Co-authored-by: Lei Cao <leca@microsoft.com>
This commit is contained in:
parent
1001ec93a7
commit
e5270e3b4f
4 changed files with 6 additions and 8 deletions
|
|
@ -95,7 +95,7 @@ SessionState::SessionState(Graph& graph,
|
|||
// The allocator registration rule:
|
||||
// Each location (OrtDevice) will only have 1 allocator used for whole session.
|
||||
// The EP which is registered first will have higher priority
|
||||
for (auto ep : execution_providers_) {
|
||||
for (auto& ep : execution_providers_) {
|
||||
auto allocators = ep->CreatePreferredAllocators();
|
||||
for (auto& alloc : allocators) {
|
||||
allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/random_seed.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
#include "core/session/ort_env.h"
|
||||
#include "orttraining/training_api/checkpoint.h"
|
||||
|
|
@ -32,6 +33,9 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* e
|
|||
_In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path,
|
||||
_In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) {
|
||||
API_IMPL_BEGIN
|
||||
if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") {
|
||||
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training.");
|
||||
}
|
||||
std::unique_ptr<onnxruntime::training::api::TrainingSession> train_sess;
|
||||
auto chkpt_state = reinterpret_cast<onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
OrtStatus* status = nullptr;
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
|
|||
const Environment& env,
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers)
|
||||
: optim_sess_(std::make_unique<InferenceSession>(session_options, env)), state_(state) {
|
||||
Initialize(optim_path_or_bytes, session_options, env, providers);
|
||||
Initialize(optim_path_or_bytes, providers);
|
||||
|
||||
ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null.");
|
||||
auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME);
|
||||
|
|
@ -215,11 +215,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
|
|||
}
|
||||
|
||||
void Optimizer::Initialize(const std::string& optim_path_or_bytes,
|
||||
const onnxruntime::SessionOptions& session_options,
|
||||
const Environment& env,
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers) {
|
||||
optim_sess_ = std::make_unique<InferenceSession>(session_options, env);
|
||||
|
||||
for (const auto& execution_provider : providers) {
|
||||
ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -121,8 +121,6 @@ struct Optimizer {
|
|||
|
||||
private:
|
||||
void Initialize(const std::string& optim_path_or_bytes,
|
||||
const onnxruntime::SessionOptions& session_options,
|
||||
const Environment& env,
|
||||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers);
|
||||
|
||||
int64_t GetStep() const {
|
||||
|
|
|
|||
Loading…
Reference in a new issue