Fix loss scaling when running ORTTrainer with BERT under mixed-precision mode (#6932)

* Fix missed Loss scale

* not to dump
This commit is contained in:
Wei-Sheng Chin 2021-03-08 21:12:33 +08:00 committed by GitHub
parent 601e04fb27
commit de6e66f3d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1679,7 +1679,7 @@ Status PipelineTrainingSession::BuildLossAndLossScaling(
std::string& loss_name,
optional<std::string>& loss_scale_input_name,
optional<TrainingConfigurationResult::MixedPrecisionConfigurationResult>& mixed_precision_config_result) {
const bool last_pipeline_stage = pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size;
const bool last_pipeline_stage = pipeline_stage_id == -1 || (pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size);
const bool enable_loss_scale = is_mixed_precision_enabled_ &&
mixed_precision_config.value().mixed_precision_type == MixedPrecisionDataType::FP16;
// Enable loss scale if mixed precision is enabled AND at pipeline's last stage if pipeline is used.