diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst b/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst index f4c4141259..d078b2e688 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst +++ b/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst @@ -1 +1,2 @@ _GetProvider +_ProviderHashFunc diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc index ca041784d7..991675a125 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc @@ -69,4 +69,25 @@ ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } +ORT_API(size_t, ProviderHashFunc, const void* provider_options){ + ProviderOptions* options = (ProviderOptions*)(provider_options); + MyProviderInfo info; + ORT_IGNORE_RETURN_VALUE(ProviderOptionsParser{} + .AddValueParser( + "device_id", + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + return Status::OK(); + }) + .AddValueParser( + "some_config", + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.some_config)); + return Status::OK(); + }) + .Parse(*options)); + // use device id as hash key + return info.device_id; +} + } \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h index b342960935..e614d01d32 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h @@ -12,6 +12,8 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOpt ORT_API(onnxruntime::Provider*, GetProvider); +ORT_API(size_t, ProviderHashFunc, const void* options); + #ifdef __cplusplus } #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def b/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def index 4ec2f7914c..d26cac362a 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def +++ b/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def @@ -1,2 +1,3 @@ EXPORTS GetProvider + ProviderHashFunc diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds b/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds index b298a6d003..dd1d20cc14 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds +++ b/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds @@ -1,7 +1,8 @@ #_init and _fini should be local VERS_1.0 { global: - GetProvider; + GetProvider; + ProviderHashFunc; # Hide everything else. local: diff --git a/orttraining/orttraining/eager/test/ort_eps_test.py b/orttraining/orttraining/eager/test/ort_eps_test.py index f1cc141fef..1c935c766d 100644 --- a/orttraining/orttraining/eager/test/ort_eps_test.py +++ b/orttraining/orttraining/eager/test/ort_eps_test.py @@ -5,6 +5,84 @@ import unittest import torch import onnxruntime_pybind11_state as torch_ort import os +from io import StringIO +import sys +import threading +import time + + +class OutputGrabber(object): + """ + Class used to grab standard output or another stream. + """ + escape_char = "\b" + + def __init__(self, stream=None, threaded=False): + self.origstream = stream + self.threaded = threaded + if self.origstream is None: + self.origstream = sys.stdout + self.origstreamfd = self.origstream.fileno() + self.capturedtext = "" + # Create a pipe so the stream can be captured: + self.pipe_out, self.pipe_in = os.pipe() + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + self.stop() + + def start(self): + """ + Start capturing the stream data. + """ + self.capturedtext = "" + # Save a copy of the stream: + self.streamfd = os.dup(self.origstreamfd) + # Replace the original stream with our write pipe: + os.dup2(self.pipe_in, self.origstreamfd) + if self.threaded: + # Start thread that will read the stream: + self.workerThread = threading.Thread(target=self.readOutput) + self.workerThread.start() + # Make sure that the thread is running and os.read() has executed: + time.sleep(0.01) + + def stop(self): + """ + Stop capturing the stream data and save the text in `capturedtext`. + """ + # Print the escape character to make the readOutput method stop: + self.origstream.write(self.escape_char) + # Flush the stream to make sure all our data goes in before + # the escape character: + self.origstream.flush() + if self.threaded: + # wait until the thread finishes so we are sure that + # we have until the last character: + self.workerThread.join() + else: + self.readOutput() + # Close the pipe: + os.close(self.pipe_in) + os.close(self.pipe_out) + # Restore the original stream: + os.dup2(self.streamfd, self.origstreamfd) + # Close the duplicate stream: + os.close(self.streamfd) + + def readOutput(self): + """ + Read the stream data (one byte at a time) + and save the text in `capturedtext`. + """ + while True: + char = os.read(self.pipe_out,1).decode(self.origstream.encoding) + if not char or self.escape_char in char: + break + self.capturedtext += char class OrtEPTests(unittest.TestCase): def get_test_execution_provider_path(self): @@ -14,8 +92,27 @@ class OrtEPTests(unittest.TestCase): torch_ort.set_device(0, 'CPUExecutionProvider', {}) torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {}) - torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) - ort_device = torch_ort.device(1) + # capture std out + with OutputGrabber() as out: + torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) + ort_device = torch_ort.device(1) + assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext + with OutputGrabber() as out: + torch_ort.set_device(2, 'TestExecutionProvider', {'device_id':'1', 'some_config':'val'}) + ort_device = torch_ort.device(1) + assert 'My EP provider created, with device id: 1, some_option: val' in out.capturedtext + # test the reusing EP instance + with OutputGrabber() as out: + torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) + ort_device = torch_ort.device(1) + assert 'My EP provider created, with device id: 0, some_option: val' not in out.capturedtext + # test clear training ep instance pool + torch_ort.clear_training_ep_instances() + with OutputGrabber() as out: + torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) + ort_device = torch_ort.device(1) + assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/orttraining/orttraining/python/orttraining_pybind_common.h b/orttraining/orttraining/python/orttraining_pybind_common.h new file mode 100644 index 0000000000..6c208cd6f0 --- /dev/null +++ b/orttraining/orttraining/python/orttraining_pybind_common.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/onnxruntime_pybind_exceptions.h" +#include "python/onnxruntime_pybind_mlvalue.h" +#include "python/onnxruntime_pybind_state_common.h" + +#include "core/platform/env.h" +#include +#include + +namespace onnxruntime { +namespace python { +namespace py = pybind11; + +using namespace onnxruntime::logging; + +using ExecutionProviderMap = std::unordered_map >; +using ExecutionProviderLibInfoMap = std::unordered_map > ; + + +class ORTTrainingPythonEnv{ +public: + ORTTrainingPythonEnv(); + + Environment& GetORTEnv(); + + std::shared_ptr GetExecutionProviderInstance(const std::string& provider_type, + size_t hash); + + void AddExecutionProvider(const std::string& provider_type, + size_t hash, + std::unique_ptr execution_provider); + + void RegisterExtExecutionProviderInfo(const std::string& provider_type, + const std::string& provider_lib_path, + const ProviderOptions& default_options); + + const std::vector& GetAvailableTrainingExecutionProviderTypes(); + + ExecutionProviderLibInfoMap ext_execution_provider_info_map_; + + void ClearExecutionProviderInstances(); + +private: + std::string GetExecutionProviderMapKey(const std::string& provider_type, + size_t hash); + + std::unique_ptr ort_env_; + ExecutionProviderMap execution_provider_instances_map_; + std::vector available_training_eps_; +}; + +} +} diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 91a70f07b8..682c2eed20 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -20,6 +20,8 @@ #include "orttraining/training_ops/cpu/aten_ops/aten_op_executor.h" +#include "orttraining/python/orttraining_pybind_common.h" + #ifdef ENABLE_TRAINING_TORCH_INTEROP #include "orttraining/core/framework/torch/custom_function_register.h" #endif @@ -35,6 +37,33 @@ using namespace onnxruntime::logging; using namespace onnxruntime::training; Environment& GetTrainingORTEnv(); +ORTTrainingPythonEnv& GetTrainingEnv(); + +void ResolveExtraProviderOptions(const std::vector& provider_types, + const ProviderOptionsVector& original_provider_options_vector, + ProviderOptionsVector& merged_options){ + auto& training_env = GetTrainingEnv(); + std::size_t j = 0; // index for provider_options_vector + for (const std::string& type : provider_types) { + auto it = training_env.ext_execution_provider_info_map_.find(type); + if (it == training_env.ext_execution_provider_info_map_.end()){ + if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) { + merged_options.push_back(original_provider_options_vector[j]); + } + }else{ + ProviderOptions options = it->second.second; + options.insert({kExecutionProviderSharedLibraryPath, it->second.first}); + if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) { + for (auto [k, v] : original_provider_options_vector[j]){ + options.insert({k, v}); + } + } + merged_options.push_back(options); + } + + j += 1; + } +} struct TrainingParameters { std::string loss_output_name; @@ -521,7 +550,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn #endif const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, provider_options); + ProviderOptionsVector merged_options; + ResolveExtraProviderOptions(provider_types, provider_options, merged_options); + + InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); return config_result; }) @@ -535,8 +567,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); #endif const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); + ProviderOptionsVector merged_options; + ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, provider_options); + InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); return config_result; }) diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index f8aa62bbc8..024e5c0c5e 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -1,18 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "python/onnxruntime_pybind_exceptions.h" -#include "python/onnxruntime_pybind_mlvalue.h" -#include "python/onnxruntime_pybind_state_common.h" +#include "orttraining/python/orttraining_pybind_common.h" #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" #include "core/providers/get_execution_providers.h" - -#include "core/platform/env.h" #include "core/session/provider_bridge_ort.h" -#include -#include namespace onnxruntime { namespace python { @@ -41,8 +35,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn void addObjectMethodsForEager(py::module& m); void InitArray(); -using ExecutionProviderMap = std::unordered_map >; -using ExecutionProviderLibInfoMap = std::unordered_map > ; bool GetDyanmicExecutionProviderHash( const std::string& ep_shared_lib_path, @@ -131,61 +123,55 @@ bool GetProviderInstanceHash(const std::string& type, return false; } -class ORTTrainingPythonEnv{ -public: - ORTTrainingPythonEnv(){ - OrtPybindThrowIfError(Environment::Create(std::make_unique( - std::unique_ptr{new CLogSink{}}, - Severity::kWARNING, false, LoggingManager::InstanceType::Default, - &SessionObjectInitializer::default_logger_id), - ort_env_)); - auto& builtinEPs = GetAvailableExecutionProviderNames(); - available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end()); - } +ORTTrainingPythonEnv::ORTTrainingPythonEnv(){ + OrtPybindThrowIfError(Environment::Create(std::make_unique( + std::unique_ptr{new CLogSink{}}, + Severity::kWARNING, false, LoggingManager::InstanceType::Default, + &SessionObjectInitializer::default_logger_id), + ort_env_)); + auto& builtinEPs = GetAvailableExecutionProviderNames(); + available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end()); +} - Environment& GetORTEnv(){ - return *ort_env_; - } +Environment& ORTTrainingPythonEnv::GetORTEnv(){ + return *ort_env_; +} - std::shared_ptr GetExecutionProviderInstance(const std::string& provider_type, - size_t hash){ - auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash)); - return it == execution_provider_instances_map_.end() ? nullptr : it->second; - } +std::shared_ptr ORTTrainingPythonEnv::GetExecutionProviderInstance(const std::string& provider_type, + size_t hash){ + auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash)); + return it == execution_provider_instances_map_.end() ? nullptr : it->second; +} - void AddExecutionProvider(const std::string& provider_type, - size_t hash, - std::unique_ptr execution_provider){ - execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash), - std::move(execution_provider)}); - } +void ORTTrainingPythonEnv::AddExecutionProvider(const std::string& provider_type, + size_t hash, + std::unique_ptr execution_provider){ + execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash), + std::move(execution_provider)}); +} - void RegisterExtExecutionProviderInfo(const std::string& provider_type, - const std::string& provider_lib_path, - const ProviderOptions& default_options){ - ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}}); - if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end()) - available_training_eps_.push_back(provider_type); - } +void ORTTrainingPythonEnv::RegisterExtExecutionProviderInfo(const std::string& provider_type, + const std::string& provider_lib_path, + const ProviderOptions& default_options){ + ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}}); + if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end()) + available_training_eps_.push_back(provider_type); +} - const std::vector& GetAvailableTrainingExecutionProviderTypes(){ - return available_training_eps_; - } +const std::vector& ORTTrainingPythonEnv::GetAvailableTrainingExecutionProviderTypes(){ + return available_training_eps_; +} - ExecutionProviderLibInfoMap ext_execution_provider_info_map_; +std::string ORTTrainingPythonEnv::GetExecutionProviderMapKey(const std::string& provider_type, + size_t hash){ + std::string key(provider_type); + key.append(std::to_string(hash)); + return key; +} -private: - std::string GetExecutionProviderMapKey(const std::string& provider_type, - size_t hash){ - std::string key(provider_type); - key.append(std::to_string(hash)); - return key; - } - - std::unique_ptr ort_env_; - ExecutionProviderMap execution_provider_instances_map_; - std::vector available_training_eps_; -}; +void ORTTrainingPythonEnv::ClearExecutionProviderInstances(){ + execution_provider_instances_map_.clear(); +} static std::unique_ptr ort_training_env; @@ -218,24 +204,27 @@ Environment& GetTrainingORTEnv() { return ort_training_env->GetORTEnv(); } -void ResolveExtraProviderOptions(const std::string& provider_type, +void ResolveExtraProviderOptions(const std::vector& provider_types, const ProviderOptionsMap& original_provider_options_map, ProviderOptionsMap& merged_options){ auto& training_env = GetTrainingEnv(); - auto it = training_env.ext_execution_provider_info_map_.find(provider_type); - if (it == training_env.ext_execution_provider_info_map_.end()){ - //nothing changed. - merged_options = original_provider_options_map; - }else{ - ProviderOptions options = it->second.second; - options.insert({kExecutionProviderSharedLibraryPath, it->second.first}); - auto original_map_it = original_provider_options_map.find(provider_type); - if (original_map_it != original_provider_options_map.end()){ - for (auto [k, v] : original_map_it->second){ - options.insert({k, v}); + for (auto& provider_type : provider_types){ + auto it = training_env.ext_execution_provider_info_map_.find(provider_type); + if (it == training_env.ext_execution_provider_info_map_.end()){ + //nothing changed. + if (original_provider_options_map.find(provider_type) != original_provider_options_map.end()) + merged_options.insert({provider_type, original_provider_options_map.at(provider_type)}); + }else{ + ProviderOptions options = it->second.second; + options.insert({kExecutionProviderSharedLibraryPath, it->second.first}); + auto original_map_it = original_provider_options_map.find(provider_type); + if (original_map_it != original_provider_options_map.end()){ + for (auto [k, v] : original_map_it->second){ + options.insert({k, v}); + } } + merged_options[provider_type] = options; } - merged_options[provider_type] = options; } } @@ -243,27 +232,22 @@ std::unique_ptr CreateTrainingEP( const SessionOptions& session_options, const std::string& provider_type, const ProviderOptionsMap& provider_options_map){ - ORTTrainingPythonEnv& training_env = GetTrainingEnv(); - if (training_env.ext_execution_provider_info_map_.find(provider_type) != - training_env.ext_execution_provider_info_map_.end()){ - ProviderOptionsMap merged_options; - ResolveExtraProviderOptions(provider_type, provider_options_map, merged_options); - return CreateExecutionProviderInstance(session_options, provider_type, merged_options); - }else{ - return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map); - } + return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map); } std::shared_ptr GetOrCreateExecutionProvider(const std::string& provider_type, const ProviderOptionsMap& provider_options_map, const SessionOptions& session_options){ ORTTrainingPythonEnv& training_env = GetTrainingEnv(); + // resolve provider options, because the hash key of ep depends on provider options. + ProviderOptionsMap merged_options; + ResolveExtraProviderOptions({provider_type}, provider_options_map, merged_options); // search in environment size_t hash; - if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){ + if (GetProviderInstanceHash(provider_type, merged_options, hash)){ auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash); if (!cached_provider_instance){ - auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map); + auto ep = CreateTrainingEP(session_options, provider_type, merged_options); if (ep){ training_env.AddExecutionProvider(provider_type, hash, std::move(ep)); cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash); @@ -273,7 +257,7 @@ std::shared_ptr GetOrCreateExecutionProvider(const std::stri } else{ // the EP doesn't support cache, register the instance to session - auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map); + auto ep = CreateTrainingEP(session_options, provider_type, merged_options); return ep; } } @@ -324,6 +308,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { "Return list of available Execution Providers in this installed version of Onnxruntime. " "The order of elements represents the default priority order of Execution Providers " "from highest to lowest."); + + m.def("clear_training_ep_instances", []() -> void { + ort_training_env->ClearExecutionProviderInstances(); + }, + "Clean the execution provider instances used in ort training module."); // clean the ort training environment when python interpreter exit // otherwise the global var will be de-constrcut after user main.