diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 43287f8b7..bf713b58a 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1479,7 +1479,7 @@ class Trainer: num_train_tokens = None # Since the actual `max_steps` depends on if we have a dataloader length, number of epochs, or max_steps, - # we keep it defined here first as the "base case" if we don't have a dataloader length later + # we keep it defined here first as the "base case" if we don't have a dataloader length later (Case 1) max_steps = args.max_steps # If max_steps is negative, we use the number of epochs to determine the number of total steps later epoch_based = max_steps < 0 @@ -1497,8 +1497,8 @@ class Trainer: len_dataloader = len(train_dataloader) num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1) num_examples = self.num_examples(train_dataloader) + # Case 3, we have a length but are using epochs, we can extrapolate the number of steps if epoch_based: - # Since we have a dataloader length and are utilizing epochs, we can extrapolate the number of steps max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) else: # Otherwise, we calculate the number of epochs from max_steps & update steps