mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Glue parallel training (#4550)
add mpi size, rank python API add single node parallel training example
This commit is contained in:
parent
9a6db9b9f4
commit
6260d073b3
19 changed files with 281 additions and 150 deletions
|
|
@ -1,14 +1,36 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "mpi_setup.h"
|
||||
#include "mpi_context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
MPIContext::MPIContext(int w_rank, int l_rank, int w_size, int l_size) : world_rank(w_rank), local_rank(l_rank), world_size(w_size), local_size(l_size) {}
|
||||
MPIContext::MPIContext() :
|
||||
world_rank_(0),
|
||||
local_rank_(0),
|
||||
world_size_(1),
|
||||
local_size_(1)
|
||||
{
|
||||
#if defined(USE_NCCL)
|
||||
setup_mpi();
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
MPIContext setup_mpi() {
|
||||
MPIContext::~MPIContext() {
|
||||
#if defined(USE_NCCL)
|
||||
#ifndef _WIN32
|
||||
shutdown_mpi();
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
const MPIContext& MPIContext::GetInstance() {
|
||||
static MPIContext context;
|
||||
return context;
|
||||
}
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
void MPIContext::setup_mpi() {
|
||||
// setup MPI amd horovod
|
||||
int is_mpi_initialized = 0;
|
||||
MPI_Initialized(&is_mpi_initialized);
|
||||
|
|
@ -27,11 +49,6 @@ MPIContext setup_mpi() {
|
|||
|
||||
MPI_Allgather(&world_rank, 1, MPI_INT, ranks, 1, MPI_INT, MPI_COMM_WORLD);
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
using namespace horovod::common;
|
||||
horovod_init(ranks, world_size);
|
||||
#endif
|
||||
|
||||
//Get local rank and size
|
||||
int local_rank;
|
||||
int local_size;
|
||||
|
|
@ -49,14 +66,13 @@ MPIContext setup_mpi() {
|
|||
|
||||
printf("Using cuda local_rank: %d, world_rank: %d, world_size: %d, local_size: %d\n(version: %s)\n",
|
||||
local_rank, world_rank, world_size, local_size, version);
|
||||
return MPIContext(world_rank, local_rank, world_size, local_size);
|
||||
this->world_rank_ = world_rank;
|
||||
this->local_rank_ = local_rank;
|
||||
this->world_size_ = world_size;
|
||||
this->local_size_ = local_size;
|
||||
}
|
||||
|
||||
void shutdown_mpi() {
|
||||
#ifdef USE_HOROVOD
|
||||
horovod::common::horovod_shutdown();
|
||||
#endif
|
||||
|
||||
void MPIContext::shutdown_mpi() {
|
||||
int is_mpi_initialized = 0;
|
||||
MPI_Initialized(&is_mpi_initialized);
|
||||
if (!is_mpi_initialized)
|
||||
60
orttraining/orttraining/core/framework/mpi_context.h
Normal file
60
orttraining/orttraining/core/framework/mpi_context.h
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
#pragma once
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
#include <mpi.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
#define MPI_CHECK(condition) \
|
||||
do { \
|
||||
int error = (condition); \
|
||||
ORT_ENFORCE( \
|
||||
error == MPI_SUCCESS, \
|
||||
"MPI Error at: ", \
|
||||
__FILE__, \
|
||||
":", \
|
||||
__LINE__, \
|
||||
": ", \
|
||||
error); \
|
||||
} while (0)
|
||||
|
||||
class MPIContext {
|
||||
// https://stackoverflow.com/questions/1008019/c-singleton-design-pattern
|
||||
public:
|
||||
static const MPIContext& GetInstance();
|
||||
|
||||
MPIContext(MPIContext const&) = delete;
|
||||
void operator=(MPIContext const&) = delete;
|
||||
|
||||
// within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi().
|
||||
~MPIContext();
|
||||
|
||||
int GetWorldRank() const { return world_rank_; }
|
||||
int GetLocalRank() const { return local_rank_; }
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
int GetLocalSize() const { return local_size_; }
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction
|
||||
// shutdown_mpi shall be called specifically in user code.
|
||||
static void shutdown_mpi();
|
||||
#endif
|
||||
|
||||
private:
|
||||
MPIContext();
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
void setup_mpi();
|
||||
#endif
|
||||
int world_rank_;
|
||||
int local_rank_;
|
||||
int world_size_;
|
||||
int local_size_;
|
||||
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
#include <mpi.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_HOROVOD
|
||||
#include "orttraining/core/graph/horovod_adapters.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
#define MPI_CHECK(condition) \
|
||||
do { \
|
||||
int error = (condition); \
|
||||
ORT_ENFORCE( \
|
||||
error == MPI_SUCCESS, \
|
||||
"MPI Error at: ", \
|
||||
__FILE__, \
|
||||
":", \
|
||||
__LINE__, \
|
||||
": ", \
|
||||
error); \
|
||||
} while (0)
|
||||
|
||||
struct MPIContext {
|
||||
MPIContext(int world_rank = 0, int local_rank = 0, int world_size = 1, int local_size = 1);
|
||||
int world_rank;
|
||||
int local_rank;
|
||||
int world_size;
|
||||
int local_size;
|
||||
};
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
MPIContext setup_mpi();
|
||||
void shutdown_mpi();
|
||||
#endif
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
#include "core/providers/cuda/cuda_allocator.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/framework/tensorboard/event_writer.h"
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
#include "orttraining/models/runner/constant.h"
|
||||
#include "orttraining/models/runner/training_runner.h"
|
||||
#include "orttraining/models/runner/training_util.h"
|
||||
|
|
@ -529,22 +529,20 @@ void setup_training_params(BertParameters& params) {
|
|||
params.model_actual_running_graph_path = model_name_base + ORT_TSTR("_bw_running.onnx");
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
params.mpi_context = setup_mpi();
|
||||
|
||||
if (params.pipeline_parallel_size > 1) {
|
||||
auto pipeline_model_name_base = model_name_base + ToPathString(std::to_string(params.mpi_context.world_rank));
|
||||
auto pipeline_model_name_base = model_name_base + ToPathString(std::to_string(MPIContext::GetInstance().GetWorldRank()));
|
||||
params.model_with_loss_func_path = pipeline_model_name_base + ORT_TSTR("_with_cost.onnx");
|
||||
params.model_with_training_graph_path = pipeline_model_name_base + ORT_TSTR("_bw.onnx");
|
||||
params.model_actual_running_graph_path = pipeline_model_name_base + ORT_TSTR("_bw_running.onnx");
|
||||
}
|
||||
ORT_ENFORCE(params.horizontal_parallel_size <= params.mpi_context.world_size);
|
||||
ORT_ENFORCE(params.data_parallel_size <= params.mpi_context.world_size);
|
||||
if (params.mpi_context.world_size % params.horizontal_parallel_size != 0) {
|
||||
ORT_ENFORCE(params.horizontal_parallel_size <= MPIContext::GetInstance().GetWorldSize());
|
||||
ORT_ENFORCE(params.data_parallel_size <= MPIContext::GetInstance().GetWorldSize());
|
||||
if (MPIContext::GetInstance().GetWorldSize() % params.horizontal_parallel_size != 0) {
|
||||
LOGS_DEFAULT(ERROR) << "Cannot split horizontal parallel group because world_size is not divisible";
|
||||
return;
|
||||
}
|
||||
|
||||
auto data_group_size = params.mpi_context.world_size / (params.horizontal_parallel_size * params.pipeline_parallel_size);
|
||||
auto data_group_size = MPIContext::GetInstance().GetWorldSize() / (params.horizontal_parallel_size * params.pipeline_parallel_size);
|
||||
ORT_ENFORCE(data_group_size > 0, "Insufficient processes lead to zero-way data parallelism, which should be at least one-way.");
|
||||
if (data_group_size != params.data_parallel_size) {
|
||||
LOGS_DEFAULT(WARNING) << "WARNING: data_parallel_size is not correct, tuned automatically to "
|
||||
|
|
@ -558,7 +556,7 @@ void setup_training_params(BertParameters& params) {
|
|||
#endif
|
||||
|
||||
#ifdef USE_CUDA
|
||||
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(params.mpi_context.local_rank);
|
||||
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(MPIContext::GetInstance().GetLocalRank());
|
||||
size_t cuda_mem_limit = std::numeric_limits<size_t>::max();
|
||||
if (params.cuda_mem_limit_in_gb > 0)
|
||||
cuda_mem_limit = static_cast<size_t>(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024);
|
||||
|
|
@ -628,7 +626,7 @@ void setup_training_params(BertParameters& params) {
|
|||
|
||||
if (params.dump_fetches) {
|
||||
std::ostringstream filename;
|
||||
filename << "./fetch_dumps/rank_" << params.mpi_context.world_rank << "_step_" << step << ".txt";
|
||||
filename << "./fetch_dumps/rank_" << MPIContext::GetInstance().GetWorldRank() << "_step_" << step << ".txt";
|
||||
ofstream ofs(filename.str());
|
||||
for (size_t i = 0; i < fetch_names.size(); ++i) {
|
||||
TrainingUtil::PrintTensor(fetch_names[i], fetches[i].Get<Tensor>(), ofs);
|
||||
|
|
@ -733,7 +731,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env)
|
|||
BertParameters params_for_phase;
|
||||
while (GetParametersForPhase(runner->GetRound(), params, params_for_phase)) {
|
||||
ORT_RETURN_IF_ERROR(runner->UpdateParams(params_for_phase));
|
||||
auto rank_in_data_parallel_group = params_for_phase.mpi_context.world_rank / params_for_phase.horizontal_parallel_size;
|
||||
auto rank_in_data_parallel_group = MPIContext::GetInstance().GetWorldRank() / params_for_phase.horizontal_parallel_size;
|
||||
auto training_data_loader = onnxruntime::make_unique<DataLoader>(params_for_phase.input_name_map,
|
||||
params_for_phase.train_data_dir,
|
||||
max_num_files_preload,
|
||||
|
|
@ -742,7 +740,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env)
|
|||
|
||||
auto test_data_loader = std::unique_ptr<DataLoader>{};
|
||||
// Evaluation is only done in device #0
|
||||
if (params_for_phase.mpi_context.world_rank == 0) {
|
||||
if (MPIContext::GetInstance().GetWorldRank() == 0) {
|
||||
test_data_loader = onnxruntime::make_unique<DataLoader>(params_for_phase.input_name_map,
|
||||
params_for_phase.test_data_dir,
|
||||
max_num_files_preload);
|
||||
|
|
@ -759,7 +757,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env)
|
|||
ORT_RETURN_IF_ERROR(runner->ResetLossScaler());
|
||||
}
|
||||
|
||||
if (params_for_phase.mpi_context.world_rank == 0) {
|
||||
if (MPIContext::GetInstance().GetWorldRank() == 0) {
|
||||
// Pass in empty dataloader to disable evaluation in EndTraining
|
||||
// to avoid a redundant synchronization caused by Tensorboard's SummaryMerge Op.
|
||||
ORT_RETURN_IF_ERROR(runner->EndTraining(nullptr));
|
||||
|
|
@ -807,9 +805,13 @@ int main(int argc, char* argv[]) {
|
|||
RETURN_IF_FAIL(RunTraining(params, *env));
|
||||
}
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
shutdown_mpi();
|
||||
#if defined(USE_NCCL)
|
||||
#ifdef _WIN32
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
|
||||
// call shutdown_mpi() here instead.
|
||||
MPIContext::shutdown_mpi();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "core/session/environment.h"
|
||||
#include "core/framework/random_seed.h"
|
||||
#include "core/providers/cuda/cuda_allocator.h"
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
#include "orttraining/core/framework/tensorboard/event_writer.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/models/runner/constant.h"
|
||||
|
|
@ -294,15 +294,14 @@ void setup_training_params(GPT2Parameters& params) {
|
|||
/*label_name*/ "labels"});
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
params.mpi_context = setup_mpi();
|
||||
ORT_ENFORCE(params.horizontal_parallel_size <= params.mpi_context.world_size);
|
||||
ORT_ENFORCE(params.data_parallel_size <= params.mpi_context.world_size);
|
||||
if (params.mpi_context.world_size % params.horizontal_parallel_size != 0) {
|
||||
ORT_ENFORCE(params.horizontal_parallel_size <= MPIContext::GetInstance().GetWorldSize());
|
||||
ORT_ENFORCE(params.data_parallel_size <= MPIContext::GetInstance().GetWorldSize());
|
||||
if (MPIContext::GetInstance().GetWorldSize() % params.horizontal_parallel_size != 0) {
|
||||
LOGS_DEFAULT(ERROR) << "Cannot split horizontal parallel group because world_size is not divisible";
|
||||
return;
|
||||
}
|
||||
|
||||
auto data_group_size = params.mpi_context.world_size / params.horizontal_parallel_size;
|
||||
auto data_group_size = MPIContext::GetInstance().GetWorldSize() / params.horizontal_parallel_size;
|
||||
if (data_group_size != params.data_parallel_size) {
|
||||
LOGS_DEFAULT(WARNING) << "WARNING: data_parallel_size is not correct, tuned automatically to "
|
||||
<< data_group_size << std::endl;
|
||||
|
|
@ -341,7 +340,7 @@ void setup_training_params(GPT2Parameters& params) {
|
|||
params.model_type = "gpt2";
|
||||
|
||||
#ifdef USE_CUDA
|
||||
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(params.mpi_context.local_rank);
|
||||
OrtDevice::DeviceId device_id = static_cast<OrtDevice::DeviceId>(MPIContext::GetInstance().GetLocalRank());
|
||||
params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id));
|
||||
params.input_allocator = std::make_shared<CUDAPinnedAllocator>(device_id, CUDA_PINNED);
|
||||
#endif
|
||||
|
|
@ -417,7 +416,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env)
|
|||
auto runner = onnxruntime::make_unique<TrainingRunner>(params, env);
|
||||
ORT_RETURN_IF_ERROR(runner->Initialize());
|
||||
|
||||
auto rank_in_data_parallel_group = params.mpi_context.world_rank / params.horizontal_parallel_size;
|
||||
auto rank_in_data_parallel_group = MPIContext::GetInstance().GetWorldRank() / params.horizontal_parallel_size;
|
||||
auto training_data_loader = onnxruntime::make_unique<DataLoader>(params.input_name_map,
|
||||
params.train_data_dir,
|
||||
max_num_files_preload,
|
||||
|
|
@ -426,7 +425,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env)
|
|||
|
||||
std::unique_ptr<DataLoader> test_data_loader;
|
||||
// Evaluation is only done in device #0
|
||||
if (params.mpi_context.world_rank == 0) {
|
||||
if (MPIContext::GetInstance().GetWorldRank() == 0) {
|
||||
test_data_loader = onnxruntime::make_unique<DataLoader>(params.input_name_map,
|
||||
params.test_data_dir,
|
||||
max_num_files_preload);
|
||||
|
|
@ -441,7 +440,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env)
|
|||
ORT_RETURN_IF_ERROR(runner->Run(training_data_loader.get(), test_data_loader.get(), mapped_dimensions));
|
||||
|
||||
// only test and save trained model on device #0
|
||||
if (params.mpi_context.world_rank == 0) {
|
||||
if (MPIContext::GetInstance().GetWorldRank() == 0) {
|
||||
test_data_loader = onnxruntime::make_unique<DataLoader>(params.input_name_map,
|
||||
params.test_data_dir,
|
||||
max_num_files_preload);
|
||||
|
|
@ -478,9 +477,13 @@ int main(int argc, char* argv[]) {
|
|||
RETURN_IF_FAIL(RunTraining(params, *env));
|
||||
}
|
||||
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
shutdown_mpi();
|
||||
#if defined(USE_NCCL)
|
||||
#ifdef _WIN32
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
|
||||
// call shutdown_mpi() here instead.
|
||||
MPIContext::shutdown_mpi();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ void setup_training_params(TrainingRunner::Parameters& params) {
|
|||
};
|
||||
|
||||
std::shared_ptr<EventWriter> tensorboard;
|
||||
if (!params.log_dir.empty() && params.mpi_context.world_rank == 0)
|
||||
if (!params.log_dir.empty() && MPIContext::GetInstance().GetWorldRank() == 0)
|
||||
tensorboard = std::make_shared<EventWriter>(params.log_dir);
|
||||
|
||||
params.post_evaluation_callback = [tensorboard](size_t num_samples, size_t step, const std::string /**/) {
|
||||
|
|
@ -183,12 +183,12 @@ int main(int argc, char* args[]) {
|
|||
setup_training_params(params);
|
||||
|
||||
// setup data
|
||||
auto device_count = params.mpi_context.world_size;
|
||||
auto device_count = MPIContext::GetInstance().GetWorldSize();
|
||||
std::vector<string> feeds{"X", "labels"};
|
||||
auto trainingData = std::make_shared<DataSet>(feeds);
|
||||
auto testData = std::make_shared<DataSet>(feeds);
|
||||
std::string mnist_data_path = ToMBString(params.train_data_dir);
|
||||
PrepareMNISTData(mnist_data_path, IMAGE_DIMS, LABEL_DIMS, *trainingData, *testData, params.mpi_context.world_rank /* shard_to_load */, device_count /* total_shards */);
|
||||
PrepareMNISTData(mnist_data_path, IMAGE_DIMS, LABEL_DIMS, *trainingData, *testData, MPIContext::GetInstance().GetWorldRank() /* shard_to_load */, device_count /* total_shards */);
|
||||
|
||||
if (testData->NumSamples() == 0) {
|
||||
printf("Warning: No data loaded - run cancelled.\n");
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include "core/common/logging/sinks/clog_sink.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/framework/tensorboard/event_writer.h"
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
#include "orttraining/models/runner/constant.h"
|
||||
#include "orttraining/models/runner/training_runner.h"
|
||||
#include "orttraining/models/runner/training_util.h"
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ Status TrainingRunner::Initialize() {
|
|||
if (params_.pipeline_parallel_size > 1 && !params_.pipeline_stage_paths.empty()) {
|
||||
// Pipeline partition happens outside ORT. We just load the result of partitioning forward graph.
|
||||
// Backward graph will be generated using ORT's graph transformers.
|
||||
ORT_ENFORCE(static_cast<size_t>(params_.mpi_context.world_size) == params_.pipeline_stage_paths.size());
|
||||
ORT_RETURN_IF_ERROR(session_.Load(params_.pipeline_stage_paths[params_.mpi_context.world_rank]));
|
||||
ORT_ENFORCE(static_cast<size_t>(MPIContext::GetInstance().GetWorldSize()) == params_.pipeline_stage_paths.size());
|
||||
ORT_RETURN_IF_ERROR(session_.Load(params_.pipeline_stage_paths[MPIContext::GetInstance().GetWorldRank()]));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(session_.Load(params_.model_path));
|
||||
}
|
||||
|
|
@ -98,10 +98,10 @@ Status TrainingRunner::Initialize() {
|
|||
|
||||
config.gradient_accumulation_steps = params_.gradient_accumulation_steps;
|
||||
|
||||
config.distributed_config.world_rank = params_.mpi_context.world_rank;
|
||||
config.distributed_config.world_size = params_.mpi_context.world_size;
|
||||
config.distributed_config.local_size = params_.mpi_context.local_size;
|
||||
config.distributed_config.local_rank = params_.mpi_context.local_rank;
|
||||
config.distributed_config.world_rank = MPIContext::GetInstance().GetWorldRank();
|
||||
config.distributed_config.world_size = MPIContext::GetInstance().GetWorldSize();
|
||||
config.distributed_config.local_size = MPIContext::GetInstance().GetLocalSize();
|
||||
config.distributed_config.local_rank = MPIContext::GetInstance().GetLocalRank();
|
||||
config.distributed_config.data_parallel_size = params_.data_parallel_size;
|
||||
config.distributed_config.horizontal_parallel_size = params_.horizontal_parallel_size;
|
||||
config.distributed_config.pipeline_parallel_size = params_.pipeline_parallel_size;
|
||||
|
|
@ -115,7 +115,7 @@ Status TrainingRunner::Initialize() {
|
|||
}
|
||||
|
||||
// always configure the loss function
|
||||
if (params_.pipeline_parallel_size == 1 || params_.mpi_context.world_rank == params_.mpi_context.world_size - 1) {
|
||||
if (params_.pipeline_parallel_size == 1 || MPIContext::GetInstance().GetWorldRank() == MPIContext::GetInstance().GetWorldSize() - 1) {
|
||||
TrainingSession::TrainingConfiguration::LossFunctionConfiguration lf{};
|
||||
lf.loss_function_info = params_.loss_func_info;
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ Status TrainingRunner::Initialize() {
|
|||
|
||||
Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions) {
|
||||
if (params_.mpi_context.world_rank == 0 && !params_.model_actual_running_graph_path.empty()) {
|
||||
if (MPIContext::GetInstance().GetWorldRank() == 0 && !params_.model_actual_running_graph_path.empty()) {
|
||||
session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD);
|
||||
}
|
||||
|
||||
|
|
@ -810,7 +810,7 @@ void TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
|
|||
Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions) {
|
||||
const bool enable_checkpoint_saving =
|
||||
params_.mpi_context.world_rank == 0 &&
|
||||
MPIContext::GetInstance().GetWorldRank() == 0 &&
|
||||
checkpoint_registry_ && params_.checkpoint_period > 0;
|
||||
|
||||
std::unique_ptr<perftest::utils::ICPUUsage> cpu_usage_calculator;
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include "core/framework/ml_value.h"
|
||||
#include "core/providers/providers.h"
|
||||
#include "orttraining/core/framework/checkpoint_registry.h"
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/models/runner/data_loader.h"
|
||||
|
|
@ -95,7 +95,6 @@ class TrainingRunner {
|
|||
bool use_gist = false;
|
||||
// Whether we collect execution profile trace during this run.
|
||||
bool use_profiler = false;
|
||||
MPIContext mpi_context;
|
||||
bool skip_evaluation = false;
|
||||
bool dump_fetches = false;
|
||||
bool dump_convergence_metrics = false;
|
||||
|
|
@ -120,7 +119,7 @@ class TrainingRunner {
|
|||
float cuda_mem_limit_in_gb = -1.0f;
|
||||
|
||||
bool EnableTensorboard() const {
|
||||
return !is_perf_test && !log_dir.empty() && mpi_context.world_rank == 0;
|
||||
return !is_perf_test && !log_dir.empty() && MPIContext::GetInstance().GetWorldRank() == 0;
|
||||
}
|
||||
|
||||
bool UseCuda() const {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
#include "core/session/environment.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
#include "python/onnxruntime_pybind_mlvalue.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -19,6 +19,7 @@ namespace python {
|
|||
namespace py = pybind11;
|
||||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::logging;
|
||||
using namespace onnxruntime::training;
|
||||
|
||||
struct TrainingParameters {
|
||||
std::string loss_output_name;
|
||||
|
|
@ -65,25 +66,10 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
|
||||
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
|
||||
if (data_group_size != parameters.data_parallel_size) {
|
||||
std::cout << "WARNING: data_parallel_size is not correct, tuned automatically to "
|
||||
<< data_group_size << std::endl;
|
||||
LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to "
|
||||
<< data_group_size;
|
||||
parameters.data_parallel_size = data_group_size;
|
||||
}
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
// this condition block is temporary.
|
||||
// For now, nccl allreduce kernel only implements for allreduce_post_accumulation
|
||||
// hovorod allreduce kernel only implements for not allreduce_post_accumulation.
|
||||
bool use_nccl = parameters.allreduce_post_accumulation;
|
||||
if (!use_nccl && parameters.world_size > 1) {
|
||||
auto mpi_context = training::setup_mpi();
|
||||
std::cout << "mpi_context.world_rank: " << mpi_context.world_rank << std::endl;
|
||||
std::cout << "mpi_context.local_rank: " << mpi_context.local_rank << std::endl;
|
||||
std::cout << "mpi_context.world_size: " << mpi_context.world_size << std::endl;
|
||||
std::cout << "mpi_context.local_size: " << mpi_context.local_size << std::endl;
|
||||
parameters.local_size = mpi_context.local_size;
|
||||
parameters.local_rank = mpi_context.local_rank;
|
||||
}
|
||||
#endif
|
||||
|
||||
training::TrainingSession::TrainingConfiguration config{};
|
||||
config.weight_names_to_train = parameters.weights_to_train;
|
||||
|
|
@ -158,6 +144,28 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
return python_config_result;
|
||||
}
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) {
|
||||
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank();
|
||||
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank();
|
||||
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize();
|
||||
LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize();
|
||||
|
||||
parameters.local_rank = MPIContext::GetInstance().GetLocalRank();
|
||||
parameters.local_size = MPIContext::GetInstance().GetLocalSize();
|
||||
if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) {
|
||||
if (parameters.world_rank != 0)
|
||||
LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank();
|
||||
parameters.world_rank = MPIContext::GetInstance().GetWorldRank();
|
||||
}
|
||||
if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) {
|
||||
if (parameters.world_size != 1)
|
||||
LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize();
|
||||
parameters.world_size = MPIContext::GetInstance().GetWorldSize();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void addObjectMethodsForTraining(py::module& m) {
|
||||
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
|
||||
parameters.def(py::init())
|
||||
|
|
@ -181,6 +189,13 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
|
||||
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad);
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
|
||||
m.def("get_mpi_context_local_size", []() -> int { return MPIContext::GetInstance().GetLocalSize(); });
|
||||
m.def("get_mpi_context_world_rank", []() -> int { return MPIContext::GetInstance().GetWorldRank(); });
|
||||
m.def("get_mpi_context_world_size", []() -> int { return MPIContext::GetInstance().GetWorldSize(); });
|
||||
#endif
|
||||
|
||||
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
|
||||
config_result.def(py::init())
|
||||
.def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object {
|
||||
|
|
@ -200,13 +215,21 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(GetDefaultCPUSessionOptions(), env);
|
||||
}))
|
||||
.def("finalize", [](py::object) {
|
||||
#if defined(USE_NCCL) || defined(USE_HOROVOD)
|
||||
training::shutdown_mpi();
|
||||
#if defined(USE_NCCL)
|
||||
#ifdef _WIN32
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices
|
||||
// shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction
|
||||
// call shutdown_mpi() here instead.
|
||||
MPIContext::shutdown_mpi();
|
||||
#endif
|
||||
#endif
|
||||
})
|
||||
.def("load_model", [](onnxruntime::training::TrainingSession* sess, const std::string& path, TrainingParameters& parameters) {
|
||||
OrtPybindThrowIfError(sess->Load(path));
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
CopyMPIContextToTrainingParameters(parameters, sess->GetLogger());
|
||||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(sess, parameters);
|
||||
|
||||
std::vector<std::string> provider_types = {};
|
||||
|
|
@ -218,6 +241,9 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
std::istringstream buffer(serialized_model);
|
||||
OrtPybindThrowIfError(sess->Load(buffer));
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
CopyMPIContextToTrainingParameters(parameters, sess->GetLogger());
|
||||
#endif
|
||||
const auto config_result = ConfigureSessionForTraining(sess, parameters);
|
||||
|
||||
std::vector<std::string> provider_types = {};
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from transformers import (
|
|||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size, get_mpi_context_world_rank, get_mpi_context_world_size
|
||||
|
||||
from orttraining_transformer_trainer import ORTTransformerTrainer
|
||||
|
||||
|
|
@ -95,13 +96,27 @@ class ORTGlueTest(unittest.TestCase):
|
|||
assert_allclose(results['loss'], expected_loss, rtol=self.rtol)
|
||||
|
||||
def test_bert_with_mrpc(self):
|
||||
expected_acc = 0.8553921568627451
|
||||
expected_f1 = 0.8970331588132635
|
||||
expected_acc_and_f1 = 0.8762126578380043
|
||||
expected_loss = 0.42737212419217707
|
||||
if self.local_rank == -1:
|
||||
expected_acc = 0.8553921568627451
|
||||
expected_f1 = 0.8970331588132635
|
||||
expected_acc_and_f1 = 0.8762126578380043
|
||||
expected_loss = 0.42737212419217707
|
||||
elif self.local_rank == 0:
|
||||
expected_acc = 0.8308823529411765
|
||||
expected_f1 = 0.881646655231561
|
||||
expected_acc_and_f1 = 0.8562645040863688
|
||||
expected_loss = 0.42491564023144107
|
||||
|
||||
for use_new_api in [True, False]:
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=use_new_api)
|
||||
if self.local_rank == -1:
|
||||
# not parallel case, we can run both new and old api tests
|
||||
for use_new_api in [True, False]:
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=use_new_api)
|
||||
else:
|
||||
# with parallel training, TrainingArguments can only be created once (due to its cached _setup_devices)
|
||||
# thus we can only choose one test case to run.
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=True)
|
||||
|
||||
if self.local_rank in [-1, 0]:
|
||||
assert_allclose(results['acc'], expected_acc, rtol=self.rtol)
|
||||
assert_allclose(results['f1'], expected_f1, rtol=self.rtol)
|
||||
assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol)
|
||||
|
|
@ -122,6 +137,14 @@ class ORTGlueTest(unittest.TestCase):
|
|||
|
||||
def model_to_desc(self, model_name, model):
|
||||
if model_name.startswith('bert') or model_name.startswith('xlnet'):
|
||||
new_model_desc = {
|
||||
'inputs': [
|
||||
('input_ids', ['batch', 'max_seq_len_in_batch'],),
|
||||
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
|
||||
('token_type_ids', ['batch', 'max_seq_len_in_batch'],),
|
||||
('labels', ['batch', ],)],
|
||||
'outputs': [('loss', [], True),
|
||||
('logits', ['batch', 2])]}
|
||||
model_desc = ModelDescription([
|
||||
IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size),
|
||||
IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2),
|
||||
|
|
@ -130,6 +153,13 @@ class ORTGlueTest(unittest.TestCase):
|
|||
IODescription('loss', [], torch.float32),
|
||||
IODescription('logits', ['batch', 2], torch.float32)])
|
||||
elif model_name.startswith('roberta'):
|
||||
new_model_desc = {
|
||||
'inputs': [
|
||||
('input_ids', ['batch', 'max_seq_len_in_batch'],),
|
||||
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
|
||||
('labels', ['batch', ],)],
|
||||
'outputs': [('loss', [], True),
|
||||
('logits', ['batch', 2])]}
|
||||
model_desc = ModelDescription([
|
||||
IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size),
|
||||
IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2),
|
||||
|
|
@ -139,16 +169,19 @@ class ORTGlueTest(unittest.TestCase):
|
|||
else:
|
||||
raise RuntimeError("unsupported base model name {}.".format(model_name))
|
||||
|
||||
return model_desc
|
||||
return model_desc, new_model_desc
|
||||
|
||||
def run_glue(self, model_name, task_name, fp16, use_new_api):
|
||||
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
|
||||
data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=self.data_dir + "/" + task_name,
|
||||
data_args = GlueDataTrainingArguments(
|
||||
task_name=task_name, data_dir=os.path.join(self.data_dir, task_name),
|
||||
max_seq_length=self.max_seq_length)
|
||||
|
||||
training_args = TrainingArguments(output_dir=self.output_dir + "/" + task_name, do_train=True, do_eval=True,
|
||||
training_args = TrainingArguments(
|
||||
output_dir=os.path.join(self.output_dir, task_name), do_train=True, do_eval=True,
|
||||
per_gpu_train_batch_size=self.train_batch_size,
|
||||
learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,local_rank=self.local_rank,
|
||||
learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,
|
||||
local_rank=self.local_rank,
|
||||
overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
fp16=fp16, logging_steps=self.logging_steps)
|
||||
|
||||
|
|
@ -214,11 +247,12 @@ class ORTGlueTest(unittest.TestCase):
|
|||
preds = np.squeeze(p.predictions)
|
||||
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
|
||||
|
||||
model_desc = self.model_to_desc(model_name, model)
|
||||
model_desc, new_model_desc = self.model_to_desc(model_name, model)
|
||||
# Initialize the ORTTrainer within ORTTransformerTrainer
|
||||
trainer = ORTTransformerTrainer(
|
||||
model=model,
|
||||
model_desc=model_desc,
|
||||
new_model_desc=new_model_desc,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
|
@ -247,4 +281,26 @@ class ORTGlueTest(unittest.TestCase):
|
|||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
if get_mpi_context_world_size() > 1:
|
||||
# mpi launch
|
||||
|
||||
print("mpirun launch")
|
||||
# TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl")
|
||||
# pytorch expects following environment settings (which would be set if launched with torch.distributed.launch).
|
||||
|
||||
local_rank = get_mpi_context_local_rank()
|
||||
print("get_mpi_context_local_rank(): ", local_rank)
|
||||
os.environ['RANK'] = str(local_rank)
|
||||
os.environ['WORLD_SIZE'] = str(get_mpi_context_world_size())
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id
|
||||
set_cuda_device_id(local_rank)
|
||||
|
||||
test = ORTGlueTest()
|
||||
test.setUp()
|
||||
test.local_rank = local_rank
|
||||
test.test_bert_with_mrpc()
|
||||
else:
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -214,6 +214,7 @@ class ORTMultipleChoiceTest(unittest.TestCase):
|
|||
trainer = ORTTransformerTrainer(
|
||||
model=model,
|
||||
model_desc=model_desc,
|
||||
new_model_desc=None,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
# from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
|
|
@ -101,6 +101,7 @@ class ORTTransformerTrainer:
|
|||
self,
|
||||
model: PreTrainedModel,
|
||||
model_desc: ModelDescription,
|
||||
new_model_desc: dict,
|
||||
args: TrainingArguments,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Dataset,
|
||||
|
|
@ -112,6 +113,7 @@ class ORTTransformerTrainer:
|
|||
|
||||
self.model = model
|
||||
self.model_desc = model_desc
|
||||
self.new_model_desc = new_model_desc
|
||||
self.args = args
|
||||
self.data_collator = DefaultDataCollator()
|
||||
self.train_dataset = train_dataset
|
||||
|
|
@ -171,15 +173,6 @@ class ORTTransformerTrainer:
|
|||
num_train_epochs = self.args.num_train_epochs
|
||||
|
||||
if self.use_new_api:
|
||||
model_desc = {
|
||||
'inputs': [
|
||||
('input_ids', ['batch', 'max_seq_len_in_batch'],),
|
||||
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
|
||||
('token_type_ids', ['batch', 'max_seq_len_in_batch'],),
|
||||
('labels', ['batch', ],)],
|
||||
'outputs': [('loss', [], True),
|
||||
('logits', ['batch', 2])]}
|
||||
|
||||
lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total))
|
||||
|
||||
loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None
|
||||
|
|
@ -205,7 +198,7 @@ class ORTTransformerTrainer:
|
|||
]
|
||||
|
||||
optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
|
||||
self.model = orttrainer.ORTTrainer(self.model, model_desc, optim_config, options=options)
|
||||
self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options)
|
||||
else:
|
||||
def map_optimizer_attributes(name):
|
||||
no_decay = "bias" in name or "LayerNorm.weight" in name
|
||||
|
|
@ -222,7 +215,6 @@ class ORTTransformerTrainer:
|
|||
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
|
||||
device=self.args.device,
|
||||
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
||||
world_rank=0, world_size=1, # only support single GPU cases
|
||||
use_mixed_precision=self.args.fp16,
|
||||
allreduce_post_accumulation=True,
|
||||
get_lr_this_step=get_lr_this_step,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#include "nccl_common.h"
|
||||
#include <mpi.h>
|
||||
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include <mpi.h>
|
||||
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include <limits>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "orttraining/core/framework/mpi_context.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
|
|||
|
|
@ -1101,6 +1101,12 @@ def run_training_python_frontend_e2e_tests(cwd):
|
|||
# frontend tests are to be added here:
|
||||
log.info("Running python frontend e2e tests.")
|
||||
|
||||
import torch
|
||||
ngpus = torch.cuda.device_count()
|
||||
if ngpus > 1:
|
||||
log.debug('RUN: mpirun -n {} {} orttraining_run_glue.py'.format(ngpus, sys.executable))
|
||||
run_subprocess(['mpirun', '-n', str(ngpus), sys.executable, 'orttraining_run_glue.py'], cwd=cwd)
|
||||
|
||||
# with orttraining_run_glue.py.
|
||||
# 1. we like to force to use single GPU (with CUDA_VISIBLE_DEVICES)
|
||||
# for fine-tune tests.
|
||||
|
|
|
|||
|
|
@ -114,6 +114,17 @@ elif [ $DEVICE_TYPE = "gpu" ]; then
|
|||
${PYTHON_EXE} -m pip install transformers==v2.10.0
|
||||
# transformers requires sklearn
|
||||
${PYTHON_EXE} -m pip install sklearn
|
||||
|
||||
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
|
||||
echo "install openmpi"
|
||||
curl -fsSL https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz -O
|
||||
tar zxf openmpi-4.0.0.tar.gz
|
||||
cd openmpi-4.0.0
|
||||
./configure --enable-orterun-prefix-by-default
|
||||
make all
|
||||
make install
|
||||
ldconfig
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ if [ "$OS_VERSION" = "16.04" ]; then
|
|||
libicu55 \
|
||||
libtinfo-dev \
|
||||
libtool \
|
||||
mpich libmpich-dev \
|
||||
openssh-server \
|
||||
aria2 \
|
||||
bzip2 \
|
||||
unzip \
|
||||
|
|
@ -80,7 +80,7 @@ else # ubuntu18.04
|
|||
libicu60 \
|
||||
libtinfo-dev \
|
||||
libtool \
|
||||
mpich libmpich-dev \
|
||||
openssh-server \
|
||||
aria2 \
|
||||
bzip2 \
|
||||
unzip \
|
||||
|
|
|
|||
Loading…
Reference in a new issue