shared allocator for on device training (#16432)

### Description
<!-- Describe your changes. -->
New logic to share allocators among module, optimizer and eval sessions
for Training scenario



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Previously on device training using shared allocator by sharing EP, now
with new mechanism to share allocator, we need to explicitly register
allocator in the environment.

---------

Co-authored-by: Lei Cao <leca@microsoft.com>
This commit is contained in:
cao lei 2023-06-27 15:10:42 -07:00 committed by GitHub
parent 1001ec93a7
commit e5270e3b4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 8 deletions

View file

@ -95,7 +95,7 @@ SessionState::SessionState(Graph& graph,
// The allocator registration rule:
// Each location (OrtDevice) will only have 1 allocator used for whole session.
// The EP which is registered first will have higher priority
for (auto ep : execution_providers_) {
for (auto& ep : execution_providers_) {
auto allocators = ep->CreatePreferredAllocators();
for (auto& alloc : allocators) {
allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key

View file

@ -6,6 +6,7 @@
#include "core/framework/error_code_helper.h"
#include "core/framework/random_seed.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/ort_apis.h"
#include "core/session/ort_env.h"
#include "orttraining/training_api/checkpoint.h"
@ -32,6 +33,9 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* e
_In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path,
_In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) {
API_IMPL_BEGIN
if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training.");
}
std::unique_ptr<onnxruntime::training::api::TrainingSession> train_sess;
auto chkpt_state = reinterpret_cast<onnxruntime::training::api::CheckpointState*>(checkpoint_state);
OrtStatus* status = nullptr;

View file

@ -198,7 +198,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers)
: optim_sess_(std::make_unique<InferenceSession>(session_options, env)), state_(state) {
Initialize(optim_path_or_bytes, session_options, env, providers);
Initialize(optim_path_or_bytes, providers);
ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null.");
auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME);
@ -215,11 +215,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
}
void Optimizer::Initialize(const std::string& optim_path_or_bytes,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers) {
optim_sess_ = std::make_unique<InferenceSession>(session_options, env);
for (const auto& execution_provider : providers) {
ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider));
}

View file

@ -121,8 +121,6 @@ struct Optimizer {
private:
void Initialize(const std::string& optim_path_or_bytes,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers);
int64_t GetStep() const {