diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 60a81f745..7c5d4b1c7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1797,6 +1797,9 @@ class TrainingArguments: " during training" ) + if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0 or 0 < self.warmup_steps <= 1: + raise ValueError("warmup_steps must be either 0 or > 1") + if isinstance(self.fsdp, bool): self.fsdp = "full_shard" if self.fsdp else "" if isinstance(self.fsdp, str):