Fix GA loss for Deepspeed (#35808)

* Fix GA loss for Deepspeed

* Turn off loss scaling in DeepSpeed engine by scale_wrt_gas

* Add comment linking to PR
This commit is contained in:
張庭瑜 2025-01-23 18:45:02 +08:00 committed by GitHub
parent f3f6c86582
commit 4ec425ffad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3722,6 +3722,11 @@ class Trainer:
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
# https://github.com/huggingface/transformers/pull/35808
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs["scale_wrt_gas"] = False
self.accelerator.backward(loss, **kwargs)
return loss.detach()