diff --git a/orttraining/orttraining/core/framework/mpi_setup.cc b/orttraining/orttraining/core/framework/mpi_context.cc similarity index 72% rename from orttraining/orttraining/core/framework/mpi_setup.cc rename to orttraining/orttraining/core/framework/mpi_context.cc index bf69bb12d9..4990f11812 100644 --- a/orttraining/orttraining/core/framework/mpi_setup.cc +++ b/orttraining/orttraining/core/framework/mpi_context.cc @@ -1,14 +1,36 @@ #include #include -#include "mpi_setup.h" +#include "mpi_context.h" namespace onnxruntime { namespace training { -MPIContext::MPIContext(int w_rank, int l_rank, int w_size, int l_size) : world_rank(w_rank), local_rank(l_rank), world_size(w_size), local_size(l_size) {} +MPIContext::MPIContext() : +world_rank_(0), +local_rank_(0), +world_size_(1), +local_size_(1) +{ +#if defined(USE_NCCL) + setup_mpi(); +#endif +} -#if defined(USE_NCCL) || defined(USE_HOROVOD) -MPIContext setup_mpi() { +MPIContext::~MPIContext() { +#if defined(USE_NCCL) +#ifndef _WIN32 + shutdown_mpi(); +#endif +#endif +} + +const MPIContext& MPIContext::GetInstance() { + static MPIContext context; + return context; +} + +#if defined(USE_NCCL) +void MPIContext::setup_mpi() { // setup MPI amd horovod int is_mpi_initialized = 0; MPI_Initialized(&is_mpi_initialized); @@ -27,11 +49,6 @@ MPIContext setup_mpi() { MPI_Allgather(&world_rank, 1, MPI_INT, ranks, 1, MPI_INT, MPI_COMM_WORLD); -#ifdef USE_HOROVOD - using namespace horovod::common; - horovod_init(ranks, world_size); -#endif - //Get local rank and size int local_rank; int local_size; @@ -49,14 +66,13 @@ MPIContext setup_mpi() { printf("Using cuda local_rank: %d, world_rank: %d, world_size: %d, local_size: %d\n(version: %s)\n", local_rank, world_rank, world_size, local_size, version); - return MPIContext(world_rank, local_rank, world_size, local_size); + this->world_rank_ = world_rank; + this->local_rank_ = local_rank; + this->world_size_ = world_size; + this->local_size_ = local_size; } -void shutdown_mpi() { -#ifdef USE_HOROVOD - horovod::common::horovod_shutdown(); -#endif - +void MPIContext::shutdown_mpi() { int is_mpi_initialized = 0; MPI_Initialized(&is_mpi_initialized); if (!is_mpi_initialized) diff --git a/orttraining/orttraining/core/framework/mpi_context.h b/orttraining/orttraining/core/framework/mpi_context.h new file mode 100644 index 0000000000..6caa9f274e --- /dev/null +++ b/orttraining/orttraining/core/framework/mpi_context.h @@ -0,0 +1,60 @@ +#pragma once + +#if defined(USE_NCCL) +#include +#endif + +namespace onnxruntime { +namespace training { + +#define MPI_CHECK(condition) \ + do { \ + int error = (condition); \ + ORT_ENFORCE( \ + error == MPI_SUCCESS, \ + "MPI Error at: ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ": ", \ + error); \ + } while (0) + +class MPIContext { + // https://stackoverflow.com/questions/1008019/c-singleton-design-pattern + public: + static const MPIContext& GetInstance(); + + MPIContext(MPIContext const&) = delete; + void operator=(MPIContext const&) = delete; + + // within ~MPIContext() we need to check for _WIN32 before calling shutdown_mpi(). + ~MPIContext(); + + int GetWorldRank() const { return world_rank_; } + int GetLocalRank() const { return local_rank_; } + int GetWorldSize() const { return world_size_; } + int GetLocalSize() const { return local_size_; } + +#if defined(USE_NCCL) + // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices + // in case of _WIN32 we cannot call shutdown_mpi() in MPIContext destructor because of DllMain's restriction + // shutdown_mpi shall be called specifically in user code. + static void shutdown_mpi(); +#endif + + private: + MPIContext(); + +#if defined(USE_NCCL) + void setup_mpi(); +#endif + int world_rank_; + int local_rank_; + int world_size_; + int local_size_; + +}; + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/framework/mpi_setup.h b/orttraining/orttraining/core/framework/mpi_setup.h deleted file mode 100644 index 1d9f6515d6..0000000000 --- a/orttraining/orttraining/core/framework/mpi_setup.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#if defined(USE_NCCL) || defined(USE_HOROVOD) -#include -#endif - -#ifdef USE_HOROVOD -#include "orttraining/core/graph/horovod_adapters.h" -#endif - -namespace onnxruntime { -namespace training { - -#define MPI_CHECK(condition) \ - do { \ - int error = (condition); \ - ORT_ENFORCE( \ - error == MPI_SUCCESS, \ - "MPI Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - error); \ - } while (0) - -struct MPIContext { - MPIContext(int world_rank = 0, int local_rank = 0, int world_size = 1, int local_size = 1); - int world_rank; - int local_rank; - int world_size; - int local_size; -}; - -#if defined(USE_NCCL) || defined(USE_HOROVOD) -MPIContext setup_mpi(); -void shutdown_mpi(); -#endif - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index 816d1018de..5a6d942da6 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -12,7 +12,7 @@ #include "core/providers/cuda/cuda_allocator.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/tensorboard/event_writer.h" -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" #include "orttraining/models/runner/constant.h" #include "orttraining/models/runner/training_runner.h" #include "orttraining/models/runner/training_util.h" @@ -529,22 +529,20 @@ void setup_training_params(BertParameters& params) { params.model_actual_running_graph_path = model_name_base + ORT_TSTR("_bw_running.onnx"); #if defined(USE_NCCL) || defined(USE_HOROVOD) - params.mpi_context = setup_mpi(); - if (params.pipeline_parallel_size > 1) { - auto pipeline_model_name_base = model_name_base + ToPathString(std::to_string(params.mpi_context.world_rank)); + auto pipeline_model_name_base = model_name_base + ToPathString(std::to_string(MPIContext::GetInstance().GetWorldRank())); params.model_with_loss_func_path = pipeline_model_name_base + ORT_TSTR("_with_cost.onnx"); params.model_with_training_graph_path = pipeline_model_name_base + ORT_TSTR("_bw.onnx"); params.model_actual_running_graph_path = pipeline_model_name_base + ORT_TSTR("_bw_running.onnx"); } - ORT_ENFORCE(params.horizontal_parallel_size <= params.mpi_context.world_size); - ORT_ENFORCE(params.data_parallel_size <= params.mpi_context.world_size); - if (params.mpi_context.world_size % params.horizontal_parallel_size != 0) { + ORT_ENFORCE(params.horizontal_parallel_size <= MPIContext::GetInstance().GetWorldSize()); + ORT_ENFORCE(params.data_parallel_size <= MPIContext::GetInstance().GetWorldSize()); + if (MPIContext::GetInstance().GetWorldSize() % params.horizontal_parallel_size != 0) { LOGS_DEFAULT(ERROR) << "Cannot split horizontal parallel group because world_size is not divisible"; return; } - auto data_group_size = params.mpi_context.world_size / (params.horizontal_parallel_size * params.pipeline_parallel_size); + auto data_group_size = MPIContext::GetInstance().GetWorldSize() / (params.horizontal_parallel_size * params.pipeline_parallel_size); ORT_ENFORCE(data_group_size > 0, "Insufficient processes lead to zero-way data parallelism, which should be at least one-way."); if (data_group_size != params.data_parallel_size) { LOGS_DEFAULT(WARNING) << "WARNING: data_parallel_size is not correct, tuned automatically to " @@ -558,7 +556,7 @@ void setup_training_params(BertParameters& params) { #endif #ifdef USE_CUDA - OrtDevice::DeviceId device_id = static_cast(params.mpi_context.local_rank); + OrtDevice::DeviceId device_id = static_cast(MPIContext::GetInstance().GetLocalRank()); size_t cuda_mem_limit = std::numeric_limits::max(); if (params.cuda_mem_limit_in_gb > 0) cuda_mem_limit = static_cast(params.cuda_mem_limit_in_gb * 1024 * 1024 * 1024); @@ -628,7 +626,7 @@ void setup_training_params(BertParameters& params) { if (params.dump_fetches) { std::ostringstream filename; - filename << "./fetch_dumps/rank_" << params.mpi_context.world_rank << "_step_" << step << ".txt"; + filename << "./fetch_dumps/rank_" << MPIContext::GetInstance().GetWorldRank() << "_step_" << step << ".txt"; ofstream ofs(filename.str()); for (size_t i = 0; i < fetch_names.size(); ++i) { TrainingUtil::PrintTensor(fetch_names[i], fetches[i].Get(), ofs); @@ -733,7 +731,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env) BertParameters params_for_phase; while (GetParametersForPhase(runner->GetRound(), params, params_for_phase)) { ORT_RETURN_IF_ERROR(runner->UpdateParams(params_for_phase)); - auto rank_in_data_parallel_group = params_for_phase.mpi_context.world_rank / params_for_phase.horizontal_parallel_size; + auto rank_in_data_parallel_group = MPIContext::GetInstance().GetWorldRank() / params_for_phase.horizontal_parallel_size; auto training_data_loader = onnxruntime::make_unique(params_for_phase.input_name_map, params_for_phase.train_data_dir, max_num_files_preload, @@ -742,7 +740,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env) auto test_data_loader = std::unique_ptr{}; // Evaluation is only done in device #0 - if (params_for_phase.mpi_context.world_rank == 0) { + if (MPIContext::GetInstance().GetWorldRank() == 0) { test_data_loader = onnxruntime::make_unique(params_for_phase.input_name_map, params_for_phase.test_data_dir, max_num_files_preload); @@ -759,7 +757,7 @@ static Status RunTraining(const BertParameters& params, const Environment& env) ORT_RETURN_IF_ERROR(runner->ResetLossScaler()); } - if (params_for_phase.mpi_context.world_rank == 0) { + if (MPIContext::GetInstance().GetWorldRank() == 0) { // Pass in empty dataloader to disable evaluation in EndTraining // to avoid a redundant synchronization caused by Tensorboard's SummaryMerge Op. ORT_RETURN_IF_ERROR(runner->EndTraining(nullptr)); @@ -807,9 +805,13 @@ int main(int argc, char* argv[]) { RETURN_IF_FAIL(RunTraining(params, *env)); } -#if defined(USE_NCCL) || defined(USE_HOROVOD) - shutdown_mpi(); +#if defined(USE_NCCL) +#ifdef _WIN32 + // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices + // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction + // call shutdown_mpi() here instead. + MPIContext::shutdown_mpi(); +#endif #endif - return 0; } diff --git a/orttraining/orttraining/models/gpt2/main.cc b/orttraining/orttraining/models/gpt2/main.cc index 99214bb75d..d85690d026 100644 --- a/orttraining/orttraining/models/gpt2/main.cc +++ b/orttraining/orttraining/models/gpt2/main.cc @@ -9,7 +9,7 @@ #include "core/session/environment.h" #include "core/framework/random_seed.h" #include "core/providers/cuda/cuda_allocator.h" -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" #include "orttraining/core/framework/tensorboard/event_writer.h" #include "orttraining/core/session/training_session.h" #include "orttraining/models/runner/constant.h" @@ -294,15 +294,14 @@ void setup_training_params(GPT2Parameters& params) { /*label_name*/ "labels"}); #if defined(USE_NCCL) || defined(USE_HOROVOD) - params.mpi_context = setup_mpi(); - ORT_ENFORCE(params.horizontal_parallel_size <= params.mpi_context.world_size); - ORT_ENFORCE(params.data_parallel_size <= params.mpi_context.world_size); - if (params.mpi_context.world_size % params.horizontal_parallel_size != 0) { + ORT_ENFORCE(params.horizontal_parallel_size <= MPIContext::GetInstance().GetWorldSize()); + ORT_ENFORCE(params.data_parallel_size <= MPIContext::GetInstance().GetWorldSize()); + if (MPIContext::GetInstance().GetWorldSize() % params.horizontal_parallel_size != 0) { LOGS_DEFAULT(ERROR) << "Cannot split horizontal parallel group because world_size is not divisible"; return; } - auto data_group_size = params.mpi_context.world_size / params.horizontal_parallel_size; + auto data_group_size = MPIContext::GetInstance().GetWorldSize() / params.horizontal_parallel_size; if (data_group_size != params.data_parallel_size) { LOGS_DEFAULT(WARNING) << "WARNING: data_parallel_size is not correct, tuned automatically to " << data_group_size << std::endl; @@ -341,7 +340,7 @@ void setup_training_params(GPT2Parameters& params) { params.model_type = "gpt2"; #ifdef USE_CUDA - OrtDevice::DeviceId device_id = static_cast(params.mpi_context.local_rank); + OrtDevice::DeviceId device_id = static_cast(MPIContext::GetInstance().GetLocalRank()); params.providers.emplace(kCudaExecutionProvider, CreateExecutionProviderFactory_CUDA(device_id)); params.input_allocator = std::make_shared(device_id, CUDA_PINNED); #endif @@ -417,7 +416,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env) auto runner = onnxruntime::make_unique(params, env); ORT_RETURN_IF_ERROR(runner->Initialize()); - auto rank_in_data_parallel_group = params.mpi_context.world_rank / params.horizontal_parallel_size; + auto rank_in_data_parallel_group = MPIContext::GetInstance().GetWorldRank() / params.horizontal_parallel_size; auto training_data_loader = onnxruntime::make_unique(params.input_name_map, params.train_data_dir, max_num_files_preload, @@ -426,7 +425,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env) std::unique_ptr test_data_loader; // Evaluation is only done in device #0 - if (params.mpi_context.world_rank == 0) { + if (MPIContext::GetInstance().GetWorldRank() == 0) { test_data_loader = onnxruntime::make_unique(params.input_name_map, params.test_data_dir, max_num_files_preload); @@ -441,7 +440,7 @@ static Status RunTraining(const GPT2Parameters& params, const Environment& env) ORT_RETURN_IF_ERROR(runner->Run(training_data_loader.get(), test_data_loader.get(), mapped_dimensions)); // only test and save trained model on device #0 - if (params.mpi_context.world_rank == 0) { + if (MPIContext::GetInstance().GetWorldRank() == 0) { test_data_loader = onnxruntime::make_unique(params.input_name_map, params.test_data_dir, max_num_files_preload); @@ -478,9 +477,13 @@ int main(int argc, char* argv[]) { RETURN_IF_FAIL(RunTraining(params, *env)); } -#if defined(USE_NCCL) || defined(USE_HOROVOD) - shutdown_mpi(); +#if defined(USE_NCCL) +#ifdef _WIN32 + // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices + // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction + // call shutdown_mpi() here instead. + MPIContext::shutdown_mpi(); +#endif #endif - return 0; } diff --git a/orttraining/orttraining/models/mnist/main.cc b/orttraining/orttraining/models/mnist/main.cc index 1674b1b723..f3dfe8fe79 100644 --- a/orttraining/orttraining/models/mnist/main.cc +++ b/orttraining/orttraining/models/mnist/main.cc @@ -143,7 +143,7 @@ void setup_training_params(TrainingRunner::Parameters& params) { }; std::shared_ptr tensorboard; - if (!params.log_dir.empty() && params.mpi_context.world_rank == 0) + if (!params.log_dir.empty() && MPIContext::GetInstance().GetWorldRank() == 0) tensorboard = std::make_shared(params.log_dir); params.post_evaluation_callback = [tensorboard](size_t num_samples, size_t step, const std::string /**/) { @@ -183,12 +183,12 @@ int main(int argc, char* args[]) { setup_training_params(params); // setup data - auto device_count = params.mpi_context.world_size; + auto device_count = MPIContext::GetInstance().GetWorldSize(); std::vector feeds{"X", "labels"}; auto trainingData = std::make_shared(feeds); auto testData = std::make_shared(feeds); std::string mnist_data_path = ToMBString(params.train_data_dir); - PrepareMNISTData(mnist_data_path, IMAGE_DIMS, LABEL_DIMS, *trainingData, *testData, params.mpi_context.world_rank /* shard_to_load */, device_count /* total_shards */); + PrepareMNISTData(mnist_data_path, IMAGE_DIMS, LABEL_DIMS, *trainingData, *testData, MPIContext::GetInstance().GetWorldRank() /* shard_to_load */, device_count /* total_shards */); if (testData->NumSamples() == 0) { printf("Warning: No data loaded - run cancelled.\n"); diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 7f423b9ee0..1429b94ac1 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -10,7 +10,7 @@ #include "core/common/logging/sinks/clog_sink.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/framework/tensorboard/event_writer.h" -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" #include "orttraining/models/runner/constant.h" #include "orttraining/models/runner/training_runner.h" #include "orttraining/models/runner/training_util.h" diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index bdc734dfed..151e28088f 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -79,8 +79,8 @@ Status TrainingRunner::Initialize() { if (params_.pipeline_parallel_size > 1 && !params_.pipeline_stage_paths.empty()) { // Pipeline partition happens outside ORT. We just load the result of partitioning forward graph. // Backward graph will be generated using ORT's graph transformers. - ORT_ENFORCE(static_cast(params_.mpi_context.world_size) == params_.pipeline_stage_paths.size()); - ORT_RETURN_IF_ERROR(session_.Load(params_.pipeline_stage_paths[params_.mpi_context.world_rank])); + ORT_ENFORCE(static_cast(MPIContext::GetInstance().GetWorldSize()) == params_.pipeline_stage_paths.size()); + ORT_RETURN_IF_ERROR(session_.Load(params_.pipeline_stage_paths[MPIContext::GetInstance().GetWorldRank()])); } else { ORT_RETURN_IF_ERROR(session_.Load(params_.model_path)); } @@ -98,10 +98,10 @@ Status TrainingRunner::Initialize() { config.gradient_accumulation_steps = params_.gradient_accumulation_steps; - config.distributed_config.world_rank = params_.mpi_context.world_rank; - config.distributed_config.world_size = params_.mpi_context.world_size; - config.distributed_config.local_size = params_.mpi_context.local_size; - config.distributed_config.local_rank = params_.mpi_context.local_rank; + config.distributed_config.world_rank = MPIContext::GetInstance().GetWorldRank(); + config.distributed_config.world_size = MPIContext::GetInstance().GetWorldSize(); + config.distributed_config.local_size = MPIContext::GetInstance().GetLocalSize(); + config.distributed_config.local_rank = MPIContext::GetInstance().GetLocalRank(); config.distributed_config.data_parallel_size = params_.data_parallel_size; config.distributed_config.horizontal_parallel_size = params_.horizontal_parallel_size; config.distributed_config.pipeline_parallel_size = params_.pipeline_parallel_size; @@ -115,7 +115,7 @@ Status TrainingRunner::Initialize() { } // always configure the loss function - if (params_.pipeline_parallel_size == 1 || params_.mpi_context.world_rank == params_.mpi_context.world_size - 1) { + if (params_.pipeline_parallel_size == 1 || MPIContext::GetInstance().GetWorldRank() == MPIContext::GetInstance().GetWorldSize() - 1) { TrainingSession::TrainingConfiguration::LossFunctionConfiguration lf{}; lf.loss_function_info = params_.loss_func_info; @@ -337,7 +337,7 @@ Status TrainingRunner::Initialize() { Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader, const MapStringToString& mapped_dimensions) { - if (params_.mpi_context.world_rank == 0 && !params_.model_actual_running_graph_path.empty()) { + if (MPIContext::GetInstance().GetWorldRank() == 0 && !params_.model_actual_running_graph_path.empty()) { session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD); } @@ -810,7 +810,7 @@ void TrainingRunner::RunWithoutUpdate(VectorString& feed_names, Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader, const MapStringToString& mapped_dimensions) { const bool enable_checkpoint_saving = - params_.mpi_context.world_rank == 0 && + MPIContext::GetInstance().GetWorldRank() == 0 && checkpoint_registry_ && params_.checkpoint_period > 0; std::unique_ptr cpu_usage_calculator; diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 52c64db377..18f6d54033 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -10,7 +10,7 @@ #include "core/framework/ml_value.h" #include "core/providers/providers.h" #include "orttraining/core/framework/checkpoint_registry.h" -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" #include "orttraining/core/graph/optimizer_config.h" #include "orttraining/core/session/training_session.h" #include "orttraining/models/runner/data_loader.h" @@ -95,7 +95,6 @@ class TrainingRunner { bool use_gist = false; // Whether we collect execution profile trace during this run. bool use_profiler = false; - MPIContext mpi_context; bool skip_evaluation = false; bool dump_fetches = false; bool dump_convergence_metrics = false; @@ -120,7 +119,7 @@ class TrainingRunner { float cuda_mem_limit_in_gb = -1.0f; bool EnableTensorboard() const { - return !is_perf_test && !log_dir.empty() && mpi_context.world_rank == 0; + return !is_perf_test && !log_dir.empty() && MPIContext::GetInstance().GetWorldRank() == 0; } bool UseCuda() const { diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 13097092d8..92cedaaead 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -11,7 +11,7 @@ #include "core/session/environment.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/graph/optimizer_config.h" -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" #include "python/onnxruntime_pybind_mlvalue.h" namespace onnxruntime { @@ -19,6 +19,7 @@ namespace python { namespace py = pybind11; using namespace onnxruntime; using namespace onnxruntime::logging; +using namespace onnxruntime::training; struct TrainingParameters { std::string loss_output_name; @@ -65,25 +66,10 @@ TrainingConfigurationResult ConfigureSessionForTraining( auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size; if (data_group_size != parameters.data_parallel_size) { - std::cout << "WARNING: data_parallel_size is not correct, tuned automatically to " - << data_group_size << std::endl; + LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to " + << data_group_size; parameters.data_parallel_size = data_group_size; } -#if defined(USE_NCCL) || defined(USE_HOROVOD) - // this condition block is temporary. - // For now, nccl allreduce kernel only implements for allreduce_post_accumulation - // hovorod allreduce kernel only implements for not allreduce_post_accumulation. - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) { - auto mpi_context = training::setup_mpi(); - std::cout << "mpi_context.world_rank: " << mpi_context.world_rank << std::endl; - std::cout << "mpi_context.local_rank: " << mpi_context.local_rank << std::endl; - std::cout << "mpi_context.world_size: " << mpi_context.world_size << std::endl; - std::cout << "mpi_context.local_size: " << mpi_context.local_size << std::endl; - parameters.local_size = mpi_context.local_size; - parameters.local_rank = mpi_context.local_rank; - } -#endif training::TrainingSession::TrainingConfiguration config{}; config.weight_names_to_train = parameters.weights_to_train; @@ -158,6 +144,28 @@ TrainingConfigurationResult ConfigureSessionForTraining( return python_config_result; } +#if defined(USE_NCCL) +void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) { + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalRank(): " << MPIContext::GetInstance().GetLocalRank(); + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldSize(): " << MPIContext::GetInstance().GetWorldSize(); + LOGS(*logger, INFO) << "MPIContext::GetInstance().GetLocalSize(): " << MPIContext::GetInstance().GetLocalSize(); + + parameters.local_rank = MPIContext::GetInstance().GetLocalRank(); + parameters.local_size = MPIContext::GetInstance().GetLocalSize(); + if (parameters.world_rank != MPIContext::GetInstance().GetWorldRank()) { + if (parameters.world_rank != 0) + LOGS(*logger, WARNING) << "TrainingParameters world_rank is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldRank(); + parameters.world_rank = MPIContext::GetInstance().GetWorldRank(); + } + if (parameters.world_size != MPIContext::GetInstance().GetWorldSize()) { + if (parameters.world_size != 1) + LOGS(*logger, WARNING) << "TrainingParameters world_size is not correct, tuned automatically to " << MPIContext::GetInstance().GetWorldSize(); + parameters.world_size = MPIContext::GetInstance().GetWorldSize(); + } +} +#endif + void addObjectMethodsForTraining(py::module& m) { py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) @@ -181,6 +189,13 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs) .def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad); +#if defined(USE_NCCL) + m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); }); + m.def("get_mpi_context_local_size", []() -> int { return MPIContext::GetInstance().GetLocalSize(); }); + m.def("get_mpi_context_world_rank", []() -> int { return MPIContext::GetInstance().GetWorldRank(); }); + m.def("get_mpi_context_world_size", []() -> int { return MPIContext::GetInstance().GetWorldSize(); }); +#endif + py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); config_result.def(py::init()) .def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object { @@ -200,13 +215,21 @@ void addObjectMethodsForTraining(py::module& m) { return onnxruntime::make_unique(GetDefaultCPUSessionOptions(), env); })) .def("finalize", [](py::object) { -#if defined(USE_NCCL) || defined(USE_HOROVOD) - training::shutdown_mpi(); +#if defined(USE_NCCL) +#ifdef _WIN32 + // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices + // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction + // call shutdown_mpi() here instead. + MPIContext::shutdown_mpi(); +#endif #endif }) .def("load_model", [](onnxruntime::training::TrainingSession* sess, const std::string& path, TrainingParameters& parameters) { OrtPybindThrowIfError(sess->Load(path)); +#if defined(USE_NCCL) + CopyMPIContextToTrainingParameters(parameters, sess->GetLogger()); +#endif const auto config_result = ConfigureSessionForTraining(sess, parameters); std::vector provider_types = {}; @@ -218,6 +241,9 @@ void addObjectMethodsForTraining(py::module& m) { std::istringstream buffer(serialized_model); OrtPybindThrowIfError(sess->Load(buffer)); +#if defined(USE_NCCL) + CopyMPIContextToTrainingParameters(parameters, sess->GetLogger()); +#endif const auto config_result = ConfigureSessionForTraining(sess, parameters); std::vector provider_types = {}; diff --git a/orttraining/orttraining/test/python/orttraining_run_glue.py b/orttraining/orttraining/test/python/orttraining_run_glue.py index a0766552c2..960e105704 100644 --- a/orttraining/orttraining/test/python/orttraining_run_glue.py +++ b/orttraining/orttraining/test/python/orttraining_run_glue.py @@ -25,6 +25,7 @@ from transformers import ( import onnxruntime from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription +from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size, get_mpi_context_world_rank, get_mpi_context_world_size from orttraining_transformer_trainer import ORTTransformerTrainer @@ -95,13 +96,27 @@ class ORTGlueTest(unittest.TestCase): assert_allclose(results['loss'], expected_loss, rtol=self.rtol) def test_bert_with_mrpc(self): - expected_acc = 0.8553921568627451 - expected_f1 = 0.8970331588132635 - expected_acc_and_f1 = 0.8762126578380043 - expected_loss = 0.42737212419217707 + if self.local_rank == -1: + expected_acc = 0.8553921568627451 + expected_f1 = 0.8970331588132635 + expected_acc_and_f1 = 0.8762126578380043 + expected_loss = 0.42737212419217707 + elif self.local_rank == 0: + expected_acc = 0.8308823529411765 + expected_f1 = 0.881646655231561 + expected_acc_and_f1 = 0.8562645040863688 + expected_loss = 0.42491564023144107 - for use_new_api in [True, False]: - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=use_new_api) + if self.local_rank == -1: + # not parallel case, we can run both new and old api tests + for use_new_api in [True, False]: + results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=use_new_api) + else: + # with parallel training, TrainingArguments can only be created once (due to its cached _setup_devices) + # thus we can only choose one test case to run. + results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False, use_new_api=True) + + if self.local_rank in [-1, 0]: assert_allclose(results['acc'], expected_acc, rtol=self.rtol) assert_allclose(results['f1'], expected_f1, rtol=self.rtol) assert_allclose(results['acc_and_f1'], expected_acc_and_f1, rtol=self.rtol) @@ -122,6 +137,14 @@ class ORTGlueTest(unittest.TestCase): def model_to_desc(self, model_name, model): if model_name.startswith('bert') or model_name.startswith('xlnet'): + new_model_desc = { + 'inputs': [ + ('input_ids', ['batch', 'max_seq_len_in_batch'],), + ('attention_mask', ['batch', 'max_seq_len_in_batch'],), + ('token_type_ids', ['batch', 'max_seq_len_in_batch'],), + ('labels', ['batch', ],)], + 'outputs': [('loss', [], True), + ('logits', ['batch', 2])]} model_desc = ModelDescription([ IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size), IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2), @@ -130,6 +153,13 @@ class ORTGlueTest(unittest.TestCase): IODescription('loss', [], torch.float32), IODescription('logits', ['batch', 2], torch.float32)]) elif model_name.startswith('roberta'): + new_model_desc = { + 'inputs': [ + ('input_ids', ['batch', 'max_seq_len_in_batch'],), + ('attention_mask', ['batch', 'max_seq_len_in_batch'],), + ('labels', ['batch', ],)], + 'outputs': [('loss', [], True), + ('logits', ['batch', 2])]} model_desc = ModelDescription([ IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size), IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2), @@ -139,16 +169,19 @@ class ORTGlueTest(unittest.TestCase): else: raise RuntimeError("unsupported base model name {}.".format(model_name)) - return model_desc + return model_desc, new_model_desc def run_glue(self, model_name, task_name, fp16, use_new_api): model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=self.data_dir + "/" + task_name, + data_args = GlueDataTrainingArguments( + task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length) - training_args = TrainingArguments(output_dir=self.output_dir + "/" + task_name, do_train=True, do_eval=True, + training_args = TrainingArguments( + output_dir=os.path.join(self.output_dir, task_name), do_train=True, do_eval=True, per_gpu_train_batch_size=self.train_batch_size, - learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,local_rank=self.local_rank, + learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs, + local_rank=self.local_rank, overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps, fp16=fp16, logging_steps=self.logging_steps) @@ -214,11 +247,12 @@ class ORTGlueTest(unittest.TestCase): preds = np.squeeze(p.predictions) return glue_compute_metrics(data_args.task_name, preds, p.label_ids) - model_desc = self.model_to_desc(model_name, model) + model_desc, new_model_desc = self.model_to_desc(model_name, model) # Initialize the ORTTrainer within ORTTransformerTrainer trainer = ORTTransformerTrainer( model=model, model_desc=model_desc, + new_model_desc=new_model_desc, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, @@ -247,4 +281,26 @@ class ORTGlueTest(unittest.TestCase): return results if __name__ == "__main__": - unittest.main() + if get_mpi_context_world_size() > 1: + # mpi launch + + print("mpirun launch") + # TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl") + # pytorch expects following environment settings (which would be set if launched with torch.distributed.launch). + + local_rank = get_mpi_context_local_rank() + print("get_mpi_context_local_rank(): ", local_rank) + os.environ['RANK'] = str(local_rank) + os.environ['WORLD_SIZE'] = str(get_mpi_context_world_size()) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + + from onnxruntime.capi._pybind_state import set_cuda_device_id + set_cuda_device_id(local_rank) + + test = ORTGlueTest() + test.setUp() + test.local_rank = local_rank + test.test_bert_with_mrpc() + else: + unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py index c9a41341f9..1d9acadda2 100644 --- a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py +++ b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py @@ -214,6 +214,7 @@ class ORTMultipleChoiceTest(unittest.TestCase): trainer = ORTTransformerTrainer( model=model, model_desc=model_desc, + new_model_desc=None, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py index 4ae5e80b06..7e15b05b55 100644 --- a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py +++ b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py @@ -12,7 +12,7 @@ import torch from torch import nn from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset -# from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler from tqdm import tqdm, trange @@ -101,6 +101,7 @@ class ORTTransformerTrainer: self, model: PreTrainedModel, model_desc: ModelDescription, + new_model_desc: dict, args: TrainingArguments, train_dataset: Dataset, eval_dataset: Dataset, @@ -112,6 +113,7 @@ class ORTTransformerTrainer: self.model = model self.model_desc = model_desc + self.new_model_desc = new_model_desc self.args = args self.data_collator = DefaultDataCollator() self.train_dataset = train_dataset @@ -171,15 +173,6 @@ class ORTTransformerTrainer: num_train_epochs = self.args.num_train_epochs if self.use_new_api: - model_desc = { - 'inputs': [ - ('input_ids', ['batch', 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', 'max_seq_len_in_batch'],), - ('token_type_ids', ['batch', 'max_seq_len_in_batch'],), - ('labels', ['batch', ],)], - 'outputs': [('loss', [], True), - ('logits', ['batch', 2])]} - lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None @@ -205,7 +198,7 @@ class ORTTransformerTrainer: ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - self.model = orttrainer.ORTTrainer(self.model, model_desc, optim_config, options=options) + self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options) else: def map_optimizer_attributes(name): no_decay = "bias" in name or "LayerNorm.weight" in name @@ -222,7 +215,6 @@ class ORTTransformerTrainer: learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args.gradient_accumulation_steps, - world_rank=0, world_size=1, # only support single GPU cases use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, diff --git a/orttraining/orttraining/training_ops/cuda/collective/nccl_common.cc b/orttraining/orttraining/training_ops/cuda/collective/nccl_common.cc index c4e2a10f81..07356c83e1 100644 --- a/orttraining/orttraining/training_ops/cuda/collective/nccl_common.cc +++ b/orttraining/orttraining/training_ops/cuda/collective/nccl_common.cc @@ -4,7 +4,7 @@ #include "nccl_common.h" #include -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.cc b/orttraining/orttraining/training_ops/cuda/communication/recv.cc index 42be0c0de9..dcaa77c123 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.cc @@ -9,7 +9,7 @@ #include "core/providers/cuda/cuda_common.h" #include -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" namespace onnxruntime { namespace cuda { diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.cc b/orttraining/orttraining/training_ops/cuda/communication/send.cc index 13855d016a..aa087f3849 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/send.cc @@ -10,7 +10,7 @@ #include #include -#include "orttraining/core/framework/mpi_setup.h" +#include "orttraining/core/framework/mpi_context.h" namespace onnxruntime { namespace cuda { diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8b497e001c..2ed0ed87eb 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1101,6 +1101,12 @@ def run_training_python_frontend_e2e_tests(cwd): # frontend tests are to be added here: log.info("Running python frontend e2e tests.") + import torch + ngpus = torch.cuda.device_count() + if ngpus > 1: + log.debug('RUN: mpirun -n {} {} orttraining_run_glue.py'.format(ngpus, sys.executable)) + run_subprocess(['mpirun', '-n', str(ngpus), sys.executable, 'orttraining_run_glue.py'], cwd=cwd) + # with orttraining_run_glue.py. # 1. we like to force to use single GPU (with CUDA_VISIBLE_DEVICES) # for fine-tune tests. diff --git a/tools/ci_build/github/linux/docker/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_deps.sh index 53f5ffb029..0834253d9b 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_deps.sh @@ -114,6 +114,17 @@ elif [ $DEVICE_TYPE = "gpu" ]; then ${PYTHON_EXE} -m pip install transformers==v2.10.0 # transformers requires sklearn ${PYTHON_EXE} -m pip install sklearn + + if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then + echo "install openmpi" + curl -fsSL https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz -O + tar zxf openmpi-4.0.0.tar.gz + cd openmpi-4.0.0 + ./configure --enable-orterun-prefix-by-default + make all + make install + ldconfig + fi fi fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index a5d80ef786..6e70f424a9 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -47,7 +47,7 @@ if [ "$OS_VERSION" = "16.04" ]; then libicu55 \ libtinfo-dev \ libtool \ - mpich libmpich-dev \ + openssh-server \ aria2 \ bzip2 \ unzip \ @@ -80,7 +80,7 @@ else # ubuntu18.04 libicu60 \ libtinfo-dev \ libtool \ - mpich libmpich-dev \ + openssh-server \ aria2 \ bzip2 \ unzip \