fix: fix gradient accumulate step for learning rate (#27667)

This commit is contained in:
Phuc Van Phan 2023-12-07 13:59:26 +07:00 committed by GitHub
parent f84d85ba67
commit 0410a29a2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -640,7 +640,7 @@ def main():
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(vectorized_datasets["train"]),
total_train_steps,
training_args.warmup_steps,
training_args.learning_rate,
)