From 4ec425ffad56cdbedfb97ab2d11243e42889f71c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=B5=E5=BA=AD=E7=91=9C?= <76907817+timjeffrey10@users.noreply.github.com> Date: Thu, 23 Jan 2025 18:45:02 +0800 Subject: [PATCH] Fix GA loss for Deepspeed (#35808) * Fix GA loss for Deepspeed * Turn off loss scaling in DeepSpeed engine by scale_wrt_gas * Add comment linking to PR --- src/transformers/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f45ff46bd..00938a630 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3722,6 +3722,11 @@ class Trainer: if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: loss = loss / self.args.gradient_accumulation_steps + # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled + # https://github.com/huggingface/transformers/pull/35808 + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs["scale_wrt_gas"] = False + self.accelerator.backward(loss, **kwargs) return loss.detach()