mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Fix loss scaling when running ORTTrainer with BERT under mixed-precision mode (#6932)
* Fix missed Loss scale * not to dump
This commit is contained in:
parent
601e04fb27
commit
de6e66f3d4
1 changed files with 1 additions and 1 deletions
|
|
@ -1679,7 +1679,7 @@ Status PipelineTrainingSession::BuildLossAndLossScaling(
|
|||
std::string& loss_name,
|
||||
optional<std::string>& loss_scale_input_name,
|
||||
optional<TrainingConfigurationResult::MixedPrecisionConfigurationResult>& mixed_precision_config_result) {
|
||||
const bool last_pipeline_stage = pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size;
|
||||
const bool last_pipeline_stage = pipeline_stage_id == -1 || (pipeline_stage_id + 1 == distributed_config.value().pipeline_parallel_size);
|
||||
const bool enable_loss_scale = is_mixed_precision_enabled_ &&
|
||||
mixed_precision_config.value().mixed_precision_type == MixedPrecisionDataType::FP16;
|
||||
// Enable loss scale if mixed precision is enabled AND at pipeline's last stage if pipeline is used.
|
||||
|
|
|
|||
Loading…
Reference in a new issue