mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Fix a macro and memory regression (#4068)
onnxruntime_training_bert can run the following command again. ./onnxruntime_training_bert --model_name=bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm --num_train_steps=16 --train_batch_size=52 --mode=train --train_data_dir=/bert_data/128/books_wiki_en_corpus/train --test_data_dir=/bert_data/128/books_wiki_en_corpus/test --gradient_accumulation_steps=16 --optimizer=Lamb --learning_rate=3e-3 --max_seq_length=128 --max_predictions_per_seq=20 --warmup_ratio=0.2843 --warmup_mode=Poly --display_loss_steps=100 --use_mixed_precision=True --allreduce_in_fp16 --use_nccl
This commit is contained in:
parent
38d76cc904
commit
e951b29a0b
2 changed files with 139 additions and 80 deletions
|
|
@ -12,7 +12,7 @@
|
|||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
#if !defined(NDEBUG) && defined(USE_CUDA) && !defined(_WIN32)
|
||||
#ifdef ENABLE_NVTX_PROFILE
|
||||
#include "core/profile/context.h"
|
||||
#endif
|
||||
#include "core/session/environment.h"
|
||||
|
|
@ -398,7 +398,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of the first waited event in forward pass.
|
||||
if (!pipeline_context_.forward_waited_event_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.forward_waited_event_name);
|
||||
OrtValue event_id;
|
||||
const int64_t id =
|
||||
|
|
@ -415,7 +415,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of the second waited event in forward pass.
|
||||
if (!pipeline_context_.forward_waited_event_after_recv_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.forward_waited_event_after_recv_name);
|
||||
OrtValue event_id;
|
||||
const int64_t id =
|
||||
|
|
@ -432,7 +432,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of first recorded event in forward pass.
|
||||
if (!pipeline_context_.forward_recorded_event_before_send_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.forward_recorded_event_before_send_name);
|
||||
OrtValue event_id;
|
||||
const int64_t id =
|
||||
|
|
@ -449,7 +449,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of second recorded event in forward pass.
|
||||
if (!pipeline_context_.forward_recorded_event_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.forward_recorded_event_name);
|
||||
OrtValue event_id;
|
||||
const int64_t id =
|
||||
|
|
@ -466,7 +466,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of first waited event in backward pass.
|
||||
if (!pipeline_context_.backward_waited_event_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.backward_waited_event_name);
|
||||
OrtValue event_id;
|
||||
const int64_t id =
|
||||
|
|
@ -483,7 +483,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of second waited event in backward pass.
|
||||
if (!pipeline_context_.backward_waited_event_after_recv_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.backward_waited_event_after_recv_name);
|
||||
OrtValue event_id;
|
||||
const int64_t id =
|
||||
|
|
@ -500,7 +500,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of first recorded event in backward pass.
|
||||
if (!pipeline_context_.backward_recorded_event_before_send_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.backward_recorded_event_before_send_name);
|
||||
OrtValue event_id;
|
||||
int64_t id =
|
||||
|
|
@ -517,7 +517,7 @@ Status TrainingRunner::PrepareFeedNamesAndFeeds(const SessionMode mode,
|
|||
|
||||
// Create feed of second recorded event in backward pass.
|
||||
if (!pipeline_context_.backward_recorded_event_name.empty()) {
|
||||
ORT_ENFORCE(params_.pipeline_parallel_size > 1);
|
||||
ORT_RETURN_IF(params_.pipeline_parallel_size <= 1, "Internal event name should be empty if there is no pipeline.");
|
||||
feed_names.push_back(pipeline_context_.backward_recorded_event_name);
|
||||
OrtValue event_id;
|
||||
int64_t id =
|
||||
|
|
@ -630,25 +630,54 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode,
|
|||
}
|
||||
|
||||
// Launch synced session.Run on the main thread.
|
||||
Status TrainingRunner::RunWithUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
std::vector<MLValue>& fetches) {
|
||||
#if !defined(NDEBUG) && defined(USE_CUDA) && !defined(_WIN32)
|
||||
// Store the tag for the thread which runs session_.Run(...).
|
||||
// It will be used to name range in Nvidia's visual profiler.
|
||||
auto& profile_context = profile::Context::GetInstance();
|
||||
profile_context.SetThreadTag(
|
||||
std::this_thread::get_id(), std::to_string(step_));
|
||||
void TrainingRunner::RunWithUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
std::vector<MLValue>& fetches) {
|
||||
// Cyclically pick up a worker ID.
|
||||
const size_t worker_id = step_ % params_.pipeline_parallel_size;
|
||||
|
||||
// Wait for the previous work to finish its job.
|
||||
// Its resource cannot be overrided when it's still working.
|
||||
pipeline_worker_pool_.Join(worker_id);
|
||||
|
||||
// Copy thread-used variable to thread-specific buffer to maintain their life.
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names = feed_names;
|
||||
pipeline_worker_pool_.worker_states[worker_id].feeds = feeds;
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetch_names = fetch_names;
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetches = std::vector<MLValue>();
|
||||
|
||||
Status status = Status::OK();
|
||||
pipeline_worker_pool_.workers[worker_id] = std::thread([&](
|
||||
const size_t worker_id, const size_t step) {
|
||||
#ifdef ENABLE_NVTX_PROFILE
|
||||
// Store the tag for the thread which runs session_.Run(...).
|
||||
// It will be used to name range in Nvidia's visual profiler.
|
||||
auto& profile_context = profile::Context::GetInstance();
|
||||
profile_context.SetThreadTag(
|
||||
std::this_thread::get_id(), std::to_string(step));
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(step);
|
||||
#endif
|
||||
// Sync launch of session. This model-update session runs on the main thread, so
|
||||
// no new async session will be launched until this model-update session is done.
|
||||
// This prevents the new sessions from using not-updated model.
|
||||
ORT_RETURN_IF_ERROR(session_.Run(RunOptions(),
|
||||
feed_names,
|
||||
feeds,
|
||||
fetch_names,
|
||||
&fetches));
|
||||
status = session_.Run(
|
||||
RunOptions(),
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feeds,
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
|
||||
&(pipeline_worker_pool_.worker_states[worker_id].fetches));
|
||||
}, worker_id, step_);
|
||||
|
||||
// Wait all workers to finish this round of pipeline parallelism.
|
||||
// The last batch in a pipeline collects gradient and update the model.
|
||||
// We must join here because main thread needs to access thread-produced
|
||||
// fetches and those fetches must be ready.
|
||||
pipeline_worker_pool_.JoinAll();
|
||||
|
||||
// If the updating thread fails, we return with its error status.
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
|
||||
// Copy back from thread-specific buffer to main thread's memory.
|
||||
fetches = pipeline_worker_pool_.worker_states[worker_id].fetches;
|
||||
|
||||
if (loss_scaler_) {
|
||||
auto it = std::find(fetch_names.begin(), fetch_names.end(), opt_graph_outputs_[OptimizerOutputKey::GradientAllIsFinite]);
|
||||
|
|
@ -683,15 +712,13 @@ Status TrainingRunner::RunWithUpdate(VectorString& feed_names,
|
|||
++step_;
|
||||
// Add one after update the model once.
|
||||
++weight_update_step_count_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Launch async session.Run on non-main thread.
|
||||
Status TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
size_t& gradient_accumulation_step_count) {
|
||||
void TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
size_t& gradient_accumulation_step_count) {
|
||||
// Cyclically pick up a worker ID.
|
||||
const size_t worker_id = step_ % params_.pipeline_parallel_size;
|
||||
|
||||
|
|
@ -708,33 +735,31 @@ Status TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
|
|||
|
||||
// Async launch of a session.
|
||||
pipeline_worker_pool_.workers[worker_id] = std::thread([&](
|
||||
const size_t worker_id, const size_t step) {
|
||||
#if !defined(NDEBUG) && defined(USE_CUDA) && !defined(_WIN32)
|
||||
const size_t worker_id, const size_t step) {
|
||||
#ifdef ENABLE_NVTX_PROFILE
|
||||
// Store the tag for the thread which runs session_.Run(...).
|
||||
// It will be used to name range in Nvidia's visual profiler.
|
||||
auto& profile_context = profile::Context::GetInstance();
|
||||
profile_context.SetThreadTag(
|
||||
std::this_thread::get_id(), std::to_string(step));
|
||||
std::this_thread::get_id(), std::to_string(step));
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(step);
|
||||
#endif
|
||||
// Dummy use of step to avoid warning when the code above is disabled.
|
||||
ORT_ENFORCE(step + 1 > 0);
|
||||
RunOptions run_options;
|
||||
run_options.only_execute_path_to_fetches = true;
|
||||
ORT_ENFORCE(session_.Run(
|
||||
run_options,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feeds,
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
|
||||
&(pipeline_worker_pool_.worker_states[worker_id].fetches)) == Status::OK());
|
||||
},
|
||||
worker_id, step_);
|
||||
auto status = session_.Run(
|
||||
run_options,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feeds,
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
|
||||
&(pipeline_worker_pool_.worker_states[worker_id].fetches));
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
}, worker_id, step_);
|
||||
|
||||
// Add one after process one batch.
|
||||
++step_;
|
||||
// Add one after comuting one forward-backward path without applying optimizer.
|
||||
++gradient_accumulation_step_count;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
|
||||
|
|
@ -807,29 +832,32 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
|
|||
|
||||
if (is_weight_update_step) {
|
||||
PrepareFeedNamesAndFeeds(ModelUpdateStep,
|
||||
training_data_loader,
|
||||
*training_data,
|
||||
lr_scheduler.get(),
|
||||
batch,
|
||||
feed_names,
|
||||
feeds);
|
||||
PrepareFetchNamesAndFetches(ModelUpdateStep,
|
||||
fetch_names,
|
||||
fetches);
|
||||
training_data_loader,
|
||||
*training_data,
|
||||
lr_scheduler.get(),
|
||||
batch,
|
||||
feed_names,
|
||||
feeds);
|
||||
ORT_RETURN_IF_ERROR(
|
||||
PrepareFetchNamesAndFetches(ModelUpdateStep,
|
||||
fetch_names,
|
||||
fetches));
|
||||
RunWithUpdate(feed_names, fetch_names, feeds, fetches);
|
||||
} else {
|
||||
PrepareFeedNamesAndFeeds(GradientAccumulateStep,
|
||||
training_data_loader,
|
||||
*training_data,
|
||||
lr_scheduler.get(),
|
||||
batch,
|
||||
feed_names,
|
||||
feeds);
|
||||
PrepareFetchNamesAndFetches(GradientAccumulateStep,
|
||||
fetch_names,
|
||||
fetches);
|
||||
training_data_loader,
|
||||
*training_data,
|
||||
lr_scheduler.get(),
|
||||
batch,
|
||||
feed_names,
|
||||
feeds);
|
||||
ORT_RETURN_IF_ERROR(
|
||||
PrepareFetchNamesAndFetches(GradientAccumulateStep,
|
||||
fetch_names,
|
||||
fetches));
|
||||
RunWithoutUpdate(feed_names, fetch_names, feeds,
|
||||
gradient_accumulation_step_count);
|
||||
gradient_accumulation_step_count);
|
||||
|
||||
}
|
||||
|
||||
// at this point, step_ already be increased by 1.
|
||||
|
|
@ -1105,11 +1133,42 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
|
|||
fetch_names,
|
||||
fetches);
|
||||
|
||||
ORT_RETURN_IF_ERROR(session.Run(run_options,
|
||||
feed_names,
|
||||
feeds,
|
||||
fetch_names,
|
||||
&fetches));
|
||||
if (params_.pipeline_parallel_size == 1) {
|
||||
auto status = Status::OK();
|
||||
// When there is no pipeline, we always use the first thread
|
||||
// to launch session_.Run(...) to avoid multiple activation allocations.
|
||||
|
||||
// Always use the first thread to evaluate.
|
||||
const size_t worker_id = 0;
|
||||
// Wait for the previous work to finish its job.
|
||||
// Its resource cannot be overrided when it's still working.
|
||||
pipeline_worker_pool_.Join(worker_id);
|
||||
// Declare Run(...)'s status in thread.
|
||||
// Launch Run(...).
|
||||
pipeline_worker_pool_.workers[worker_id] = std::thread([&]() {
|
||||
RunOptions run_options;
|
||||
run_options.only_execute_path_to_fetches = true;
|
||||
status = session.Run(
|
||||
run_options,
|
||||
feed_names,
|
||||
feeds,
|
||||
fetch_names,
|
||||
&fetches);
|
||||
});
|
||||
// Wait Run(...) to finish.
|
||||
pipeline_worker_pool_.Join(worker_id);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
} else {
|
||||
// Training threads are fully used by pipeline stages.
|
||||
// Pipeline cannot reuse training threads to do evaluation.
|
||||
// Otherwise, deadlock may happens.
|
||||
ORT_RETURN_IF_ERROR(session.Run(run_options,
|
||||
feed_names,
|
||||
feeds,
|
||||
fetch_names,
|
||||
&fetches));
|
||||
}
|
||||
|
||||
|
||||
// Assume that user-specified fetches are avaliable only on the last pipeline stage.
|
||||
// When there is no pipeline, all pipeline_context_.pipeline_stage_id should be 0 and
|
||||
|
|
|
|||
|
|
@ -196,14 +196,14 @@ class TrainingRunner {
|
|||
Status PrepareFetchNamesAndFetches(const SessionMode mode,
|
||||
std::vector<std::string>& fetch_names,
|
||||
std::vector<MLValue>& fetches);
|
||||
Status RunWithUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
std::vector<MLValue>& fetches);
|
||||
Status RunWithoutUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
size_t& gradient_accumulation_step_count);
|
||||
void RunWithUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
std::vector<MLValue>& fetches);
|
||||
void RunWithoutUpdate(VectorString& feed_names,
|
||||
VectorString& fetch_names,
|
||||
std::vector<MLValue>& feeds,
|
||||
size_t& gradient_accumulation_step_count);
|
||||
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions);
|
||||
Status Evaluate(InferenceSession& session, IDataLoader& data_loader);
|
||||
|
|
|
|||
Loading…
Reference in a new issue