onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc
Ashwini Khade ea7bbd667d
fix headers for training apis (#14350)
### Description
Minor refactor PR for fixing header placement for training apis
2023-01-19 10:26:53 -08:00

1068 lines
57 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_state_common.h"
// pybind11/stl.h is needed to support std::unordered_set, etc.
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#ifdef ENABLE_TRAINING_APIS
#include <google/protobuf/io/zero_copy_stream_impl.h>
#endif
#include "core/common/parse_string.h"
#include "core/graph/model.h"
#include "core/session/environment.h"
#include "core/dlpack/dlpack_converter.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/agent/training_agent.h"
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
#include "orttraining/core/framework/gradient_graph_builder.h"
#include "orttraining/core/framework/ortmodule_graph_builder.h"
#include "orttraining/core/graph/gradient_definition_registry.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "orttraining/python/orttraining_pybind_common.h"
#include "orttraining/core/optimizer/graph_transformer_utils.h"
#include "core/framework/stream_execution_context.h"
#ifdef ENABLE_TRAINING_TORCH_INTEROP
#include "orttraining/core/framework/torch/custom_function_register.h"
#endif
#ifdef ENABLE_TRAINING_APIS
#include "orttraining/training_api/checkpoint.h"
#include "orttraining/training_api/lr_scheduler.h"
#endif
PYBIND11_MAKE_OPAQUE(onnxruntime::OrtValueCache);
namespace onnxruntime {
namespace python {
namespace py = pybind11;
using namespace onnxruntime;
using namespace onnxruntime::logging;
using namespace onnxruntime::training;
Environment& GetTrainingORTEnv();
ORTTrainingPythonEnv& GetTrainingEnv();
void ResolveExtraProviderOptions(const std::vector<std::string>& provider_types,
const ProviderOptionsVector& original_provider_options_vector,
ProviderOptionsVector& merged_options) {
auto& training_env = GetTrainingEnv();
std::size_t j = 0; // index for provider_options_vector
for (const std::string& type : provider_types) {
auto it = training_env.ext_execution_provider_info_map_.find(type);
if (it == training_env.ext_execution_provider_info_map_.end()) {
if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) {
merged_options.push_back(original_provider_options_vector[j]);
}
} else {
ProviderOptions options = it->second.second;
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
if (j < original_provider_options_vector.size() && !original_provider_options_vector[j].empty()) {
for (auto [k, v] : original_provider_options_vector[j]) {
options.insert({k, v});
}
}
merged_options.push_back(options);
}
j += 1;
}
}
#ifdef ENABLE_TRAINING_APIS
namespace {
// This function is used to create an execution provider to be passed to Module and Optimizer.
std::vector<std::shared_ptr<IExecutionProvider>>
GetExecutionProvidersForTrainingApis(OrtDevice device) {
std::vector<std::shared_ptr<IExecutionProvider>> provider;
#ifdef USE_CUDA
if (device.Type() == OrtDevice::GPU) {
OrtCUDAProviderOptions provider_options{};
provider_options.device_id = device.Id();
if (auto factory = CudaProviderFactoryCreator::Create(&provider_options))
provider.push_back(factory->CreateProvider());
return provider;
}
#endif
if (device.Type() == OrtDevice::CPU) {
provider = std::vector<std::shared_ptr<IExecutionProvider>>();
} else {
ORT_THROW("Unsupported device type: ", device.Type());
}
return provider;
}
} // namespace
#endif
struct TrainingParameters {
std::string loss_output_name;
std::unordered_set<std::string> weights_to_train;
std::unordered_set<std::string> weights_not_to_train;
onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights;
// optimizer
std::string training_optimizer_name;
std::string lr_params_feed_name = "Learning_Rate";
std::unordered_map<std::string, std::unordered_map<std::string, float>> optimizer_attributes_map;
std::unordered_map<std::string, std::unordered_map<std::string, int64_t>> optimizer_int_attributes_map;
onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state;
std::unordered_map<std::string, std::vector<int>> sliced_schema;
std::unordered_map<std::string, int> sliced_axes;
std::vector<std::string> sliced_tensor_names;
bool use_fp16_moments = false;
bool use_mixed_precision = false;
bool allreduce_post_accumulation = false;
float loss_scale = 0.0f;
int world_rank = 0;
int world_size = 1;
int local_rank = 0;
int local_size = 1;
int gradient_accumulation_steps = 1;
int data_parallel_size = 1;
int horizontal_parallel_size = 1;
int pipeline_parallel_size = 1;
int num_pipeline_micro_batches = 1;
int deepspeed_zero_stage = 0;
bool enable_grad_norm_clip = true;
bool set_gradients_as_graph_outputs = false;
bool use_memory_efficient_gradient = false;
std::string pipeline_cut_info_string = {};
// recompute
bool attn_dropout_recompute = false;
bool gelu_recompute = false;
bool transformer_layer_recompute = false;
int number_recompute_layers = 0;
bool enable_adasum = false;
// transformation
int propagate_cast_ops_level = 1;
std::vector<std::string> propagate_cast_ops_allow;
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy propagate_cast_ops_strategy =
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::FloodFill;
// graph dumping
std::string model_after_graph_transforms_path;
std::string model_with_gradient_graph_path;
std::string model_with_training_graph_path;
};
struct TrainingConfigurationResult {
optional<std::string> loss_scale_input_name;
};
#ifdef ENABLE_TRAINING_APIS
// Thin wrapper over internal C++ Optimizer
struct PyOptimizer {
PyOptimizer(const std::string optimizer_model_uri,
onnxruntime::training::api::Module* model, std::vector<std::shared_ptr<IExecutionProvider>> provider)
: optimizer_(std::make_unique<onnxruntime::training::api::Optimizer>(optimizer_model_uri,
model->NamedParameters(), onnxruntime::SessionOptions(),
GetTrainingORTEnv(), provider)) {
}
std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
};
#endif
struct PyGradientGraphBuilder {
std::unique_ptr<GradientGraphBuilder> builder;
std::shared_ptr<Model> model;
std::unique_ptr<logging::Logger> logger;
std::unique_ptr<GradientGraphConfiguration> gradient_graph_config;
PyGradientGraphBuilder(std::unique_ptr<GradientGraphBuilder> builder_, std::shared_ptr<Model> model_, std::unique_ptr<logging::Logger> logger_, std::unique_ptr<GradientGraphConfiguration> gradient_graph_config_)
: builder(std::move(builder_)), model(std::move(model_)), logger(std::move(logger_)), gradient_graph_config(std::move(gradient_graph_config_)) {}
};
// TODO: this method does not handle parallel optimization.
TrainingConfigurationResult ConfigureSessionForTraining(
training::PipelineTrainingSession* sess, TrainingParameters& parameters) {
// TODO tix, refactor the mpi related code to populate all fields correctly by default.
ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size);
ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size);
ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size);
// When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks.
if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) {
ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0,
"D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP.");
const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size);
parameters.data_parallel_size = new_data_parallel_size;
const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) +
", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size);
LOGS(*(sess->GetLogger()), WARNING) << msg;
}
training::PipelineTrainingSession::TrainingConfiguration config{};
config.weight_names_to_train = parameters.weights_to_train;
config.weight_names_to_not_train = parameters.weights_not_to_train;
config.immutable_weights = parameters.immutable_weights;
config.gradient_accumulation_steps = parameters.gradient_accumulation_steps;
config.distributed_config.world_rank = parameters.world_rank;
config.distributed_config.world_size = parameters.world_size;
config.distributed_config.local_rank = parameters.local_rank;
config.distributed_config.local_size = parameters.local_size;
config.distributed_config.data_parallel_size = parameters.data_parallel_size;
config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size;
config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size;
config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches;
config.distributed_config.sliced_schema = parameters.sliced_schema;
config.distributed_config.sliced_axes = parameters.sliced_axes;
config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names;
if (parameters.use_mixed_precision) {
training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{};
mp.use_mixed_precision_initializers = true;
config.mixed_precision_config = mp;
}
if (config.distributed_config.pipeline_parallel_size > 1) {
training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config;
// Currently don't support auto-partition. User needs to pass in cut information for pipeline
pipeline_config.do_partition = true;
assert(!parameters.pipeline_cut_info_string.empty());
auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) {
std::vector<std::string> result;
size_t pos = 0;
while ((pos = input_str.find(delimiter)) != std::string::npos) {
std::string token = input_str.substr(0, pos);
result.emplace_back(token);
input_str.erase(0, pos + delimiter.length());
}
// push the last split of substring into result.
result.emplace_back(input_str);
return result;
};
auto process_cut_info = [&](std::string& cut_info_string) {
std::vector<PipelineTrainingSession::TrainingConfiguration::CutInfo> cut_list;
const std::string group_delimiter = ",";
const std::string edge_delimiter = ":";
const std::string consumer_delimiter = "/";
const std::string producer_consumer_delimiter = "-";
auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter);
for (auto& cut_info_group : cut_info_groups) {
PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info;
auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter);
for (auto& cut_edge : cut_edges) {
auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter);
if (process_edge.size() == 1) {
PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]};
cut_info.emplace_back(edge);
} else {
ORT_ENFORCE(process_edge.size() == 2);
auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter);
PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list};
cut_info.emplace_back(edge);
}
}
cut_list.emplace_back(cut_info);
}
return cut_list;
};
pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string);
config.pipeline_config = pipeline_config;
}
config.loss_name = parameters.loss_output_name;
if (!parameters.training_optimizer_name.empty()) {
training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{};
opt.name = parameters.training_optimizer_name;
opt.learning_rate_input_name = parameters.lr_params_feed_name;
opt.weight_attributes_generator = [&parameters](const std::string& weight_name) {
const auto it = parameters.optimizer_attributes_map.find(weight_name);
ORT_ENFORCE(
it != parameters.optimizer_attributes_map.end(),
"Failed to find attribute map for weight ", weight_name);
return it->second;
};
opt.weight_int_attributes_generator = [&parameters](const std::string& weight_name) {
const auto it = parameters.optimizer_int_attributes_map.find(weight_name);
ORT_ENFORCE(
it != parameters.optimizer_int_attributes_map.end(),
"Failed to find int attribute map for weight ", weight_name);
return it->second;
};
opt.use_mixed_precision_moments = parameters.use_fp16_moments;
opt.do_all_reduce_in_mixed_precision_type = true;
// TODO: this mapping is temporary.
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
// eventually we will have one all reduce kernel and let opt to have
// an allreduce_post_accumulation option and remove the use_nccl option.
opt.use_nccl = parameters.allreduce_post_accumulation;
opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage);
opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip;
// TODO reduction types
if (parameters.enable_adasum) {
#ifdef USE_CUDA
opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction;
#else
opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction;
#endif
}
config.optimizer_config = opt;
}
if (!parameters.optimizer_initial_state.empty()) {
config.init_optimizer_states = parameters.optimizer_initial_state;
}
config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient;
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute;
config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute;
config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute;
config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers;
config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy;
config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level;
config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow;
if (!parameters.model_after_graph_transforms_path.empty()) {
config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path);
}
if (!parameters.model_with_gradient_graph_path.empty()) {
config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path);
}
if (!parameters.model_with_training_graph_path.empty()) {
config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path);
}
training::PipelineTrainingSession::TrainingConfigurationResult config_result{};
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
TrainingConfigurationResult python_config_result{};
if (config_result.mixed_precision_config_result.has_value()) {
const auto& mp_config_result = config_result.mixed_precision_config_result.value();
python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name;
}
return python_config_result;
}
#if defined(USE_MPI)
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();
parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
if (parameters.world_rank != 0) {
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
}
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
}
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
if (parameters.world_size != 1) {
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
}
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
}
}
#endif
std::unordered_map<std::string, std::unordered_map<std::string, py::object>> ConvertORTTensorMapToNumpy(std::unordered_map<std::string, NameMLValMap> c_tensor_state, const DataTransferManager& data_transfer_manager) {
std::unordered_map<std::string, std::unordered_map<std::string, py::object>> py_tensor_state;
for (const auto& layer1_item : c_tensor_state) {
py_tensor_state[layer1_item.first] = {};
for (const auto& layer2_item : layer1_item.second) {
assert(layer2_item.second.IsTensor());
py::object obj;
const Tensor& rtensor = layer2_item.second.Get<Tensor>();
GetPyObjFromTensor(rtensor, obj, &data_transfer_manager);
py_tensor_state[layer1_item.first].insert({layer2_item.first, obj});
}
}
return py_tensor_state;
}
void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) {
py::class_<OrtValueCache, OrtValueCachePtr>(m, "OrtValueCache")
.def(py::init<>())
.def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) {
cache_ptr->emplace(node_arg_name, value);
})
.def("keys", [](const OrtValueCachePtr& cache_ptr) {
py::list keys;
for (auto kv : *cache_ptr.get()) {
keys.append(kv.first);
}
return keys;
})
.def("clear", [](const OrtValueCachePtr& cache_ptr) {
cache_ptr->clear();
})
.def("count", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
return cache_ptr->count(node_arg_name);
})
.def("remove", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name) {
const auto& num_entries_erased = cache_ptr->erase(node_arg_name);
ORT_ENFORCE(num_entries_erased == 1, "NodeArg not found in cache: ", node_arg_name);
});
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
.def_readwrite("weights_to_train", &TrainingParameters::weights_to_train)
.def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names)
.def_readwrite("training_optimizer_name", &TrainingParameters::training_optimizer_name)
.def_readwrite("lr_params_feed_name", &TrainingParameters::lr_params_feed_name)
.def_readwrite("optimizer_attributes_map", &TrainingParameters::optimizer_attributes_map)
.def_readwrite("optimizer_int_attributes_map", &TrainingParameters::optimizer_int_attributes_map)
.def_readwrite("sliced_schema", &TrainingParameters::sliced_schema)
.def_readwrite("sliced_axes", &TrainingParameters::sliced_axes)
.def_readwrite("use_fp16_moments", &TrainingParameters::use_fp16_moments)
.def_readwrite("use_mixed_precision", &TrainingParameters::use_mixed_precision)
.def_readwrite("allreduce_post_accumulation", &TrainingParameters::allreduce_post_accumulation)
.def_readwrite("loss_scale", &TrainingParameters::loss_scale)
.def_readwrite("world_rank", &TrainingParameters::world_rank)
.def_readwrite("world_size", &TrainingParameters::world_size)
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size)
.def_readwrite("pipeline_cut_info_string", &TrainingParameters::pipeline_cut_info_string)
.def_readwrite("num_pipeline_micro_batches", &TrainingParameters::num_pipeline_micro_batches)
.def_readwrite("gradient_accumulation_steps", &TrainingParameters::gradient_accumulation_steps)
.def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage)
.def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip)
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
.def_readwrite("use_memory_efficient_gradient", &TrainingParameters::use_memory_efficient_gradient)
.def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute)
.def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers)
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size)
.def("set_optimizer_initial_state",
[](TrainingParameters& parameters, const std::unordered_map<std::string, std::unordered_map<std::string, py::object>>& py_state) -> void {
onnxruntime::training::TrainingSession::OptimizerState optim_state;
for (const auto& weight_it : py_state) {
auto state = weight_it.second;
NameMLValMap state_tensors;
for (auto& initializer : state) {
OrtValue ml_value;
// InputDeflist is null because parameters havent been tied to session yet
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true);
ThrowIfPyErrOccured();
state_tensors.emplace(initializer.first, ml_value);
}
optim_state.emplace(weight_it.first, state_tensors);
}
parameters.optimizer_initial_state = optim_state;
})
.def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path)
.def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path)
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
.def_readwrite("enable_adasum", &TrainingParameters::enable_adasum)
.def_readwrite("propagate_cast_ops_level", &TrainingParameters::propagate_cast_ops_level)
.def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow);
#if defined(USE_MPI)
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
m.def("get_mpi_context_local_size", []() -> int { return MPIContext::GetInstance().GetLocalSize(); });
m.def("get_mpi_context_world_rank", []() -> int { return MPIContext::GetInstance().GetWorldRank(); });
m.def("get_mpi_context_world_size", []() -> int { return MPIContext::GetInstance().GetWorldSize(); });
#endif
m.def("register_forward_runner", [](py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterForwardRunner(obj.ptr());
#else
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("register_backward_runner", [](py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterBackwardRunner(obj.ptr());
#else
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("register_torch_autograd_function", [](std::string key, py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterTorchAutogradFunction(key, obj.ptr());
#else
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("unregister_python_functions", []() -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
// Release all custom python functions registered.
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.UnRegisterFunctions();
#endif
});
m.def("is_torch_interop_default_on", []() -> bool {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
return true;
#else
return false;
#endif
});
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
config_result.def(py::init())
.def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object {
if (result.loss_scale_input_name.has_value()) {
return py::str{result.loss_scale_input_name.value()};
}
return py::none();
});
// Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user
struct PyTrainingSession : public PyInferenceSession {
PyTrainingSession(Environment& env, const PySessionOptions& so)
: PyInferenceSession(std::make_unique<PipelineTrainingSession>(so.value, env)) {
}
};
py::class_<PyTrainingSession, PyInferenceSession> training_session(m, "TrainingSession");
training_session
.def(py::init([](const PySessionOptions& so) {
Environment& env = GetTrainingORTEnv();
return std::make_unique<PyTrainingSession>(env, so);
}))
.def(py::init([]() {
Environment& env = GetTrainingORTEnv();
return std::make_unique<PyTrainingSession>(env, GetDefaultCPUSessionOptions());
}))
.def("finalize", [](py::object) {
#if defined(USE_MPI)
#ifdef _WIN32
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
// call shutdown_mpi() here instead.
MPIContext::shutdown_mpi();
#endif
#endif
})
.def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path));
#if defined(USE_MPI)
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1)
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
.def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector<std::string>& provider_types, const ProviderOptionsVector& provider_options) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer));
#if defined(USE_MPI)
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1)
CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle()), parameters);
ProviderOptionsVector merged_options;
ResolveExtraProviderOptions(provider_types, provider_options, merged_options);
InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options);
return config_result;
})
.def("get_state", [](PyTrainingSession* sess) {
NameMLValMap state_tensors;
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->GetStateTensors(state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
// convert to numpy array
std::map<std::string, py::object> rmap;
for (auto& kv : state_tensors) {
if (kv.second.IsTensor()) {
py::object obj;
const Tensor& rtensor = kv.second.Get<Tensor>();
GetPyObjFromTensor(rtensor, obj, &data_transfer_manager);
rmap.insert({kv.first, obj});
} else {
throw std::runtime_error("Non tensor type in session state tensors is not expected.");
}
}
return rmap;
})
.def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) {
std::unordered_map<std::string, NameMLValMap> model_state_tensors;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager);
})
.def("get_optimizer_state", [](PyTrainingSession* sess) {
std::unordered_map<std::string, NameMLValMap> opt_state_tensors;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors));
auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager();
return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager);
})
.def("get_partition_info_map", [](PyTrainingSession* sess) {
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> part_info_map;
ORT_THROW_IF_ERROR(static_cast<TrainingSession*>(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map));
return part_info_map;
})
.def("load_state", [](PyTrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
NameMLValMap state_tensors;
for (auto initializer : state) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value);
ThrowIfPyErrOccured();
state_tensors.insert(std::make_pair(initializer.first, ml_value));
}
ORT_THROW_IF_ERROR(static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict));
})
.def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) {
return static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
});
py::class_<PartialGraphExecutionState>(m, "PartialGraphExecutionState")
.def(py::init([]() {
return std::make_unique<PartialGraphExecutionState>();
}));
py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class used to run a ORTModule model.)pbdoc")
.def(py::init([](PyInferenceSession* session, const std::vector<std::string>& fw_feed_names,
const std::vector<OrtDevice>& fw_outputs_device_info,
const std::vector<std::string>& bw_fetches_names,
const std::vector<OrtDevice>& bw_outputs_device_info,
int local_rank) {
return std::make_unique<TrainingAgent>(*session->GetSessionHandle(), fw_feed_names, fw_outputs_device_info,
bw_fetches_names, bw_outputs_device_info, local_rank);
}))
.def("run_forward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state, OrtValueCachePtr cache) -> void {
Status status = agent->RunForward(feeds, fetches, *state, cache);
if (!status.IsOK()) {
throw std::runtime_error("Error in forward pass execution: " + status.ErrorMessage());
}
})
.def("run_backward", [](TrainingAgent* agent, const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches, PartialGraphExecutionState* state) -> void {
Status status = agent->RunBackward(feeds, fetches, *state);
if (!status.IsOK()) {
throw std::runtime_error("Error in backward pass execution: " + status.ErrorMessage());
}
});
py::enum_<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(m, "PropagateCastOpsStrategy", py::module_local(), py::arithmetic{})
.value("NONE", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::None)
.value("INSERT_AND_REDUCE", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::InsertAndReduce)
.value("FLOOD_FILL", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::FloodFill)
.def("__or__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator|))
.def("__and__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator&))
.def("__eq__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator==))
.def("__neq__", py::overload_cast<GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy,
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy>(&operator!=));
py::class_<GraphTransformerConfiguration::PropagateCastOpsConfiguration>
propagate_cast_ops_config(
m, "PropagateCastOpsConfiguration",
R"pbdoc(Propagate cast ops configuration.)pbdoc");
propagate_cast_ops_config.def(py::init())
.def_readwrite("strategy", &GraphTransformerConfiguration::PropagateCastOpsConfiguration::strategy)
.def_readwrite("level", &GraphTransformerConfiguration::PropagateCastOpsConfiguration::level)
.def_readwrite("allow", &GraphTransformerConfiguration::PropagateCastOpsConfiguration::allow);
py::class_<GraphTransformerConfiguration> graph_transformer_config(
m, "GraphTransformerConfiguration",
R"pbdoc(Graph transformer configuration.)pbdoc");
graph_transformer_config.def(py::init())
.def_readwrite("propagate_cast_ops_config", &GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<TrainingGraphTransformerConfiguration, GraphTransformerConfiguration> training_graph_transformer_config(
m, "TrainingGraphTransformerConfiguration",
R"pbdoc(Training Graph transformer configuration.)pbdoc");
training_graph_transformer_config.def(py::init())
.def_readwrite("enable_gelu_approximation", &TrainingGraphTransformerConfiguration::enable_gelu_approximation)
.def_readwrite("attn_dropout_recompute", &TrainingGraphTransformerConfiguration::attn_dropout_recompute)
.def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(
m, "OrtModuleGraphBuilderConfiguration",
R"pbdoc(Configuration information for module graph builder.)pbdoc");
py::enum_<Severity>(m, "Severity", py::arithmetic(), py::module_local())
.value("VERBOSE", logging::Severity::kVERBOSE)
.value("INFO", logging::Severity::kINFO)
.value("WARNING", logging::Severity::kWARNING)
.value("ERROR", logging::Severity::kERROR)
.value("FATAL", logging::Severity::kFATAL);
module_graph_builder_config.def(py::init())
.def_readwrite("initializer_names", &OrtModuleGraphBuilderConfiguration::initializer_names)
.def_readwrite("initializer_names_to_train", &OrtModuleGraphBuilderConfiguration::initializer_names_to_train)
.def_readwrite("input_names_require_grad", &OrtModuleGraphBuilderConfiguration::input_names_require_grad)
.def_readwrite("use_memory_efficient_gradient",
&OrtModuleGraphBuilderConfiguration::use_memory_efficient_gradient)
.def_readwrite("build_gradient_graph", &OrtModuleGraphBuilderConfiguration::build_gradient_graph)
.def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config)
.def_readwrite("enable_caching", &OrtModuleGraphBuilderConfiguration::enable_caching)
.def_readwrite("loglevel", &OrtModuleGraphBuilderConfiguration::loglevel);
py::class_<GraphInfo> graph_info(m, "GraphInfo",
R"pbdoc(The information of split graphs for frontend.)pbdoc");
graph_info.def(py::init())
.def_readwrite("user_input_names", &GraphInfo::user_input_names)
.def_readwrite("user_input_grad_names", &GraphInfo::user_input_grad_names)
.def_readwrite("initializer_names", &GraphInfo::initializer_names)
.def_readwrite("initializer_names_to_train", &GraphInfo::initializer_names_to_train)
.def_readwrite("initializer_grad_names_to_train", &GraphInfo::initializer_grad_names_to_train)
.def_readwrite("user_output_names", &GraphInfo::user_output_names)
.def_readwrite("output_grad_indices_non_differentiable", &GraphInfo::output_grad_indices_non_differentiable)
.def_readwrite("output_grad_indices_require_full_shape", &GraphInfo::output_grad_indices_require_full_shape)
.def_readwrite("module_output_indices_requires_save_for_backward", &GraphInfo::module_output_indices_requires_save_for_backward)
.def_readwrite("frontier_node_arg_map", &GraphInfo::frontier_node_arg_map)
.def_readwrite("cached_node_arg_names", &GraphInfo::cached_node_arg_names)
.def_readwrite("module_output_gradient_name", &GraphInfo::module_output_gradient_name);
py::class_<OrtModuleGraphBuilder> ortmodule_graph_builder(m, "OrtModuleGraphBuilder");
ortmodule_graph_builder.def(py::init([]() { return std::make_unique<OrtModuleGraphBuilder>(); }))
.def("initialize",
[](OrtModuleGraphBuilder* ortmodule_graph_builder, const py::bytes& serialized_model,
const OrtModuleGraphBuilderConfiguration& config) {
std::istringstream buffer(serialized_model);
ORT_THROW_IF_ERROR(ortmodule_graph_builder->Initialize(buffer, config));
})
.def("build",
[](OrtModuleGraphBuilder* ortmodule_graph_builder) {
ORT_THROW_IF_ERROR(ortmodule_graph_builder->Build());
})
.def("build",
[](OrtModuleGraphBuilder* ortmodule_graph_builder,
const std::vector<std::vector<int64_t>>& input_shapes) {
ORT_THROW_IF_ERROR(ortmodule_graph_builder->Build(&input_shapes));
})
.def("get_gradient_model",
[](OrtModuleGraphBuilder* ortmodule_graph_builder) {
return py::bytes(ortmodule_graph_builder->GetGradientModel());
})
.def("get_forward_model",
[](OrtModuleGraphBuilder* ortmodule_graph_builder) {
return py::bytes(ortmodule_graph_builder->GetForwardModel());
})
.def("get_graph_info", [](OrtModuleGraphBuilder* ortmodule_graph_builder) {
return ortmodule_graph_builder->GetGraphInfo();
});
// Provide a convenient and well-documented way to make a gradient graph.
// It's possible to get the gradient graph through ORTModule by leveraging some "private" fields and not-so-well-documented APIs, so we provide this explicit and tested way to get the gradient graph.
py::class_<PyGradientGraphBuilder> gradient_graph_builder(m, "GradientGraphBuilder", R"pbdoc(A utility for making a gradient graph that can be used to help train a model.)pbdoc");
// Set up methods to match the C++ `GradientGraphBuilder` interface.
gradient_graph_builder.def(py::init([](
const py::bytes& serialized_model,
const std::unordered_set<std::string>& y_node_arg_names,
const std::unordered_set<std::string>& x_node_arg_names,
const std::string loss_node_arg_name) {
std::shared_ptr<Model> model;
auto logger_ptr = std::make_unique<logging::Logger>(logging::LoggingManager::DefaultLogger());
logger_ptr->SetSeverity(logging::Severity::kINFO);
ONNX_NAMESPACE::ModelProto model_proto;
std::istringstream model_istream(serialized_model);
ORT_THROW_IF_ERROR(Model::Load(model_istream, &model_proto));
ORT_THROW_IF_ERROR(Model::Load(model_proto, model, nullptr, *logger_ptr));
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.set_gradients_as_graph_outputs = true;
// Save some objects, otherwise they get lost.
auto gradient_graph_config_ptr = std::make_unique<GradientGraphConfiguration>(gradient_graph_config);
auto builder = std::make_unique<GradientGraphBuilder>(
&model->MainGraph(),
y_node_arg_names,
x_node_arg_names,
loss_node_arg_name,
*gradient_graph_config_ptr,
*logger_ptr);
return std::make_unique<PyGradientGraphBuilder>(std::move(builder), std::move(model), std::move(logger_ptr), std::move(gradient_graph_config_ptr));
}))
.def("build", [](PyGradientGraphBuilder* gradient_graph_builder) {
ORT_THROW_IF_ERROR(gradient_graph_builder->builder->Build());
})
.def("save", [](PyGradientGraphBuilder* gradient_graph_builder, const std::string& path) {
ORT_THROW_IF_ERROR(Model::Save(*(gradient_graph_builder->model), path));
})
.def("get_model", [](PyGradientGraphBuilder* gradient_graph_builder) {
std::string model_str;
gradient_graph_builder->model->ToProto().SerializeToString(&model_str);
return py::bytes(model_str);
});
py::class_<GradientNodeAttributeDefinition> gradient_node_attribute_definition(
m, "GradientNodeAttributeDefinition", R"pbdoc(Attribute definition for gradient graph nodes.)pbdoc");
gradient_node_attribute_definition.def(py::init())
.def_readwrite("name", &GradientNodeAttributeDefinition::name)
.def_readwrite("value_json", &GradientNodeAttributeDefinition::value_json)
.def_readwrite("dtype", &GradientNodeAttributeDefinition::dtype)
.def_readwrite("is_tensor", &GradientNodeAttributeDefinition::is_tensor);
py::class_<GradientNodeDefinition> gradient_node_definition(m, "GradientNodeDefinition",
R"pbdoc(Definition for gradient graph nodes.)pbdoc");
gradient_node_definition.def(py::init())
.def_readwrite("op_type", &GradientNodeDefinition::op_type)
.def_readwrite("domain", &GradientNodeDefinition::domain)
.def_readwrite("inputs", &GradientNodeDefinition::inputs)
.def_readwrite("outputs", &GradientNodeDefinition::outputs)
.def_readwrite("attributes", &GradientNodeDefinition::attributes);
m.def("register_gradient_definition",
[](const std::string& key, const std::vector<GradientNodeDefinition>& gradient_def) -> void {
GradientDefinitionRegistry::Instance().Register(key, gradient_def);
});
m.def("register_custom_stop_gradient_edges",
[](const std::string& key, const std::unordered_set<size_t> edges) -> void {
GradientDefinitionRegistry::Instance().SetStopGradientEdgesForNode(key, edges);
});
#ifdef ENABLE_TRAINING_APIS
py::class_<onnxruntime::training::api::Module> training_module(m, "Module", R"pbdoc(Training Module.)pbdoc");
training_module
.def(py::init([](const std::string& model_uri,
onnxruntime::training::api::CheckpointState& state,
std::optional<std::string> eval_model_uri,
OrtDevice device) {
onnxruntime::SessionOptions session_option;
std::vector<std::shared_ptr<IExecutionProvider>> provider = GetExecutionProvidersForTrainingApis(device);
return std::make_unique<onnxruntime::training::api::Module>(
model_uri,
state.module_checkpoint_state.named_parameters, session_option,
GetTrainingORTEnv(), provider, eval_model_uri);
}))
.def("train_step",
[](onnxruntime::training::api::Module* model,
const std::vector<OrtValue>& user_inputs, std::vector<OrtValue>& user_outputs) -> void {
ORT_THROW_IF_ERROR(model->TrainStep(user_inputs, user_outputs));
})
.def("eval_step",
[](onnxruntime::training::api::Module* model,
const std::vector<OrtValue>& user_inputs, std::vector<OrtValue>& user_outputs) -> void {
ORT_THROW_IF_ERROR(model->EvalStep(user_inputs, user_outputs));
})
.def("lazy_reset_grad",
[](onnxruntime::training::api::Module* model) -> void {
ORT_THROW_IF_ERROR(model->LazyResetGrad());
})
.def("copy_parameters_to_buffer",
[](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
})
.def("copy_buffer_to_parameters",
[](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
})
.def("get_parameters_size",
[](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
return model->GetParametersSize(trainable_only);
})
.def("save_checkpoint",
[](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
onnxruntime::training::api::CheckpointState state;
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state,
ToPathString(checkpoint_path)));
})
.def("export_model_for_inferencing",
[](onnxruntime::training::api::Module* model, const std::string& inference_model_path,
const std::vector<std::string>& graph_output_names) -> void {
ORT_ENFORCE(model, "Received a nullptr for expected pointer to class training::api::Module");
ORT_THROW_IF_ERROR(model->ExportModelForInferencing(inference_model_path,
graph_output_names));
});
py::class_<onnxruntime::training::api::CheckpointState>
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
checkpoint_state.def(py::init([](
const std::string& ckpt_uri) {
onnxruntime::training::api::CheckpointState state;
ORT_THROW_IF_ERROR(onnxruntime::training::api::LoadCheckpoint(ToPathString(ckpt_uri), state));
return state;
}));
py::class_<PyOptimizer>
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
training_optimizer.def(py::init([](
const std::string optimizer_model_uri,
onnxruntime::training::api::Module* model,
OrtDevice device) {
onnxruntime::SessionOptions session_option;
std::vector<std::shared_ptr<IExecutionProvider>> provider = GetExecutionProvidersForTrainingApis(device);
return std::make_unique<PyOptimizer>(
optimizer_model_uri,
model, provider);
}))
.def("optimizer_step", [](PyOptimizer* optimizer) -> void {
ORT_THROW_IF_ERROR(optimizer->optimizer_->Step());
})
.def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void {
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetLearningRate(lr));
})
.def("get_learning_rate", [](PyOptimizer* optimizer) -> float {
return optimizer->optimizer_->GetLearningRate();
});
py::class_<onnxruntime::training::api::LinearLRScheduler>
lr_scheduler(m, "LinearLRScheduler", R"pbdoc(Learning Rate Scheduler.)pbdoc");
lr_scheduler.def(py::init([](PyOptimizer* optimizer,
int64_t total_step_count,
int64_t warmup_step_count,
float initial_lr) {
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetInitialLearningRate(initial_lr));
return std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
optimizer->optimizer_, warmup_step_count, total_step_count);
}))
.def("scheduler_step", [](onnxruntime::training::api::LinearLRScheduler* scheduler) -> void {
ORT_THROW_IF_ERROR(scheduler->Step());
});
m.def("save_checkpoint",
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
const std::vector<py::bytes>& non_trainable_tensor_protos_pybytes,
const std::string& checkpoint_path) {
std::vector<TensorProto> trainable_tensor_protos(trainable_tensor_protos_pybytes.size());
std::vector<TensorProto> non_trainable_tensor_protos(non_trainable_tensor_protos_pybytes.size());
auto parse_pybytes_to_tensor_proto =
[](const std::vector<py::bytes>& tensor_protos_pybytes, std::vector<TensorProto>& tensor_protos) {
for (size_t i = 0; i < tensor_protos_pybytes.size(); ++i) {
std::istringstream tensor_proto_istream(tensor_protos_pybytes[i]);
ORT_ENFORCE(tensor_proto_istream.good(), "Broken tensor proto istream to read.");
google::protobuf::io::IstreamInputStream zero_copy_input(&tensor_proto_istream);
const bool result =
tensor_protos[i].ParseFromZeroCopyStream(&zero_copy_input) && tensor_proto_istream.eof();
ORT_ENFORCE(result, "Parse tensor proto failed.");
}
};
parse_pybytes_to_tensor_proto(trainable_tensor_protos_pybytes, trainable_tensor_protos);
parse_pybytes_to_tensor_proto(non_trainable_tensor_protos_pybytes, non_trainable_tensor_protos);
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos,
non_trainable_tensor_protos,
ToPathString(checkpoint_path)));
});
m.def("get_model_after_loading_checkpoint",
[](const std::string& checkpoint_path, const py::bytes& serialized_model) {
ONNX_NAMESPACE::ModelProto model_proto;
std::istringstream buffer(serialized_model);
ORT_THROW_IF_ERROR(Model::Load(buffer, &model_proto));
ORT_THROW_IF_ERROR(
onnxruntime::training::api::LoadCheckpointToModel(ToPathString(checkpoint_path), model_proto));
std::string model_proto_str;
ORT_ENFORCE(model_proto.SerializeToString(&model_proto_str), "Serializing Model failed.");
return py::bytes(model_proto_str);
});
m.def("get_optimized_model",
[](const py::bytes& serialized_model,
const std::unordered_set<std::string>& graph_entities_that_require_gradients) {
// Load the serialized model
std::istringstream buffer(serialized_model);
ONNX_NAMESPACE::ModelProto model_proto;
ORT_THROW_IF_ERROR(Model::Load(buffer, &model_proto));
// Get the ort model from ModelProto model
auto logger_ptr = std::make_unique<logging::Logger>(logging::LoggingManager::DefaultLogger());
logger_ptr->SetSeverity(logging::Severity::kINFO);
std::shared_ptr<onnxruntime::Model> ort_model;
ORT_THROW_IF_ERROR(Model::Load(model_proto, ort_model, nullptr, *logger_ptr));
Graph& graph = ort_model->MainGraph();
ORT_THROW_IF_ERROR(graph.Resolve());
// Register the pretraining graph transformations so that they are run twice
constexpr size_t NumSteps = 2;
GraphTransformerManager graph_transformation_mgr{NumSteps};
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
const auto add_transformers = [&cpu_execution_provider,
&graph_transformation_mgr,
&graph_entities_that_require_gradients](TransformerLevel level) {
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, graph_entities_that_require_gradients, TrainingGraphTransformerConfiguration(),
*cpu_execution_provider);
for (auto& entry : transformers_to_register) {
ORT_THROW_IF_ERROR(graph_transformation_mgr.Register(std::move(entry), level));
}
return Status::OK();
};
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
TransformerLevel level = static_cast<TransformerLevel>(i);
if (TransformerLevel::MaxLevel >= level) {
ORT_THROW_IF_ERROR(add_transformers(level));
}
}
// Run the graph transformations
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
ORT_THROW_IF_ERROR(
graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_ptr));
}
// Return the optimized model.
std::string model_str;
ort_model->ToProto().SerializeToString(&model_str);
return py::bytes(model_str);
});
#endif
}
} // namespace python
} // namespace onnxruntime