diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 9fe5c1bd72..564557f856 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -1679,7 +1679,7 @@ Status PipelineTrainingSession::BuildLossAndLossScaling( std::string& loss_name, optional& loss_scale_input_name, optional& mixed_precision_config_result) { - const bool last_pipeline_stage = pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size; + const bool last_pipeline_stage = pipeline_stage_id == -1 || (pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size); const bool enable_loss_scale = is_mixed_precision_enabled_ && mixed_precision_config.value().mixed_precision_type == MixedPrecisionDataType::FP16; // Enable loss scale if mixed precision is enabled AND at pipeline's last stage if pipeline is used.