onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc
2020-03-23 23:23:34 +00:00

245 lines
No EOL
11 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_state_common.h"
// pybind11/stl.h is needed to support std::unordered_set, etc.
#include <pybind11/stl.h>
#include "core/framework/random_seed.h"
#include "core/framework/session_options.h"
#include "core/session/environment.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/mpi_setup.h"
namespace onnxruntime {
namespace python {
namespace py = pybind11;
using namespace onnxruntime;
using namespace onnxruntime::logging;
// BEGIN: forward declaration for stuff in onnxruntime_pybind_state
void InitializeSession(InferenceSession* sess, const std::vector<std::string>& provider_types);
void GetPyObjFromTensor(const Tensor& rtensor, py::object& obj, const DataTransferManager* data_transfer_manager = nullptr);
void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, AllocatorPtr alloc, const std::string& name_input,
py::object& value, OrtValue* p_mlvalue);
// END: forward declaration
struct TrainingParameters {
std::string loss_output_name;
std::unordered_set<std::string> weights_to_train;
std::unordered_set<std::string> weights_not_to_train;
onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights;
// optimizer
std::string training_optimizer_name;
std::string loss_scale_input_name;
std::string scaled_loss_output_name;
std::string lr_params_feed_name = "Learning_Rate";
std::unordered_map<std::string, std::unordered_map<std::string, float>> optimizer_attributes_map;
bool use_fp16_moments = false;
bool use_mixed_precision = false;
bool allreduce_post_accumulation = false;
float loss_scale = 0.0f;
int world_rank = 0;
int world_size = 1;
int local_rank = 0;
int local_size = 1;
int gradient_accumulation_steps = 1;
int data_parallel_size = 1;
int horizontal_parallel_size = 1;
bool partition_optimizer = false;
int seed = -1;
};
// TODO: this method does not handle parallel optimization.
void ConfigureSessionForTraining(
training::TrainingSession* sess, TrainingParameters& parameters) {
//TODO tix, refactor the mpi related code to populate all fields correctly by default.
ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size);
ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size);
if (parameters.world_size % parameters.horizontal_parallel_size != 0) {
throw std::runtime_error("Cannot split horizontal parallel group because world_size is not divisible");
}
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
if (data_group_size != parameters.data_parallel_size) {
std::cout << "WARNING: data_parallel_size is not correct, tuned automatically to "
<< data_group_size << std::endl;
parameters.data_parallel_size = data_group_size;
}
#ifdef USE_HOROVOD
// this condition block is temporary.
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1) {
auto mpi_context = training::setup_horovod();
std::cout << "mpi_context.world_rank: " << mpi_context.world_rank << std::endl;
std::cout << "mpi_context.local_rank: " << mpi_context.local_rank << std::endl;
std::cout << "mpi_context.world_size: " << mpi_context.world_size << std::endl;
std::cout << "mpi_context.local_size: " << mpi_context.local_size << std::endl;
parameters.local_size = mpi_context.local_size;
parameters.local_rank = mpi_context.local_rank;
}
#endif
training::TrainingSession::TrainingConfiguration config{};
config.weight_names_to_train = parameters.weights_to_train;
config.weight_names_to_not_train = parameters.weights_not_to_train;
config.immutable_weights = parameters.immutable_weights;
config.set_gradients_as_graph_outputs = true;
config.gradient_accumulation_steps = parameters.gradient_accumulation_steps;
config.distributed_config.world_rank = parameters.world_rank;
config.distributed_config.world_size = parameters.world_size;
config.distributed_config.local_rank = parameters.local_rank;
config.distributed_config.local_size = parameters.local_size;
config.distributed_config.data_parallel_size = parameters.data_parallel_size;
config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size;
if (parameters.use_mixed_precision) {
training::TrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{};
mp.add_loss_scaling = false;
mp.use_fp16_initializers = true;
config.mixed_precision_config = mp;
}
config.loss_name =
parameters.use_mixed_precision ? parameters.scaled_loss_output_name : parameters.loss_output_name;
if (!parameters.training_optimizer_name.empty()) {
training::TrainingSession::TrainingConfiguration::OptimizerConfiguration opt{};
opt.name = parameters.training_optimizer_name;
opt.learning_rate_input_name = parameters.lr_params_feed_name;
opt.weight_attributes_generator = [&parameters](const std::string& weight_name) {
const auto it = parameters.optimizer_attributes_map.find(weight_name);
ORT_ENFORCE(
it != parameters.optimizer_attributes_map.end(),
"Failed to find attribute map for weight ", weight_name);
return it->second;
};
opt.use_fp16_moments = parameters.use_fp16_moments;
opt.do_all_reduce_in_fp16 = true;
// TODO: this mapping is temporary.
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
// eventually we will have one all reduce kernel and let opt to have
// an allreduce_post_accumulation option and remove the use_nccl option.
opt.use_nccl = parameters.allreduce_post_accumulation;
opt.partition_optimizer = parameters.partition_optimizer;
config.optimizer_config = opt;
}
if (parameters.seed > 0) {
utils::SetStaticRandomSeed(static_cast<uint32_t>(parameters.seed));
std::cout << "Random seed is set to " << parameters.seed << std::endl;
}
training::TrainingSession::TrainingConfigurationResult config_result{};
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
}
void addObjectMethodsForTraining(py::module& m) {
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
.def_readwrite("weights_to_train", &TrainingParameters::weights_to_train)
.def_readwrite("loss_scale_input_name", &TrainingParameters::loss_scale_input_name)
.def_readwrite("scaled_loss_output_name", &TrainingParameters::scaled_loss_output_name)
.def_readwrite("training_optimizer_name", &TrainingParameters::training_optimizer_name)
.def_readwrite("lr_params_feed_name", &TrainingParameters::lr_params_feed_name)
.def_readwrite("optimizer_attributes_map", &TrainingParameters::optimizer_attributes_map)
.def_readwrite("use_fp16_moments", &TrainingParameters::use_fp16_moments)
.def_readwrite("use_mixed_precision", &TrainingParameters::use_mixed_precision)
.def_readwrite("allreduce_post_accumulation", &TrainingParameters::allreduce_post_accumulation)
.def_readwrite("loss_scale", &TrainingParameters::loss_scale)
.def_readwrite("world_rank", &TrainingParameters::world_rank)
.def_readwrite("world_size", &TrainingParameters::world_size)
.def_readwrite("gradient_accumulation_steps", &TrainingParameters::gradient_accumulation_steps)
.def_readwrite("partition_optimizer", &TrainingParameters::partition_optimizer);
py::class_<onnxruntime::training::TrainingSession, InferenceSession> training_session(m, "TrainingSession");
training_session.def(py::init([](const SessionOptions& so) {
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
}))
.def("finalize", [](py::object) {
#ifdef USE_HOROVOD
training::shutdown_horovod();
#endif
})
.def("load_model", [](onnxruntime::training::TrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
OrtPybindThrowIfError(sess->Load(path));
ConfigureSessionForTraining(sess, parameters);
std::vector<std::string> provider_types = {};
InitializeSession(sess, provider_types);
})
.def("read_bytes", [](onnxruntime::training::TrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->Load(buffer));
ConfigureSessionForTraining(sess, parameters);
std::vector<std::string> provider_types = {};
InitializeSession(sess, provider_types);
})
.def("get_state", [](onnxruntime::training::TrainingSession* sess) {
NameMLValMap state_tensors;
ORT_THROW_IF_ERROR(sess->GetStateTensors(state_tensors));
auto& data_transfer_manager = sess->GetDataTransferManager();
//convert to numpy array
std::map<std::string, py::object> rmap;
for (auto& kv : state_tensors) {
if (kv.second.IsTensor()) {
py::object obj;
const Tensor& rtensor = kv.second.Get<Tensor>();
GetPyObjFromTensor(rtensor, obj, &data_transfer_manager);
rmap.insert({kv.first, obj});
} else {
throw std::runtime_error("Non tensor type in session state tensors is not expected.");
}
}
return rmap;
})
.def("load_state", [](onnxruntime::training::TrainingSession* sess, std::unordered_map<std::string, py::object>& state, bool strict) {
NameMLValMap state_tensors;
for (auto initializer : state) {
OrtValue ml_value;
auto px = sess->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value);
if (PyErr_Occurred()) {
PyObject *ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyObject* pStr = PyObject_Str(ptype);
std::string sType = py::reinterpret_borrow<py::str>(pStr);
Py_XDECREF(pStr);
pStr = PyObject_Str(pvalue);
sType += ": ";
sType += py::reinterpret_borrow<py::str>(pStr);
Py_XDECREF(pStr);
throw std::runtime_error(sType);
}
state_tensors.insert(std::make_pair(initializer.first, ml_value));
}
ORT_THROW_IF_ERROR(sess->SetStateTensors(state_tensors, strict));
});
}
} // namespace python
} // namespace onnxruntime