From de6e66f3d49eebec36884f488789bb429fee0c92 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 8 Mar 2021 21:12:33 +0800 Subject: [PATCH] Fix loss scaling when running ORTTrainer with BERT under mixed-precision mode (#6932) * Fix missed Loss scale * not to dump --- orttraining/orttraining/core/session/training_session.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 9fe5c1bd72..564557f856 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -1679,7 +1679,7 @@ Status PipelineTrainingSession::BuildLossAndLossScaling( std::string& loss_name, optional& loss_scale_input_name, optional& mixed_precision_config_result) { - const bool last_pipeline_stage = pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size; + const bool last_pipeline_stage = pipeline_stage_id == -1 || (pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size); const bool enable_loss_scale = is_mixed_precision_enabled_ && mixed_precision_config.value().mixed_precision_type == MixedPrecisionDataType::FP16; // Enable loss scale if mixed precision is enabled AND at pipeline's last stage if pipeline is used.