mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix GA loss for Deepspeed (#35808)
* Fix GA loss for Deepspeed * Turn off loss scaling in DeepSpeed engine by scale_wrt_gas * Add comment linking to PR
This commit is contained in:
parent
f3f6c86582
commit
4ec425ffad
1 changed files with 5 additions and 0 deletions
|
|
@ -3722,6 +3722,11 @@ class Trainer:
|
|||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
|
||||
# https://github.com/huggingface/transformers/pull/35808
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
kwargs["scale_wrt_gas"] = False
|
||||
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach()
|
||||
|
|
|
|||
Loading…
Reference in a new issue