Glue parallel training (#4550)

add mpi size, rank python API

add single node parallel training example
This commit is contained in:
liqunfu 2020-08-21 21:24:27 -07:00 committed by GitHub
parent 9a6db9b9f4
commit 6260d073b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 281 additions and 150 deletions

View file

@ -1,14 +1,36 @@
#include <stdio.h>
#include <stdlib.h>
#include "mpi_setup.h"
#include "mpi_context.h"
namespace onnxruntime {
namespace training {
MPIContext::MPIContext(int w_rank, int l_rank, int w_size, int l_size) : world_rank(w_rank), local_rank(l_rank), world_size(w_size), local_size(l_size) {}
MPIContext::MPIContext() :
world_rank_(0),
local_rank_(0),
world_size_(1),
local_size_(1)
{
#if defined(USE_NCCL)
setup_mpi();
#endif
}
#if defined(USE_NCCL) || defined(USE_HOROVOD)
MPIContext setup_mpi() {
MPIContext::~MPIContext() {
#if defined(USE_NCCL)
#ifndef _WIN32
shutdown_mpi();
#endif
#endif
}
const MPIContext& MPIContext::GetInstance() {
static MPIContext context;
return context;
}
#if defined(USE_NCCL)
void MPIContext::setup_mpi() {
// setup MPI amd horovod
int is_mpi_initialized = 0;
MPI_Initialized(&is_mpi_initialized);
@ -27,11 +49,6 @@ MPIContext setup_mpi() {
MPI_Allgather(&world_rank, 1, MPI_INT, ranks, 1, MPI_INT, MPI_COMM_WORLD);
#ifdef USE_HOROVOD
using namespace horovod::common;
horovod_init(ranks, world_size);
#endif
//Get local rank and size
int local_rank;
int local_size;
@ -49,14 +66,13 @@ MPIContext setup_mpi() {
printf("Using cuda local_rank: %d, world_rank: %d, world_size: %d, local_size: %d\n(version: %s)\n",
local_rank, world_rank, world_size, local_size, version);
return MPIContext(world_rank, local_rank, world_size, local_size);
this->world_rank_ = world_rank;
this->local_rank_ = local_rank;
this->world_size_ = world_size;
this->local_size_ = local_size;
}
void shutdown_mpi() {
#ifdef USE_HOROVOD
horovod::common::horovod_shutdown();
#endif
void MPIContext::shutdown_mpi() {
int is_mpi_initialized = 0;
MPI_Initialized(&is_mpi_initialized);
if (!is_mpi_initialized)

View file

@ -0,0 +1,60 @@
#pragma once
#if defined(USE_NCCL)
#include <mpi.h>
#endif
namespace onnxruntime {
namespace training {
#define MPI_CHECK(condition) \
do { \
int error = (condition); \
ORT_ENFORCE( \
error == MPI_SUCCESS, \
"MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
class MPIContext {
// https://stackoverflow.com/questions/1008019/c-singleton-design-pattern
public:
static const MPIContext& GetInstance();
MPIContext(MPIContext const&) = delete;
void operator=(MPIContext const&) = delete;
// within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi().
~MPIContext();
int GetWorldRank() const { return world_rank_; }
int GetLocalRank() const { return local_rank_; }
int GetWorldSize() const { return world_size_; }
int GetLocalSize() const { return local_size_; }
#if defined(USE_NCCL)
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction
// shutdown_mpi shall be called specifically in user code.
static void shutdown_mpi();
#endif
private:
MPIContext();
#if defined(USE_NCCL)
void setup_mpi();
#endif
int world_rank_;
int local_rank_;
int world_size_;
int local_size_;
};
} // namespace training
} // namespace onnxruntime

View file

@ -1,41 +0,0 @@
#pragma once
#if defined(USE_NCCL) || defined(USE_HOROVOD)
#include <mpi.h>
#endif
#ifdef USE_HOROVOD
#include "orttraining/core/graph/horovod_adapters.h"
#endif
namespace onnxruntime {
namespace training {
#define MPI_CHECK(condition) \
do { \
int error = (condition); \
ORT_ENFORCE( \
error == MPI_SUCCESS, \
"MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
struct MPIContext {
MPIContext(int world_rank = 0, int local_rank = 0, int world_size = 1, int local_size = 1);
int world_rank;
int local_rank;
int world_size;
int local_size;
};
#if defined(USE_NCCL) || defined(USE_HOROVOD)
MPIContext setup_mpi();
void shutdown_mpi();
#endif
} // namespace training
} // namespace onnxruntime

View file

@ -12,7 +12,7 @@
#include "core/providers/cuda/cuda_allocator.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/framework/tensorboard/event_writer.h"
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
#include "orttraining/models/runner/constant.h"
#include "orttraining/models/runner/training_runner.h"
#include "orttraining/models/runner/training_util.h"
@ -529,22 +529,20 @@ void setup_training_params(BertParameters& params) {
params.model_actual_running_graph_path = model_name_base + ORT_TSTR("_bw_running.onnx");
#if defined(USE_NCCL) || defined(USE_HOROVOD)
params.mpi_context = setup_mpi();
if (params.pipeline_parallel_size > 1) {
auto pipeline_model_name_base = model_name_base + ToPathString(std::to_string(params.mpi_context.world_rank));
auto pipeline_model_name_base = model_name_base + ToPathString(std::to_string(MPIContext::GetInstance().GetWorldRank()));
params.model_with_loss_func_path = pipeline_model_name_base + ORT_TSTR("_with_cost.onnx");
params.model_with_training_graph_path = pipeline_model_name_base + ORT_TSTR("_bw.onnx");
params.model_actual_running_graph_path = pipeline_model_name_base + ORT_TSTR("_bw_running.onnx");
}
ORT_ENFORCE(params.horizontal_parallel_size <= params.mpi_context.world_size);
ORT_ENFORCE(params.data_parallel_size <= params.mpi_context.world_size);
if (params.mpi_context.world_size % params.horizontal_parallel_size != 0) {
ORT_ENFORCE(params.horizontal_parallel_size <= MPIContext::GetInstance().GetWorldSize());
ORT_ENFORCE(params.data_parallel_size <= MPIContext::GetInstance().GetWorldSize());
if (MPIContext::GetInstance().GetWorldSize() % params.horizontal_parallel_size != 0) {
LOGS_DEFAULT(ERROR) << "Cannot split horizontal parallel group because world_size is not divisible";
return;
}
auto data_group_size = params.mpi_context.world_size / (params.horizontal_parallel_size * params.pipeline_parallel_size);
auto data_group_size = MPIContext::GetInstance().GetWorldSize() / (params.horizontal_parallel_size * params.pipeline_parallel_size);
ORT_ENFORCE(data_group_size > 0, "Insufficient processes lead to zero-way data parallelism, which should be at least one-way.");
if (data_group_size != params.data_parallel_size) {
LOGS_DEFAULT(WARNING) << "WARNING: data_parallel_size is not correct, tuned automatically to "
@ -558,7 +556,7 @@ void setup_training_params(BertParameters& params) {
#endif
#ifdef USE_CUDA
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(params.mpi_context.local_rank);
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(MPIContext::GetInstance().GetLocalRank());
size_t cuda_mem_limit = std::numeric_limits<size_t>::max();
if (params.cuda_mem_limit_in_gb > 0)
cuda_mem_limit = static_cast<size_t>(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024);
@ -628,7 +626,7 @@ void setup_training_params(BertParameters& params) {
if (params.dump_fetches) {
std::ostringstream filename;
filename << "./fetch_dumps/rank_" << params.mpi_context.world_rank << "_step_" << step << ".txt";
filename << "./fetch_dumps/rank_" << MPIContext::GetInstance().GetWorldRank() << "_step_" << step << ".txt";
ofstream ofs(filename.str());
for (size_t i = 0; i < fetch_names.size(); ++i) {
TrainingUtil::PrintTensor(fetch_names[i], fetches[i].Get<Tensor>(), ofs);
@ -733,7 +731,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env)
BertParameters params_for_phase;
while (GetParametersForPhase(runner->GetRound(), params, params_for_phase)) {
ORT_RETURN_IF_ERROR(runner->UpdateParams(params_for_phase));
auto rank_in_data_parallel_group = params_for_phase.mpi_context.world_rank / params_for_phase.horizontal_parallel_size;
auto rank_in_data_parallel_group = MPIContext::GetInstance().GetWorldRank() / params_for_phase.horizontal_parallel_size;
auto training_data_loader = onnxruntime::make_unique<DataLoader>(params_for_phase.input_name_map,
params_for_phase.train_data_dir,
max_num_files_preload,
@ -742,7 +740,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env)
auto test_data_loader = std::unique_ptr<DataLoader>{};
// Evaluation is only done in device #0
if (params_for_phase.mpi_context.world_rank == 0) {
if (MPIContext::GetInstance().GetWorldRank() == 0) {
test_data_loader = onnxruntime::make_unique<DataLoader>(params_for_phase.input_name_map,
params_for_phase.test_data_dir,
max_num_files_preload);
@ -759,7 +757,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env)
ORT_RETURN_IF_ERROR(runner->ResetLossScaler());
}
if (params_for_phase.mpi_context.world_rank == 0) {
if (MPIContext::GetInstance().GetWorldRank() == 0) {
// Pass in empty dataloader to disable evaluation in EndTraining
// to avoid a redundant synchronization caused by Tensorboard's SummaryMerge Op.
ORT_RETURN_IF_ERROR(runner->EndTraining(nullptr));
@ -807,9 +805,13 @@ int main(int argc, char* argv[]) {
RETURN_IF_FAIL(RunTraining(params, *env));
}
#if defined(USE_NCCL) || defined(USE_HOROVOD)
shutdown_mpi();
#if defined(USE_NCCL)
#ifdef _WIN32
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
// call shutdown_mpi() here instead.
MPIContext::shutdown_mpi();
#endif
#endif
return 0;
}

View file

@ -9,7 +9,7 @@
#include "core/session/environment.h"
#include "core/framework/random_seed.h"
#include "core/providers/cuda/cuda_allocator.h"
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
#include "orttraining/core/framework/tensorboard/event_writer.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/models/runner/constant.h"
@ -294,15 +294,14 @@ void setup_training_params(GPT2Parameters& params) {
/*label_name*/ "labels"});
#if defined(USE_NCCL) || defined(USE_HOROVOD)
params.mpi_context = setup_mpi();
ORT_ENFORCE(params.horizontal_parallel_size <= params.mpi_context.world_size);
ORT_ENFORCE(params.data_parallel_size <= params.mpi_context.world_size);
if (params.mpi_context.world_size % params.horizontal_parallel_size != 0) {
ORT_ENFORCE(params.horizontal_parallel_size <= MPIContext::GetInstance().GetWorldSize());
ORT_ENFORCE(params.data_parallel_size <= MPIContext::GetInstance().GetWorldSize());
if (MPIContext::GetInstance().GetWorldSize() % params.horizontal_parallel_size != 0) {
LOGS_DEFAULT(ERROR) << "Cannot split horizontal parallel group because world_size is not divisible";
return;
}
auto data_group_size = params.mpi_context.world_size / params.horizontal_parallel_size;
auto data_group_size = MPIContext::GetInstance().GetWorldSize() / params.horizontal_parallel_size;
if (data_group_size != params.data_parallel_size) {
LOGS_DEFAULT(WARNING) << "WARNING: data_parallel_size is not correct, tuned automatically to "
<< data_group_size << std::endl;
@ -341,7 +340,7 @@ void setup_training_params(GPT2Parameters& params) {
params.model_type = "gpt2";
#ifdef USE_CUDA
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(params.mpi_context.local_rank);
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(MPIContext::GetInstance().GetLocalRank());
params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id));
params.input_allocator = std::make_shared<CUDAPinnedAllocator>(device_id, CUDA_PINNED);
#endif
@ -417,7 +416,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env)
auto runner = onnxruntime::make_unique<TrainingRunner>(params, env);
ORT_RETURN_IF_ERROR(runner->Initialize());
auto rank_in_data_parallel_group = params.mpi_context.world_rank / params.horizontal_parallel_size;
auto rank_in_data_parallel_group = MPIContext::GetInstance().GetWorldRank() / params.horizontal_parallel_size;
auto training_data_loader = onnxruntime::make_unique<DataLoader>(params.input_name_map,
params.train_data_dir,
max_num_files_preload,
@ -426,7 +425,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env)
std::unique_ptr<DataLoader> test_data_loader;
// Evaluation is only done in device #0
if (params.mpi_context.world_rank == 0) {
if (MPIContext::GetInstance().GetWorldRank() == 0) {
test_data_loader = onnxruntime::make_unique<DataLoader>(params.input_name_map,
params.test_data_dir,
max_num_files_preload);
@ -441,7 +440,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env)
ORT_RETURN_IF_ERROR(runner->Run(training_data_loader.get(), test_data_loader.get(), mapped_dimensions));
// only test and save trained model on device #0
if (params.mpi_context.world_rank == 0) {
if (MPIContext::GetInstance().GetWorldRank() == 0) {
test_data_loader = onnxruntime::make_unique<DataLoader>(params.input_name_map,
params.test_data_dir,
max_num_files_preload);
@ -478,9 +477,13 @@ int main(int argc, char* argv[]) {
RETURN_IF_FAIL(RunTraining(params, *env));
}
#if defined(USE_NCCL) || defined(USE_HOROVOD)
shutdown_mpi();
#if defined(USE_NCCL)
#ifdef _WIN32
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
// call shutdown_mpi() here instead.
MPIContext::shutdown_mpi();
#endif
#endif
return 0;
}

View file

@ -143,7 +143,7 @@ void setup_training_params(TrainingRunner::Parameters& params) {
};
std::shared_ptr<EventWriter> tensorboard;
if (!params.log_dir.empty() && params.mpi_context.world_rank == 0)
if (!params.log_dir.empty() && MPIContext::GetInstance().GetWorldRank() == 0)
tensorboard = std::make_shared<EventWriter>(params.log_dir);
params.post_evaluation_callback = [tensorboard](size_t num_samples, size_t step, const std::string /**/) {
@ -183,12 +183,12 @@ int main(int argc, char* args[]) {
setup_training_params(params);
// setup data
auto device_count = params.mpi_context.world_size;
auto device_count = MPIContext::GetInstance().GetWorldSize();
std::vector<string> feeds{"X", "labels"};
auto trainingData = std::make_shared<DataSet>(feeds);
auto testData = std::make_shared<DataSet>(feeds);
std::string mnist_data_path = ToMBString(params.train_data_dir);
PrepareMNISTData(mnist_data_path, IMAGE_DIMS, LABEL_DIMS, *trainingData, *testData, params.mpi_context.world_rank /* shard_to_load */, device_count /* total_shards */);
PrepareMNISTData(mnist_data_path, IMAGE_DIMS, LABEL_DIMS, *trainingData, *testData, MPIContext::GetInstance().GetWorldRank() /* shard_to_load */, device_count /* total_shards */);
if (testData->NumSamples() == 0) {
printf("Warning: No data loaded - run cancelled.\n");

View file

@ -10,7 +10,7 @@
#include "core/common/logging/sinks/clog_sink.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/framework/tensorboard/event_writer.h"
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
#include "orttraining/models/runner/constant.h"
#include "orttraining/models/runner/training_runner.h"
#include "orttraining/models/runner/training_util.h"

View file

@ -79,8 +79,8 @@ Status TrainingRunner::Initialize() {
if (params_.pipeline_parallel_size > 1 && !params_.pipeline_stage_paths.empty()) {
// Pipeline partition happens outside ORT. We just load the result of partitioning forward graph.
// Backward graph will be generated using ORT's graph transformers.
ORT_ENFORCE(static_cast<size_t>(params_.mpi_context.world_size) == params_.pipeline_stage_paths.size());
ORT_RETURN_IF_ERROR(session_.Load(params_.pipeline_stage_paths[params_.mpi_context.world_rank]));
ORT_ENFORCE(static_cast<size_t>(MPIContext::GetInstance().GetWorldSize()) == params_.pipeline_stage_paths.size());
ORT_RETURN_IF_ERROR(session_.Load(params_.pipeline_stage_paths[MPIContext::GetInstance().GetWorldRank()]));
} else {
ORT_RETURN_IF_ERROR(session_.Load(params_.model_path));
}
@ -98,10 +98,10 @@ Status TrainingRunner::Initialize() {
config.gradient_accumulation_steps = params_.gradient_accumulation_steps;
config.distributed_config.world_rank = params_.mpi_context.world_rank;
config.distributed_config.world_size = params_.mpi_context.world_size;
config.distributed_config.local_size = params_.mpi_context.local_size;
config.distributed_config.local_rank = params_.mpi_context.local_rank;
config.distributed_config.world_rank = MPIContext::GetInstance().GetWorldRank();
config.distributed_config.world_size = MPIContext::GetInstance().GetWorldSize();
config.distributed_config.local_size = MPIContext::GetInstance().GetLocalSize();
config.distributed_config.local_rank = MPIContext::GetInstance().GetLocalRank();
config.distributed_config.data_parallel_size = params_.data_parallel_size;
config.distributed_config.horizontal_parallel_size = params_.horizontal_parallel_size;
config.distributed_config.pipeline_parallel_size = params_.pipeline_parallel_size;
@ -115,7 +115,7 @@ Status TrainingRunner::Initialize() {
}
// always configure the loss function
if (params_.pipeline_parallel_size == 1 || params_.mpi_context.world_rank == params_.mpi_context.world_size - 1) {
if (params_.pipeline_parallel_size == 1 || MPIContext::GetInstance().GetWorldRank() == MPIContext::GetInstance().GetWorldSize() - 1) {
TrainingSession::TrainingConfiguration::LossFunctionConfiguration lf{};
lf.loss_function_info = params_.loss_func_info;
@ -337,7 +337,7 @@ Status TrainingRunner::Initialize() {
Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
const MapStringToString& mapped_dimensions) {
if (params_.mpi_context.world_rank == 0 && !params_.model_actual_running_graph_path.empty()) {
if (MPIContext::GetInstance().GetWorldRank() == 0 && !params_.model_actual_running_graph_path.empty()) {
session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD);
}
@ -810,7 +810,7 @@ void TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
const MapStringToString& mapped_dimensions) {
const bool enable_checkpoint_saving =
params_.mpi_context.world_rank == 0 &&
MPIContext::GetInstance().GetWorldRank() == 0 &&
checkpoint_registry_ && params_.checkpoint_period > 0;
std::unique_ptr<perftest::utils::ICPUUsage> cpu_usage_calculator;

View file

@ -10,7 +10,7 @@
#include "core/framework/ml_value.h"
#include "core/providers/providers.h"
#include "orttraining/core/framework/checkpoint_registry.h"
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/models/runner/data_loader.h"
@ -95,7 +95,6 @@ class TrainingRunner {
bool use_gist = false;
// Whether we collect execution profile trace during this run.
bool use_profiler = false;
MPIContext mpi_context;
bool skip_evaluation = false;
bool dump_fetches = false;
bool dump_convergence_metrics = false;
@ -120,7 +119,7 @@ class TrainingRunner {
float cuda_mem_limit_in_gb = -1.0f;
bool EnableTensorboard() const {
return !is_perf_test && !log_dir.empty() && mpi_context.world_rank == 0;
return !is_perf_test && !log_dir.empty() && MPIContext::GetInstance().GetWorldRank() == 0;
}
bool UseCuda() const {

View file

@ -11,7 +11,7 @@
#include "core/session/environment.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
#include "python/onnxruntime_pybind_mlvalue.h"
namespace onnxruntime {
@ -19,6 +19,7 @@ namespace python {
namespace py = pybind11;
using namespace onnxruntime;
using namespace onnxruntime::logging;
using namespace onnxruntime::training;
struct TrainingParameters {
std::string loss_output_name;
@ -65,25 +66,10 @@ TrainingConfigurationResult ConfigureSessionForTraining(
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
if (data_group_size != parameters.data_parallel_size) {
std::cout << "WARNING: data_parallel_size is not correct, tuned automatically to "
<< data_group_size << std::endl;
LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to "
<< data_group_size;
parameters.data_parallel_size = data_group_size;
}
#if defined(USE_NCCL) || defined(USE_HOROVOD)
// this condition block is temporary.
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
bool use_nccl = parameters.allreduce_post_accumulation;
if (!use_nccl && parameters.world_size > 1) {
auto mpi_context = training::setup_mpi();
std::cout << "mpi_context.world_rank: " << mpi_context.world_rank << std::endl;
std::cout << "mpi_context.local_rank: " << mpi_context.local_rank << std::endl;
std::cout << "mpi_context.world_size: " << mpi_context.world_size << std::endl;
std::cout << "mpi_context.local_size: " << mpi_context.local_size << std::endl;
parameters.local_size = mpi_context.local_size;
parameters.local_rank = mpi_context.local_rank;
}
#endif
training::TrainingSession::TrainingConfiguration config{};
config.weight_names_to_train = parameters.weights_to_train;
@ -158,6 +144,28 @@ TrainingConfigurationResult ConfigureSessionForTraining(
return python_config_result;
}
#if defined(USE_NCCL)
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();
parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
if (parameters.world_rank != 0)
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
}
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
if (parameters.world_size != 1)
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
}
}
#endif
void addObjectMethodsForTraining(py::module& m) {
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
@ -181,6 +189,13 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad);
#if defined(USE_NCCL)
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
m.def("get_mpi_context_local_size", []() -> int { return MPIContext::GetInstance().GetLocalSize(); });
m.def("get_mpi_context_world_rank", []() -> int { return MPIContext::GetInstance().GetWorldRank(); });
m.def("get_mpi_context_world_size", []() -> int { return MPIContext::GetInstance().GetWorldSize(); });
#endif
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
config_result.def(py::init())
.def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object {
@ -200,13 +215,21 @@ void addObjectMethodsForTraining(py::module& m) {
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(GetDefaultCPUSessionOptions(), env);
}))
.def("finalize", [](py::object) {
#if defined(USE_NCCL) || defined(USE_HOROVOD)
training::shutdown_mpi();
#if defined(USE_NCCL)
#ifdef _WIN32
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
// call shutdown_mpi() here instead.
MPIContext::shutdown_mpi();
#endif
#endif
})
.def("load_model", [](onnxruntime::training::TrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
OrtPybindThrowIfError(sess->Load(path));
#if defined(USE_NCCL)
CopyMPIContextToTrainingParameters(parameters, sess->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(sess, parameters);
std::vector<std::string> provider_types = {};
@ -218,6 +241,9 @@ void addObjectMethodsForTraining(py::module& m) {
std::istringstream buffer(serialized_model);
OrtPybindThrowIfError(sess->Load(buffer));
#if defined(USE_NCCL)
CopyMPIContextToTrainingParameters(parameters, sess->GetLogger());
#endif
const auto config_result = ConfigureSessionForTraining(sess, parameters);
std::vector<std::string> provider_types = {};

View file

@ -25,6 +25,7 @@ from transformers import (
import onnxruntime
from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size, get_mpi_context_world_rank, get_mpi_context_world_size
from orttraining_transformer_trainer import ORTTransformerTrainer
@ -95,13 +96,27 @@ class ORTGlueTest(unittest.TestCase):
assert_allclose(results['loss'], expected_loss, rtol=self.rtol)
def test_bert_with_mrpc(self):
expected_acc = 0.8553921568627451
expected_f1 = 0.8970331588132635
expected_acc_and_f1 = 0.8762126578380043
expected_loss = 0.42737212419217707
if self.local_rank == -1:
expected_acc = 0.8553921568627451
expected_f1 = 0.8970331588132635
expected_acc_and_f1 = 0.8762126578380043
expected_loss = 0.42737212419217707
elif self.local_rank == 0:
expected_acc = 0.8308823529411765
expected_f1 = 0.881646655231561
expected_acc_and_f1 = 0.8562645040863688
expected_loss = 0.42491564023144107
for use_new_api in [True, False]:
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=use_new_api)
if self.local_rank == -1:
# not parallel case, we can run both new and old api tests
for use_new_api in [True, False]:
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=use_new_api)
else:
# with parallel training, TrainingArguments can only be created once (due to its cached _setup_devices)
# thus we can only choose one test case to run.
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=True)
if self.local_rank in [-1, 0]:
assert_allclose(results['acc'], expected_acc, rtol=self.rtol)
assert_allclose(results['f1'], expected_f1, rtol=self.rtol)
assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol)
@ -122,6 +137,14 @@ class ORTGlueTest(unittest.TestCase):
def model_to_desc(self, model_name, model):
if model_name.startswith('bert') or model_name.startswith('xlnet'):
new_model_desc = {
'inputs': [
('input_ids', ['batch', 'max_seq_len_in_batch'],),
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
('token_type_ids', ['batch', 'max_seq_len_in_batch'],),
('labels', ['batch', ],)],
'outputs': [('loss', [], True),
('logits', ['batch', 2])]}
model_desc = ModelDescription([
IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size),
IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2),
@ -130,6 +153,13 @@ class ORTGlueTest(unittest.TestCase):
IODescription('loss', [], torch.float32),
IODescription('logits', ['batch', 2], torch.float32)])
elif model_name.startswith('roberta'):
new_model_desc = {
'inputs': [
('input_ids', ['batch', 'max_seq_len_in_batch'],),
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
('labels', ['batch', ],)],
'outputs': [('loss', [], True),
('logits', ['batch', 2])]}
model_desc = ModelDescription([
IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size),
IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2),
@ -139,16 +169,19 @@ class ORTGlueTest(unittest.TestCase):
else:
raise RuntimeError("unsupported base model name {}.".format(model_name))
return model_desc
return model_desc, new_model_desc
def run_glue(self, model_name, task_name, fp16, use_new_api):
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=self.data_dir + "/" + task_name,
data_args = GlueDataTrainingArguments(
task_name=task_name, data_dir=os.path.join(self.data_dir, task_name),
max_seq_length=self.max_seq_length)
training_args = TrainingArguments(output_dir=self.output_dir + "/" + task_name, do_train=True, do_eval=True,
training_args = TrainingArguments(
output_dir=os.path.join(self.output_dir, task_name), do_train=True, do_eval=True,
per_gpu_train_batch_size=self.train_batch_size,
learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,local_rank=self.local_rank,
learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,
local_rank=self.local_rank,
overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps,
fp16=fp16, logging_steps=self.logging_steps)
@ -214,11 +247,12 @@ class ORTGlueTest(unittest.TestCase):
preds = np.squeeze(p.predictions)
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
model_desc = self.model_to_desc(model_name, model)
model_desc, new_model_desc = self.model_to_desc(model_name, model)
# Initialize the ORTTrainer within ORTTransformerTrainer
trainer = ORTTransformerTrainer(
model=model,
model_desc=model_desc,
new_model_desc=new_model_desc,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
@ -247,4 +281,26 @@ class ORTGlueTest(unittest.TestCase):
return results
if __name__ == "__main__":
unittest.main()
if get_mpi_context_world_size() > 1:
# mpi launch
print("mpirun launch")
# TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl")
# pytorch expects following environment settings (which would be set if launched with torch.distributed.launch).
local_rank = get_mpi_context_local_rank()
print("get_mpi_context_local_rank(): ", local_rank)
os.environ['RANK'] = str(local_rank)
os.environ['WORLD_SIZE'] = str(get_mpi_context_world_size())
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
from onnxruntime.capi._pybind_state import set_cuda_device_id
set_cuda_device_id(local_rank)
test = ORTGlueTest()
test.setUp()
test.local_rank = local_rank
test.test_bert_with_mrpc()
else:
unittest.main()

View file

@ -214,6 +214,7 @@ class ORTMultipleChoiceTest(unittest.TestCase):
trainer = ORTTransformerTrainer(
model=model,
model_desc=model_desc,
new_model_desc=None,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,

View file

@ -12,7 +12,7 @@ import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
# from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from tqdm import tqdm, trange
@ -101,6 +101,7 @@ class ORTTransformerTrainer:
self,
model: PreTrainedModel,
model_desc: ModelDescription,
new_model_desc: dict,
args: TrainingArguments,
train_dataset: Dataset,
eval_dataset: Dataset,
@ -112,6 +113,7 @@ class ORTTransformerTrainer:
self.model = model
self.model_desc = model_desc
self.new_model_desc = new_model_desc
self.args = args
self.data_collator = DefaultDataCollator()
self.train_dataset = train_dataset
@ -171,15 +173,6 @@ class ORTTransformerTrainer:
num_train_epochs = self.args.num_train_epochs
if self.use_new_api:
model_desc = {
'inputs': [
('input_ids', ['batch', 'max_seq_len_in_batch'],),
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
('token_type_ids', ['batch', 'max_seq_len_in_batch'],),
('labels', ['batch', ],)],
'outputs': [('loss', [], True),
('logits', ['batch', 2])]}
lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total))
loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
@ -205,7 +198,7 @@ class ORTTransformerTrainer:
]
optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
self.model = orttrainer.ORTTrainer(self.model, model_desc, optim_config, options=options)
self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options)
else:
def map_optimizer_attributes(name):
no_decay = "bias" in name or "LayerNorm.weight" in name
@ -222,7 +215,6 @@ class ORTTransformerTrainer:
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
device=self.args.device,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
world_rank=0, world_size=1, # only support single GPU cases
use_mixed_precision=self.args.fp16,
allreduce_post_accumulation=True,
get_lr_this_step=get_lr_this_step,

View file

@ -4,7 +4,7 @@
#include "nccl_common.h"
#include <mpi.h>
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
namespace onnxruntime {
namespace cuda {

View file

@ -9,7 +9,7 @@
#include "core/providers/cuda/cuda_common.h"
#include <mpi.h>
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
namespace onnxruntime {
namespace cuda {

View file

@ -10,7 +10,7 @@
#include <limits>
#include <mpi.h>
#include "orttraining/core/framework/mpi_setup.h"
#include "orttraining/core/framework/mpi_context.h"
namespace onnxruntime {
namespace cuda {

View file

@ -1101,6 +1101,12 @@ def run_training_python_frontend_e2e_tests(cwd):
# frontend tests are to be added here:
log.info("Running python frontend e2e tests.")
import torch
ngpus = torch.cuda.device_count()
if ngpus > 1:
log.debug('RUN: mpirun -n {} {} orttraining_run_glue.py'.format(ngpus, sys.executable))
run_subprocess(['mpirun', '-n', str(ngpus), sys.executable, 'orttraining_run_glue.py'], cwd=cwd)
# with orttraining_run_glue.py.
# 1. we like to force to use single GPU (with CUDA_VISIBLE_DEVICES)
# for fine-tune tests.

View file

@ -114,6 +114,17 @@ elif [ $DEVICE_TYPE = "gpu" ]; then
${PYTHON_EXE} -m pip install transformers==v2.10.0
# transformers requires sklearn
${PYTHON_EXE} -m pip install sklearn
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
echo "install openmpi"
curl -fsSL https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz -O
tar zxf openmpi-4.0.0.tar.gz
cd openmpi-4.0.0
./configure --enable-orterun-prefix-by-default
make all
make install
ldconfig
fi
fi
fi

View file

@ -47,7 +47,7 @@ if [ "$OS_VERSION" = "16.04" ]; then
libicu55 \
libtinfo-dev \
libtool \
mpich libmpich-dev \
openssh-server \
aria2 \
bzip2 \
unzip \
@ -80,7 +80,7 @@ else # ubuntu18.04
libicu60 \
libtinfo-dev \
libtool \
mpich libmpich-dev \
openssh-server \
aria2 \
bzip2 \
unzip \