onnxruntime/orttraining/orttraining/python/orttraining_python_module.cc
Tianlei Wu 23996bbbbe
[CUDA][ROCm][Training] Fix cuda/rocm provider info hash (#19398)
When I test a new provider option, the training pipeline failed. I found
that training uses hash code of provider info to try get provider
instance. If a provider option is not used in hashing, the provider
instance fetched from cache might have different configuration for that
option.

Here I fix the hashing to use all provider options (except the default
Arena config that cannot be set from python API since training is used
with PyTorch in most cases).

Fixed a few obvious typo in the touched files.

Add regression test cases.
2024-02-05 14:35:57 -08:00

369 lines
15 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/python/orttraining_pybind_common.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/common/path_string.h"
#include "core/providers/get_execution_providers.h"
#include "core/session/provider_bridge_ort.h"
#include "onnxruntime_config.h"
namespace onnxruntime {
namespace python {
namespace py = pybind11;
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
static constexpr bool HAS_COLLECTIVE_OPS = true;
#else
static constexpr bool HAS_COLLECTIVE_OPS = false;
#endif
using namespace onnxruntime::logging;
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
const SessionOptions& session_options,
const std::string& type,
const ProviderOptionsMap& provider_options_map);
#ifdef USE_CUDA
const CUDAExecutionProviderInfo GetCudaExecutionProviderInfo(ProviderInfo_CUDA* cuda_provider_info,
const ProviderOptionsMap& provider_options_map);
#endif
#ifdef USE_ROCM
const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* rocm_provider_info,
const ProviderOptionsMap& provider_options_map);
#endif
void addGlobalMethods(py::module& m);
void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn);
void addObjectMethodsForTraining(py::module& m);
void addObjectMethodsForEager(py::module& m);
#ifdef ENABLE_LAZY_TENSOR
void addObjectMethodsForLazyTensor(py::module& m);
#endif
bool InitArray();
bool GetDynamicExecutionProviderHash(
const std::string& ep_shared_lib_path,
const ProviderOptions& provider_options,
size_t& hash,
const std::string& entry_symbol_name = "ProviderHashFunc") {
void* handle;
const auto path_str = ToPathString(ep_shared_lib_path);
auto error = Env::Default().LoadDynamicLibrary(path_str, false, &handle);
if (!error.IsOK()) {
throw std::runtime_error(error.ErrorMessage());
}
try {
size_t (*PGetProviderHash)(const void*) = nullptr;
OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProviderHash));
if (PGetProviderHash) {
hash = PGetProviderHash(&provider_options);
return true;
}
return false;
} catch (...) {
// there is no ProvideHashFunc provide in the shared lib, which means it doesn't support cache
return false;
}
}
bool GetProviderInstanceHash(const std::string& type,
const ProviderOptionsMap& provider_options_map,
size_t& hash) {
// for built-in execution provider, currently only cpu / cuda / rocm support hash.
if (type == kCpuExecutionProvider) {
// for CPU, only 1 instance
hash = 0;
return true;
} else if (type == kCudaExecutionProvider) {
#ifdef USE_CUDA
if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) {
const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info,
provider_options_map);
hash = std::hash<CUDAExecutionProviderInfo>{}(info);
return true;
}
#endif
} else if (type == kRocmExecutionProvider) {
#ifdef USE_ROCM
if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) {
const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info,
provider_options_map);
hash = std::hash<ROCMExecutionProviderInfo>{}(info);
return true;
}
#endif
} else {
const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
auto shared_lib_path_it = it->second.find(kExecutionProviderSharedLibraryPath);
if (shared_lib_path_it != it->second.end()) {
// this is an EP with dynamic loading
// construct the provider option
ProviderOptions provider_options;
std::string entry_symbol = kDefaultExecutionProviderEntry;
for (auto option : it->second) {
if (option.first == kExecutionProviderSharedLibraryEntry) {
entry_symbol = option.second;
} else if (option.first != kExecutionProviderSharedLibraryPath) {
provider_options.insert(option);
}
}
return GetDynamicExecutionProviderHash(shared_lib_path_it->second, provider_options, hash);
}
}
}
return false;
}
ORTTrainingPythonEnv::ORTTrainingPythonEnv() : ort_env_(GetEnv()) {
const auto& builtinEPs = GetAvailableExecutionProviderNames();
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
}
std::shared_ptr<Environment> ORTTrainingPythonEnv::GetORTEnv() const {
return ort_env_;
}
std::shared_ptr<IExecutionProvider> ORTTrainingPythonEnv::GetExecutionProviderInstance(const std::string& provider_type,
size_t hash) {
auto it = execution_provider_instances_map_.find(GetExecutionProviderMapKey(provider_type, hash));
return it == execution_provider_instances_map_.end() ? nullptr : it->second;
}
void ORTTrainingPythonEnv::AddExecutionProvider(const std::string& provider_type,
size_t hash,
std::unique_ptr<IExecutionProvider> execution_provider) {
execution_provider_instances_map_.insert({GetExecutionProviderMapKey(provider_type, hash),
std::move(execution_provider)});
}
void ORTTrainingPythonEnv::RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options) {
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
available_training_eps_.push_back(provider_type);
}
const std::vector<std::string>& ORTTrainingPythonEnv::GetAvailableTrainingExecutionProviderTypes() {
return available_training_eps_;
}
std::string ORTTrainingPythonEnv::GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash) {
std::string key(provider_type);
key.append(std::to_string(hash));
return key;
}
void ORTTrainingPythonEnv::ClearExecutionProviderInstances() {
execution_provider_instances_map_.clear();
}
namespace {
// This class provides a static shell for on-demand and thread-safe construction
// of ORTTrainingPythonEnv object for both Inference and Training python layers.
// ORTTrainingPythonEnv class contains instances of execution providers that have been
// instantiated for training purposes. It depends on the Environment singleton to which it
// holds a shared_ptr instance.
//
// 1) we make this class a singleton that is a function local static. The function local statics
// are constructed when the function is called the very first time. This fact has several important
// properties.
// - First, it is constructed before it is first needed possibly by another static object
// and destroyed after that object is destroyed.
// - Second, it is constructed in a thread safe manner.
// - Last, this order of construction/destruction is enforced across the compilation units, as opposed
// to the static objects that are simply declared in order in a single unit, but their lifespan is
// unconnected to that of in other compilation units. This is achieved automatically by run-time
// by execution atexit() to build a chain.
// 2) This ORTTrainingPythonEnv is currently owned by a unique_ptr unlike the Environment singleton. This is
// because we currently do not see a need to refer to it by any of the Python objects or by other singletons.
// With this change this singleton is properly destroyed after python module is unloaded, but before the Environment.
// HOWEVER, because it holds instances of execution providers, we want to make sure that those instances are destroyed
// before those depended EP DLLs are unloaded so EP destructor can run.
// This static is destroyed when this compilation unit is unloaded and it generally happens
// AFTER EP dlls are unloaded. To mitigate that, we clear EP instances using python `atexit` (different from C atexit())
// mechanism which takes place after all python objects are GCed but before any DLLs are unloaded or
// runtime starts destroying globals.
// 3) We guard against singleton resurrection attempts to detect code that runs when it should not
// and make necessary adjustments.
// For all the related details and why it is needed see "Modern C++ design" by A. Alexandrescu Chapter 6.
class TrainingEnvInitialzer {
public:
static ORTTrainingPythonEnv& Instance() {
// Guard against attempts to resurrect the singleton
if (TrainingEnvInitialzer::destroyed) {
ORT_THROW("Detected an attempt to resurrect destroyed Training Environment");
}
static TrainingEnvInitialzer training_env_holder;
return training_env_holder.Get();
}
private:
TrainingEnvInitialzer() {
ORT_ENFORCE(InitArray());
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
ort_training_env_ = std::make_unique<ORTTrainingPythonEnv>();
}
~TrainingEnvInitialzer() {
destroyed = true;
}
ORTTrainingPythonEnv& Get() noexcept {
return *ort_training_env_;
}
std::unique_ptr<ORTTrainingPythonEnv> ort_training_env_;
static bool destroyed;
};
bool TrainingEnvInitialzer::destroyed = false;
} // namespace
ORTTrainingPythonEnv& GetTrainingEnv() {
return TrainingEnvInitialzer::Instance();
}
void ResolveExtraProviderOptions(const std::vector<std::string>& provider_types,
const ProviderOptionsMap& original_provider_options_map,
ProviderOptionsMap& merged_options) {
auto& training_env = GetTrainingEnv();
for (auto& provider_type : provider_types) {
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
if (it == training_env.ext_execution_provider_info_map_.end()) {
// nothing changed.
if (original_provider_options_map.find(provider_type) != original_provider_options_map.end())
merged_options.insert({provider_type, original_provider_options_map.at(provider_type)});
} else {
ProviderOptions options = it->second.second;
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
auto original_map_it = original_provider_options_map.find(provider_type);
if (original_map_it != original_provider_options_map.end()) {
for (auto [k, v] : original_map_it->second) {
options.insert({k, v});
}
}
merged_options[provider_type] = options;
}
}
}
std::unique_ptr<IExecutionProvider> CreateTrainingEP(
const SessionOptions& session_options,
const std::string& provider_type,
const ProviderOptionsMap& provider_options_map) {
// TODO(leca): REVIEW: No allocators are initialized
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
}
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
const ProviderOptionsMap& provider_options_map,
const SessionOptions& session_options) {
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
// resolve provider options, because the hash key of ep depends on provider options.
ProviderOptionsMap merged_options;
ResolveExtraProviderOptions({provider_type}, provider_options_map, merged_options);
// search in environment
size_t hash;
if (GetProviderInstanceHash(provider_type, merged_options, hash)) {
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
if (!cached_provider_instance) {
auto ep = CreateTrainingEP(session_options, provider_type, merged_options);
if (ep) {
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
}
}
return cached_provider_instance;
} else {
// the EP doesn't support cache, register the instance to session
auto ep = CreateTrainingEP(session_options, provider_type, merged_options);
return ep;
}
}
void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::vector<std::string>& provider_types,
const ProviderOptionsMap& provider_options_map) {
for (auto provider_type : provider_types) {
auto ep = GetOrCreateExecutionProvider(provider_type, provider_options_map, sess->GetSessionOptions());
if (ep)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(ep));
}
}
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
m.doc() = "pybind11 stateful interface to ORTTraining";
RegisterExceptions(m);
// Instantiate singletons
GetTrainingEnv();
addGlobalMethods(m);
addObjectMethods(m, ORTTrainingRegisterExecutionProviders);
addOrtValueMethods(m);
addSparseTensorMethods(m);
addIoBindingMethods(m);
#if !defined(__APPLE__) && \
(!defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS))
Ort::SessionOptions tmp_options;
if (!InitProvidersSharedLibrary()) {
const logging::Logger& default_logger = logging::LoggingManager::DefaultLogger();
LOGS(default_logger, WARNING) << "Init provider bridge failed.";
}
#endif
addObjectMethodsForTraining(m);
#ifdef ENABLE_LAZY_TENSOR
addObjectMethodsForLazyTensor(m);
#endif
m.def("_register_provider_lib", [](const std::string& name,
const std::string& provider_shared_lib_path,
const ProviderOptions& default_options) {
GetTrainingEnv().RegisterExtExecutionProviderInfo(name, provider_shared_lib_path, default_options);
});
m.def(
"get_available_providers", []() -> const std::vector<std::string>& { return GetTrainingEnv().GetAvailableTrainingExecutionProviderTypes(); },
"Return list of available Execution Providers in this installed version of Onnxruntime. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def("get_version_string", []() -> std::string { return ORT_VERSION; });
m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; });
m.def(
"clear_training_ep_instances", []() -> void {
GetTrainingEnv().ClearExecutionProviderInstances();
},
"Clean the execution provider instances used in ort training module.");
m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; });
// See documentation for class TrainingEnvInitialzer earlier in this module
// for an explanation as to why this is needed.
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() {
GetTrainingEnv().ClearExecutionProviderInstances();
}));
}
} // namespace python
} // namespace onnxruntime