mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
resolve the provider options before create training session in orttrainer (#9199)
* resolve the provider options before create training session in orttrainer * Update orttraining/orttraining/python/orttraining_pybind_common.h Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> * support clear the training ep instance pool * fix status error Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
This commit is contained in:
parent
52c021d1f3
commit
48737091c0
9 changed files with 288 additions and 87 deletions
|
|
@ -1 +1,2 @@
|
|||
_GetProvider
|
||||
_ProviderHashFunc
|
||||
|
|
|
|||
|
|
@ -69,4 +69,25 @@ ORT_API(onnxruntime::Provider*, GetProvider) {
|
|||
return &onnxruntime::g_provider;
|
||||
}
|
||||
|
||||
ORT_API(size_t, ProviderHashFunc, const void* provider_options){
|
||||
ProviderOptions* options = (ProviderOptions*)(provider_options);
|
||||
MyProviderInfo info;
|
||||
ORT_IGNORE_RETURN_VALUE(ProviderOptionsParser{}
|
||||
.AddValueParser(
|
||||
"device_id",
|
||||
[&info](const std::string& value_str) -> Status {
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
|
||||
return Status::OK();
|
||||
})
|
||||
.AddValueParser(
|
||||
"some_config",
|
||||
[&info](const std::string& value_str) -> Status {
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.some_config));
|
||||
return Status::OK();
|
||||
})
|
||||
.Parse(*options));
|
||||
// use device id as hash key
|
||||
return info.device_id;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -12,6 +12,8 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOpt
|
|||
|
||||
ORT_API(onnxruntime::Provider*, GetProvider);
|
||||
|
||||
ORT_API(size_t, ProviderHashFunc, const void* options);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
EXPORTS
|
||||
GetProvider
|
||||
ProviderHashFunc
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
#_init and _fini should be local
|
||||
VERS_1.0 {
|
||||
global:
|
||||
GetProvider;
|
||||
GetProvider;
|
||||
ProviderHashFunc;
|
||||
|
||||
# Hide everything else.
|
||||
local:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,84 @@ import unittest
|
|||
import torch
|
||||
import onnxruntime_pybind11_state as torch_ort
|
||||
import os
|
||||
from io import StringIO
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class OutputGrabber(object):
|
||||
"""
|
||||
Class used to grab standard output or another stream.
|
||||
"""
|
||||
escape_char = "\b"
|
||||
|
||||
def __init__(self, stream=None, threaded=False):
|
||||
self.origstream = stream
|
||||
self.threaded = threaded
|
||||
if self.origstream is None:
|
||||
self.origstream = sys.stdout
|
||||
self.origstreamfd = self.origstream.fileno()
|
||||
self.capturedtext = ""
|
||||
# Create a pipe so the stream can be captured:
|
||||
self.pipe_out, self.pipe_in = os.pipe()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start capturing the stream data.
|
||||
"""
|
||||
self.capturedtext = ""
|
||||
# Save a copy of the stream:
|
||||
self.streamfd = os.dup(self.origstreamfd)
|
||||
# Replace the original stream with our write pipe:
|
||||
os.dup2(self.pipe_in, self.origstreamfd)
|
||||
if self.threaded:
|
||||
# Start thread that will read the stream:
|
||||
self.workerThread = threading.Thread(target=self.readOutput)
|
||||
self.workerThread.start()
|
||||
# Make sure that the thread is running and os.read() has executed:
|
||||
time.sleep(0.01)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop capturing the stream data and save the text in `capturedtext`.
|
||||
"""
|
||||
# Print the escape character to make the readOutput method stop:
|
||||
self.origstream.write(self.escape_char)
|
||||
# Flush the stream to make sure all our data goes in before
|
||||
# the escape character:
|
||||
self.origstream.flush()
|
||||
if self.threaded:
|
||||
# wait until the thread finishes so we are sure that
|
||||
# we have until the last character:
|
||||
self.workerThread.join()
|
||||
else:
|
||||
self.readOutput()
|
||||
# Close the pipe:
|
||||
os.close(self.pipe_in)
|
||||
os.close(self.pipe_out)
|
||||
# Restore the original stream:
|
||||
os.dup2(self.streamfd, self.origstreamfd)
|
||||
# Close the duplicate stream:
|
||||
os.close(self.streamfd)
|
||||
|
||||
def readOutput(self):
|
||||
"""
|
||||
Read the stream data (one byte at a time)
|
||||
and save the text in `capturedtext`.
|
||||
"""
|
||||
while True:
|
||||
char = os.read(self.pipe_out,1).decode(self.origstream.encoding)
|
||||
if not char or self.escape_char in char:
|
||||
break
|
||||
self.capturedtext += char
|
||||
|
||||
class OrtEPTests(unittest.TestCase):
|
||||
def get_test_execution_provider_path(self):
|
||||
|
|
@ -14,8 +92,27 @@ class OrtEPTests(unittest.TestCase):
|
|||
torch_ort.set_device(0, 'CPUExecutionProvider', {})
|
||||
|
||||
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {})
|
||||
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
# capture std out
|
||||
with OutputGrabber() as out:
|
||||
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
|
||||
with OutputGrabber() as out:
|
||||
torch_ort.set_device(2, 'TestExecutionProvider', {'device_id':'1', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
assert 'My EP provider created, with device id: 1, some_option: val' in out.capturedtext
|
||||
# test the reusing EP instance
|
||||
with OutputGrabber() as out:
|
||||
torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
assert 'My EP provider created, with device id: 0, some_option: val' not in out.capturedtext
|
||||
# test clear training ep instance pool
|
||||
torch_ort.clear_training_ep_instances()
|
||||
with OutputGrabber() as out:
|
||||
torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
55
orttraining/orttraining/python/orttraining_pybind_common.h
Normal file
55
orttraining/orttraining/python/orttraining_pybind_common.h
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/onnxruntime_pybind_exceptions.h"
|
||||
#include "python/onnxruntime_pybind_mlvalue.h"
|
||||
#include "python/onnxruntime_pybind_state_common.h"
|
||||
|
||||
#include "core/platform/env.h"
|
||||
#include <unordered_map>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace onnxruntime::logging;
|
||||
|
||||
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
|
||||
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
|
||||
|
||||
|
||||
class ORTTrainingPythonEnv{
|
||||
public:
|
||||
ORTTrainingPythonEnv();
|
||||
|
||||
Environment& GetORTEnv();
|
||||
|
||||
std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
|
||||
size_t hash);
|
||||
|
||||
void AddExecutionProvider(const std::string& provider_type,
|
||||
size_t hash,
|
||||
std::unique_ptr<IExecutionProvider> execution_provider);
|
||||
|
||||
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
|
||||
const std::string& provider_lib_path,
|
||||
const ProviderOptions& default_options);
|
||||
|
||||
const std::vector<std::string>& GetAvailableTrainingExecutionProviderTypes();
|
||||
|
||||
ExecutionProviderLibInfoMap ext_execution_provider_info_map_;
|
||||
|
||||
void ClearExecutionProviderInstances();
|
||||
|
||||
private:
|
||||
std::string GetExecutionProviderMapKey(const std::string& provider_type,
|
||||
size_t hash);
|
||||
|
||||
std::unique_ptr<Environment> ort_env_;
|
||||
ExecutionProviderMap execution_provider_instances_map_;
|
||||
std::vector<std::string> available_training_eps_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include "orttraining/training_ops/cpu/aten_ops/aten_op_executor.h"
|
||||
|
||||
#include "orttraining/python/orttraining_pybind_common.h"
|
||||
|
||||
#ifdef ENABLE_TRAINING_TORCH_INTEROP
|
||||
#include "orttraining/core/framework/torch/custom_function_register.h"
|
||||
#endif
|
||||
|
|
@ -35,6 +37,33 @@ using namespace onnxruntime::logging;
|
|||
using namespace onnxruntime::training;
|
||||
|
||||
Environment& GetTrainingORTEnv();
|
||||
ORTTrainingPythonEnv& GetTrainingEnv();
|
||||
|
||||
void ResolveExtraProviderOptions(const std::vector<std::string>& provider_types,
|
||||
const ProviderOptionsVector& original_provider_options_vector,
|
||||
ProviderOptionsVector& merged_options){
|
||||
auto& training_env = GetTrainingEnv();
|
||||
std::size_t j = 0; // index for provider_options_vector
|
||||
for (const std::string& type : provider_types) {
|
||||
auto it = training_env.ext_execution_provider_info_map_.find(type);
|
||||
if (it == training_env.ext_execution_provider_info_map_.end()){
|
||||
if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) {
|
||||
merged_options.push_back(original_provider_options_vector[j]);
|
||||
}
|
||||
}else{
|
||||
ProviderOptions options = it->second.second;
|
||||
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
|
||||
if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) {
|
||||
for (auto [k, v] : original_provider_options_vector[j]){
|
||||
options.insert({k, v});
|
||||
}
|
||||
}
|
||||
merged_options.push_back(options);
|
||||
}
|
||||
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
|
||||
struct TrainingParameters {
|
||||
std::string loss_output_name;
|
||||
|
|
@ -521,7 +550,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
|
||||
|
||||
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, provider_options);
|
||||
ProviderOptionsVector merged_options;
|
||||
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
|
||||
|
||||
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
|
||||
|
||||
return config_result;
|
||||
})
|
||||
|
|
@ -535,8 +567,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
|
||||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
|
||||
ProviderOptionsVector merged_options;
|
||||
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
|
||||
|
||||
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, provider_options);
|
||||
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
|
||||
|
||||
return config_result;
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,18 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/onnxruntime_pybind_exceptions.h"
|
||||
#include "python/onnxruntime_pybind_mlvalue.h"
|
||||
#include "python/onnxruntime_pybind_state_common.h"
|
||||
#include "orttraining/python/orttraining_pybind_common.h"
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
|
||||
#include "core/platform/env.h"
|
||||
#include "core/session/provider_bridge_ort.h"
|
||||
#include <unordered_map>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
|
@ -41,8 +35,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
void addObjectMethodsForEager(py::module& m);
|
||||
void InitArray();
|
||||
|
||||
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
|
||||
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
|
||||
|
||||
bool GetDyanmicExecutionProviderHash(
|
||||
const std::string& ep_shared_lib_path,
|
||||
|
|
@ -131,61 +123,55 @@ bool GetProviderInstanceHash(const std::string& type,
|
|||
return false;
|
||||
}
|
||||
|
||||
class ORTTrainingPythonEnv{
|
||||
public:
|
||||
ORTTrainingPythonEnv(){
|
||||
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
|
||||
std::unique_ptr<ISink>{new CLogSink{}},
|
||||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
&SessionObjectInitializer::default_logger_id),
|
||||
ort_env_));
|
||||
auto& builtinEPs = GetAvailableExecutionProviderNames();
|
||||
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
|
||||
}
|
||||
ORTTrainingPythonEnv::ORTTrainingPythonEnv(){
|
||||
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
|
||||
std::unique_ptr<ISink>{new CLogSink{}},
|
||||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
&SessionObjectInitializer::default_logger_id),
|
||||
ort_env_));
|
||||
auto& builtinEPs = GetAvailableExecutionProviderNames();
|
||||
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
|
||||
}
|
||||
|
||||
Environment& GetORTEnv(){
|
||||
return *ort_env_;
|
||||
}
|
||||
Environment& ORTTrainingPythonEnv::GetORTEnv(){
|
||||
return *ort_env_;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
|
||||
size_t hash){
|
||||
auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash));
|
||||
return it == execution_provider_instances_map_.end() ? nullptr : it->second;
|
||||
}
|
||||
std::shared_ptr<IExecutionProvider> ORTTrainingPythonEnv::GetExecutionProviderInstance(const std::string& provider_type,
|
||||
size_t hash){
|
||||
auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash));
|
||||
return it == execution_provider_instances_map_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
void AddExecutionProvider(const std::string& provider_type,
|
||||
size_t hash,
|
||||
std::unique_ptr<IExecutionProvider> execution_provider){
|
||||
execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash),
|
||||
std::move(execution_provider)});
|
||||
}
|
||||
void ORTTrainingPythonEnv::AddExecutionProvider(const std::string& provider_type,
|
||||
size_t hash,
|
||||
std::unique_ptr<IExecutionProvider> execution_provider){
|
||||
execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash),
|
||||
std::move(execution_provider)});
|
||||
}
|
||||
|
||||
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
|
||||
const std::string& provider_lib_path,
|
||||
const ProviderOptions& default_options){
|
||||
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
|
||||
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
|
||||
available_training_eps_.push_back(provider_type);
|
||||
}
|
||||
void ORTTrainingPythonEnv::RegisterExtExecutionProviderInfo(const std::string& provider_type,
|
||||
const std::string& provider_lib_path,
|
||||
const ProviderOptions& default_options){
|
||||
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
|
||||
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
|
||||
available_training_eps_.push_back(provider_type);
|
||||
}
|
||||
|
||||
const std::vector<std::string>& GetAvailableTrainingExecutionProviderTypes(){
|
||||
return available_training_eps_;
|
||||
}
|
||||
const std::vector<std::string>& ORTTrainingPythonEnv::GetAvailableTrainingExecutionProviderTypes(){
|
||||
return available_training_eps_;
|
||||
}
|
||||
|
||||
ExecutionProviderLibInfoMap ext_execution_provider_info_map_;
|
||||
std::string ORTTrainingPythonEnv::GetExecutionProviderMapKey(const std::string& provider_type,
|
||||
size_t hash){
|
||||
std::string key(provider_type);
|
||||
key.append(std::to_string(hash));
|
||||
return key;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string GetExecutionProviderMapKey(const std::string& provider_type,
|
||||
size_t hash){
|
||||
std::string key(provider_type);
|
||||
key.append(std::to_string(hash));
|
||||
return key;
|
||||
}
|
||||
|
||||
std::unique_ptr<Environment> ort_env_;
|
||||
ExecutionProviderMap execution_provider_instances_map_;
|
||||
std::vector<std::string> available_training_eps_;
|
||||
};
|
||||
void ORTTrainingPythonEnv::ClearExecutionProviderInstances(){
|
||||
execution_provider_instances_map_.clear();
|
||||
}
|
||||
|
||||
static std::unique_ptr<ORTTrainingPythonEnv> ort_training_env;
|
||||
|
||||
|
|
@ -218,24 +204,27 @@ Environment& GetTrainingORTEnv() {
|
|||
return ort_training_env->GetORTEnv();
|
||||
}
|
||||
|
||||
void ResolveExtraProviderOptions(const std::string& provider_type,
|
||||
void ResolveExtraProviderOptions(const std::vector<std::string>& provider_types,
|
||||
const ProviderOptionsMap& original_provider_options_map,
|
||||
ProviderOptionsMap& merged_options){
|
||||
auto& training_env = GetTrainingEnv();
|
||||
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
|
||||
if (it == training_env.ext_execution_provider_info_map_.end()){
|
||||
//nothing changed.
|
||||
merged_options = original_provider_options_map;
|
||||
}else{
|
||||
ProviderOptions options = it->second.second;
|
||||
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
|
||||
auto original_map_it = original_provider_options_map.find(provider_type);
|
||||
if (original_map_it != original_provider_options_map.end()){
|
||||
for (auto [k, v] : original_map_it->second){
|
||||
options.insert({k, v});
|
||||
for (auto& provider_type : provider_types){
|
||||
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
|
||||
if (it == training_env.ext_execution_provider_info_map_.end()){
|
||||
//nothing changed.
|
||||
if (original_provider_options_map.find(provider_type) != original_provider_options_map.end())
|
||||
merged_options.insert({provider_type, original_provider_options_map.at(provider_type)});
|
||||
}else{
|
||||
ProviderOptions options = it->second.second;
|
||||
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
|
||||
auto original_map_it = original_provider_options_map.find(provider_type);
|
||||
if (original_map_it != original_provider_options_map.end()){
|
||||
for (auto [k, v] : original_map_it->second){
|
||||
options.insert({k, v});
|
||||
}
|
||||
}
|
||||
merged_options[provider_type] = options;
|
||||
}
|
||||
merged_options[provider_type] = options;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,27 +232,22 @@ std::unique_ptr<IExecutionProvider> CreateTrainingEP(
|
|||
const SessionOptions& session_options,
|
||||
const std::string& provider_type,
|
||||
const ProviderOptionsMap& provider_options_map){
|
||||
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
|
||||
if (training_env.ext_execution_provider_info_map_.find(provider_type) !=
|
||||
training_env.ext_execution_provider_info_map_.end()){
|
||||
ProviderOptionsMap merged_options;
|
||||
ResolveExtraProviderOptions(provider_type, provider_options_map, merged_options);
|
||||
return CreateExecutionProviderInstance(session_options, provider_type, merged_options);
|
||||
}else{
|
||||
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
|
||||
}
|
||||
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
|
||||
const ProviderOptionsMap& provider_options_map,
|
||||
const SessionOptions& session_options){
|
||||
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
|
||||
// resolve provider options, because the hash key of ep depends on provider options.
|
||||
ProviderOptionsMap merged_options;
|
||||
ResolveExtraProviderOptions({provider_type}, provider_options_map, merged_options);
|
||||
// search in environment
|
||||
size_t hash;
|
||||
if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){
|
||||
if (GetProviderInstanceHash(provider_type, merged_options, hash)){
|
||||
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
|
||||
if (!cached_provider_instance){
|
||||
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
|
||||
auto ep = CreateTrainingEP(session_options, provider_type, merged_options);
|
||||
if (ep){
|
||||
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
|
||||
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
|
||||
|
|
@ -273,7 +257,7 @@ std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::stri
|
|||
}
|
||||
else{
|
||||
// the EP doesn't support cache, register the instance to session
|
||||
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
|
||||
auto ep = CreateTrainingEP(session_options, provider_type, merged_options);
|
||||
return ep;
|
||||
}
|
||||
}
|
||||
|
|
@ -324,6 +308,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
"Return list of available Execution Providers in this installed version of Onnxruntime. "
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
|
||||
m.def("clear_training_ep_instances", []() -> void {
|
||||
ort_training_env->ClearExecutionProviderInstances();
|
||||
},
|
||||
"Clean the execution provider instances used in ort training module.");
|
||||
|
||||
// clean the ort training environment when python interpreter exit
|
||||
// otherwise the global var will be de-constrcut after user main.
|
||||
|
|
|
|||
Loading…
Reference in a new issue