From b2e4b091f08f1aaf21855d588c6c8d284baba9eb Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Sat, 30 Jul 2022 10:07:56 +0530 Subject: [PATCH] fix FSDP ShardedGradScaler (#18358) renaming it --- src/transformers/trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d08e60137..59a1ca19a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -565,9 +565,11 @@ class Trainer: self.scaler = ShardedGradScaler() elif self.fsdp is not None: if self.amp_dtype == torch.float16: - from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + from torch.distributed.fsdp.sharded_grad_scaler import ( + ShardedGradScaler as FSDPShardedGradScaler, + ) - self.scaler = ShardedGradScaler() + self.scaler = FSDPShardedGradScaler() else: self.do_grad_scaling = False self.use_cuda_amp = False @@ -1366,6 +1368,8 @@ class Trainer: transformer_cls_to_wrap = get_module_class_from_name( model, self.args.fsdp_transformer_layer_cls_to_wrap ) + if transformer_cls_to_wrap is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap