resolve the provider options before create training session in orttrainer (#9199)

* resolve the provider options before create training session in orttrainer

* Update orttraining/orttraining/python/orttraining_pybind_common.h

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>

* support clear the training ep instance pool

* fix status error

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
This commit is contained in:
Tang, Cheng 2021-10-12 09:30:45 -07:00 committed by GitHub
parent 52c021d1f3
commit 48737091c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 288 additions and 87 deletions

View file

@ -1 +1,2 @@
_GetProvider
_ProviderHashFunc

View file

@ -69,4 +69,25 @@ ORT_API(onnxruntime::Provider*, GetProvider) {
return &onnxruntime::g_provider;
}
ORT_API(size_t, ProviderHashFunc, const void* provider_options){
ProviderOptions* options = (ProviderOptions*)(provider_options);
MyProviderInfo info;
ORT_IGNORE_RETURN_VALUE(ProviderOptionsParser{}
.AddValueParser(
"device_id",
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
return Status::OK();
})
.AddValueParser(
"some_config",
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.some_config));
return Status::OK();
})
.Parse(*options));
// use device id as hash key
return info.device_id;
}
}

View file

@ -12,6 +12,8 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOpt
ORT_API(onnxruntime::Provider*, GetProvider);
ORT_API(size_t, ProviderHashFunc, const void* options);
#ifdef __cplusplus
}
#endif

View file

@ -1,2 +1,3 @@
EXPORTS
GetProvider
ProviderHashFunc

View file

@ -1,7 +1,8 @@
#_init and _fini should be local
VERS_1.0 {
global:
GetProvider;
GetProvider;
ProviderHashFunc;
# Hide everything else.
local:

View file

@ -5,6 +5,84 @@ import unittest
import torch
import onnxruntime_pybind11_state as torch_ort
import os
from io import StringIO
import sys
import threading
import time
class OutputGrabber(object):
"""
Class used to grab standard output or another stream.
"""
escape_char = "\b"
def __init__(self, stream=None, threaded=False):
self.origstream = stream
self.threaded = threaded
if self.origstream is None:
self.origstream = sys.stdout
self.origstreamfd = self.origstream.fileno()
self.capturedtext = ""
# Create a pipe so the stream can be captured:
self.pipe_out, self.pipe_in = os.pipe()
def __enter__(self):
self.start()
return self
def __exit__(self, type, value, traceback):
self.stop()
def start(self):
"""
Start capturing the stream data.
"""
self.capturedtext = ""
# Save a copy of the stream:
self.streamfd = os.dup(self.origstreamfd)
# Replace the original stream with our write pipe:
os.dup2(self.pipe_in, self.origstreamfd)
if self.threaded:
# Start thread that will read the stream:
self.workerThread = threading.Thread(target=self.readOutput)
self.workerThread.start()
# Make sure that the thread is running and os.read() has executed:
time.sleep(0.01)
def stop(self):
"""
Stop capturing the stream data and save the text in `capturedtext`.
"""
# Print the escape character to make the readOutput method stop:
self.origstream.write(self.escape_char)
# Flush the stream to make sure all our data goes in before
# the escape character:
self.origstream.flush()
if self.threaded:
# wait until the thread finishes so we are sure that
# we have until the last character:
self.workerThread.join()
else:
self.readOutput()
# Close the pipe:
os.close(self.pipe_in)
os.close(self.pipe_out)
# Restore the original stream:
os.dup2(self.streamfd, self.origstreamfd)
# Close the duplicate stream:
os.close(self.streamfd)
def readOutput(self):
"""
Read the stream data (one byte at a time)
and save the text in `capturedtext`.
"""
while True:
char = os.read(self.pipe_out,1).decode(self.origstream.encoding)
if not char or self.escape_char in char:
break
self.capturedtext += char
class OrtEPTests(unittest.TestCase):
def get_test_execution_provider_path(self):
@ -14,8 +92,27 @@ class OrtEPTests(unittest.TestCase):
torch_ort.set_device(0, 'CPUExecutionProvider', {})
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {})
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
ort_device = torch_ort.device(1)
# capture std out
with OutputGrabber() as out:
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
ort_device = torch_ort.device(1)
assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
with OutputGrabber() as out:
torch_ort.set_device(2, 'TestExecutionProvider', {'device_id':'1', 'some_config':'val'})
ort_device = torch_ort.device(1)
assert 'My EP provider created, with device id: 1, some_option: val' in out.capturedtext
# test the reusing EP instance
with OutputGrabber() as out:
torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
ort_device = torch_ort.device(1)
assert 'My EP provider created, with device id: 0, some_option: val' not in out.capturedtext
# test clear training ep instance pool
torch_ort.clear_training_ep_instances()
with OutputGrabber() as out:
torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
ort_device = torch_ort.device(1)
assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "python/onnxruntime_pybind_state_common.h"
#include "core/platform/env.h"
#include <unordered_map>
#include <cstdlib>
namespace onnxruntime {
namespace python {
namespace py = pybind11;
using namespace onnxruntime::logging;
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
class ORTTrainingPythonEnv{
public:
ORTTrainingPythonEnv();
Environment& GetORTEnv();
std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
size_t hash);
void AddExecutionProvider(const std::string& provider_type,
size_t hash,
std::unique_ptr<IExecutionProvider> execution_provider);
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options);
const std::vector<std::string>& GetAvailableTrainingExecutionProviderTypes();
ExecutionProviderLibInfoMap ext_execution_provider_info_map_;
void ClearExecutionProviderInstances();
private:
std::string GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash);
std::unique_ptr<Environment> ort_env_;
ExecutionProviderMap execution_provider_instances_map_;
std::vector<std::string> available_training_eps_;
};
}
}

View file

@ -20,6 +20,8 @@
#include "orttraining/training_ops/cpu/aten_ops/aten_op_executor.h"
#include "orttraining/python/orttraining_pybind_common.h"
#ifdef ENABLE_TRAINING_TORCH_INTEROP
#include "orttraining/core/framework/torch/custom_function_register.h"
#endif
@ -35,6 +37,33 @@ using namespace onnxruntime::logging;
using namespace onnxruntime::training;
Environment& GetTrainingORTEnv();
ORTTrainingPythonEnv& GetTrainingEnv();
void ResolveExtraProviderOptions(const std::vector<std::string>& provider_types,
const ProviderOptionsVector& original_provider_options_vector,
ProviderOptionsVector& merged_options){
auto& training_env = GetTrainingEnv();
std::size_t j = 0; // index for provider_options_vector
for (const std::string& type : provider_types) {
auto it = training_env.ext_execution_provider_info_map_.find(type);
if (it == training_env.ext_execution_provider_info_map_.end()){
if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) {
merged_options.push_back(original_provider_options_vector[j]);
}
}else{
ProviderOptions options = it->second.second;
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) {
for (auto [k, v] : original_provider_options_vector[j]){
options.insert({k, v});
}
}
merged_options.push_back(options);
}
j += 1;
}
}
struct TrainingParameters {
std::string loss_output_name;
@ -521,7 +550,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, provider_options);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
@ -535,8 +567,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, provider_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})

View file

@ -1,18 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "python/onnxruntime_pybind_state_common.h"
#include "orttraining/python/orttraining_pybind_common.h"
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/providers/get_execution_providers.h"
#include "core/platform/env.h"
#include "core/session/provider_bridge_ort.h"
#include <unordered_map>
#include <cstdlib>
namespace onnxruntime {
namespace python {
@ -41,8 +35,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
void addObjectMethodsForEager(py::module& m);
void InitArray();
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
bool GetDyanmicExecutionProviderHash(
const std::string& ep_shared_lib_path,
@ -131,61 +123,55 @@ bool GetProviderInstanceHash(const std::string& type,
return false;
}
class ORTTrainingPythonEnv{
public:
ORTTrainingPythonEnv(){
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::unique_ptr<ISink>{new CLogSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
ort_env_));
auto& builtinEPs = GetAvailableExecutionProviderNames();
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
}
ORTTrainingPythonEnv::ORTTrainingPythonEnv(){
OrtPybindThrowIfError(Environment::Create(std::make_unique<LoggingManager>(
std::unique_ptr<ISink>{new CLogSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
ort_env_));
auto& builtinEPs = GetAvailableExecutionProviderNames();
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
}
Environment& GetORTEnv(){
return *ort_env_;
}
Environment& ORTTrainingPythonEnv::GetORTEnv(){
return *ort_env_;
}
std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
size_t hash){
auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash));
return it == execution_provider_instances_map_.end() ? nullptr : it->second;
}
std::shared_ptr<IExecutionProvider> ORTTrainingPythonEnv::GetExecutionProviderInstance(const std::string& provider_type,
size_t hash){
auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash));
return it == execution_provider_instances_map_.end() ? nullptr : it->second;
}
void AddExecutionProvider(const std::string& provider_type,
size_t hash,
std::unique_ptr<IExecutionProvider> execution_provider){
execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash),
std::move(execution_provider)});
}
void ORTTrainingPythonEnv::AddExecutionProvider(const std::string& provider_type,
size_t hash,
std::unique_ptr<IExecutionProvider> execution_provider){
execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash),
std::move(execution_provider)});
}
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options){
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
available_training_eps_.push_back(provider_type);
}
void ORTTrainingPythonEnv::RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options){
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
available_training_eps_.push_back(provider_type);
}
const std::vector<std::string>& GetAvailableTrainingExecutionProviderTypes(){
return available_training_eps_;
}
const std::vector<std::string>& ORTTrainingPythonEnv::GetAvailableTrainingExecutionProviderTypes(){
return available_training_eps_;
}
ExecutionProviderLibInfoMap ext_execution_provider_info_map_;
std::string ORTTrainingPythonEnv::GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash){
std::string key(provider_type);
key.append(std::to_string(hash));
return key;
}
private:
std::string GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash){
std::string key(provider_type);
key.append(std::to_string(hash));
return key;
}
std::unique_ptr<Environment> ort_env_;
ExecutionProviderMap execution_provider_instances_map_;
std::vector<std::string> available_training_eps_;
};
void ORTTrainingPythonEnv::ClearExecutionProviderInstances(){
execution_provider_instances_map_.clear();
}
static std::unique_ptr<ORTTrainingPythonEnv> ort_training_env;
@ -218,24 +204,27 @@ Environment& GetTrainingORTEnv() {
return ort_training_env->GetORTEnv();
}
void ResolveExtraProviderOptions(const std::string& provider_type,
void ResolveExtraProviderOptions(const std::vector<std::string>& provider_types,
const ProviderOptionsMap& original_provider_options_map,
ProviderOptionsMap& merged_options){
auto& training_env = GetTrainingEnv();
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
if (it == training_env.ext_execution_provider_info_map_.end()){
//nothing changed.
merged_options = original_provider_options_map;
}else{
ProviderOptions options = it->second.second;
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
auto original_map_it = original_provider_options_map.find(provider_type);
if (original_map_it != original_provider_options_map.end()){
for (auto [k, v] : original_map_it->second){
options.insert({k, v});
for (auto& provider_type : provider_types){
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
if (it == training_env.ext_execution_provider_info_map_.end()){
//nothing changed.
if (original_provider_options_map.find(provider_type) != original_provider_options_map.end())
merged_options.insert({provider_type, original_provider_options_map.at(provider_type)});
}else{
ProviderOptions options = it->second.second;
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
auto original_map_it = original_provider_options_map.find(provider_type);
if (original_map_it != original_provider_options_map.end()){
for (auto [k, v] : original_map_it->second){
options.insert({k, v});
}
}
merged_options[provider_type] = options;
}
merged_options[provider_type] = options;
}
}
@ -243,27 +232,22 @@ std::unique_ptr<IExecutionProvider> CreateTrainingEP(
const SessionOptions& session_options,
const std::string& provider_type,
const ProviderOptionsMap& provider_options_map){
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
if (training_env.ext_execution_provider_info_map_.find(provider_type) !=
training_env.ext_execution_provider_info_map_.end()){
ProviderOptionsMap merged_options;
ResolveExtraProviderOptions(provider_type, provider_options_map, merged_options);
return CreateExecutionProviderInstance(session_options, provider_type, merged_options);
}else{
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
}
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
}
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
const ProviderOptionsMap& provider_options_map,
const SessionOptions& session_options){
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
// resolve provider options, because the hash key of ep depends on provider options.
ProviderOptionsMap merged_options;
ResolveExtraProviderOptions({provider_type}, provider_options_map, merged_options);
// search in environment
size_t hash;
if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){
if (GetProviderInstanceHash(provider_type, merged_options, hash)){
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
if (!cached_provider_instance){
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
auto ep = CreateTrainingEP(session_options, provider_type, merged_options);
if (ep){
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
@ -273,7 +257,7 @@ std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::stri
}
else{
// the EP doesn't support cache, register the instance to session
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
auto ep = CreateTrainingEP(session_options, provider_type, merged_options);
return ep;
}
}
@ -324,6 +308,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
"Return list of available Execution Providers in this installed version of Onnxruntime. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def("clear_training_ep_instances", []() -> void {
ort_training_env->ClearExecutionProviderInstances();
},
"Clean the execution provider instances used in ort training module.");
// clean the ort training environment when python interpreter exit
// otherwise the global var will be de-constrcut after user main.