Trainer Refactor: Part 1 (#35567)

* start

* So far: 30%

* Small fix

* Continuing update

* Continuing

* Forgot to check if not None

* Continuing refactor

* Fix if else

* Fix ref

* Should make tests pass

* Keep grad norm same

* Document

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Err instead of info for logging RNG state error

* Seperate out to func

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2025-01-29 09:50:54 -05:00 committed by GitHub
parent 23d782ead2
commit 5d257111c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 147 additions and 132 deletions

View file

@ -41,7 +41,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Ty
# isort: off
from .integrations import (
get_reporting_integration_callbacks,
hp_params,
)
# isort: on
@ -108,6 +107,7 @@ from .trainer_pt_utils import (
nested_xla_mesh_reduce,
reissue_pt_warnings,
remove_dummy_checkpoint,
set_rng_state_for_device,
)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
@ -2219,46 +2219,25 @@ class Trainer:
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
(
num_train_epochs,
num_update_steps_per_epoch,
num_examples,
num_train_samples,
epoch_based,
len_dataloader,
max_steps,
) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
len_dataloader = None
num_train_tokens = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
)
if self.args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)
# If going by epochs, multiply tokens linearly
if len_dataloader is not None and epoch_based:
num_train_tokens *= args.num_train_epochs
# Otherwise since its steps, we just multiply by grad accum
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
num_train_tokens *= args.gradient_accumulation_steps
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
@ -2293,21 +2272,7 @@ class Trainer:
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.state.compute_steps(args, max_steps)
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
@ -2420,25 +2385,7 @@ class Trainer:
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.state.trial_params = hp_params(assignments)
else:
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
self.state.init_training_references(self, train_dataloader, max_steps, num_train_epochs, trial)
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
@ -2495,10 +2442,7 @@ class Trainer:
step += 1
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
@ -2565,8 +2509,6 @@ class Trainer:
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
if is_sagemaker_mp_enabled() and args.fp16:
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
@ -2598,8 +2540,7 @@ class Trainer:
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
if not self.accelerator.optimizer_step_was_skipped:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
@ -3119,52 +3060,19 @@ class Trainer:
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
else:
try:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_xla_available():
xm.set_rng_state(checkpoint_rng_state["xla"])
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device("GPU", torch.cuda, checkpoint_rng_state, is_distributed)
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
else:
try:
torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
if is_torch_mlu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
else:
try:
torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
if is_torch_musa_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
else:
try:
torch.musa.set_rng_state(checkpoint_rng_state["musa"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
def _determine_best_metric(self, metrics, trial):
"""
@ -5050,11 +4958,10 @@ class Trainer:
accelerator_config = self.args.accelerator_config.to_dict()
if is_accelerate_available("0.28.0"):
# Extract dataloader config params from accelerator config
dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
dataloader_config = DataLoaderConfiguration(
split_batches=accelerator_config.pop("split_batches"),
dispatch_batches=accelerator_config.pop("dispatch_batches"),
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
**{param: accelerator_config.pop(param) for param in dataloader_params}
)
if is_accelerate_available("1.1.0"):
dataloader_config.data_seed = self.args.data_seed
@ -5099,12 +5006,8 @@ class Trainer:
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
for param in ["limit_all_gathers", "activation_checkpointing"]:
setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
@ -5186,3 +5089,63 @@ class Trainer:
num_items_in_batch = num_items_in_batch.item()
return batch_samples, num_items_in_batch
def set_initial_training_values(
self, args: TrainingArguments, dataloader: DataLoader, total_train_batch_size: int
):
"""
Calculates and returns the following values:
- `num_train_epochs`
- `num_update_steps_per_epoch`
- `num_examples`
- `num_train_samples`
- `epoch_based`
- `len_dataloader`
- `max_steps`
"""
# Case 1: we rely on `args.max_steps` first
max_steps = args.max_steps
# If max_steps is negative, we use the number of epochs to determine the number of total steps later
epoch_based = max_steps < 0
len_dataloader = len(dataloader) if has_length(dataloader) else None
# Case 2: We have a dataloader length and can extrapolate
if len_dataloader is not None:
num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1)
# Case 3: We have a length but are using epochs, we can extrapolate the number of steps
if epoch_based:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
# Now we figure out `num_examples`, `num_train_epochs`, and `train_samples`
if len_dataloader:
num_examples = self.num_examples(dataloader)
if args.max_steps > 0:
num_train_epochs = max_steps // num_update_steps_per_epoch + int(
max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = max_steps * total_train_batch_size
else:
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
return (
num_train_epochs,
num_update_steps_per_epoch,
num_examples,
num_train_samples,
epoch_based,
len_dataloader,
max_steps,
)

View file

@ -18,13 +18,14 @@ Callbacks to use with the Trainer class and customize the training loop.
import dataclasses
import json
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from .trainer_utils import IntervalStrategy, SaveStrategy, has_length
from .trainer_utils import HPSearchBackend, IntervalStrategy, SaveStrategy, has_length
from .training_args import TrainingArguments
from .utils import logging
@ -150,6 +151,43 @@ class TrainerState:
text = f.read()
return cls(**json.loads(text))
def compute_steps(self, args, max_steps):
"""
Calculates and stores the absolute value for logging,
eval, and save steps based on if it was a proportion
or not.
"""
for step_kind in ("logging", "eval", "save"):
num_steps = getattr(args, f"{step_kind}_steps")
if num_steps is not None:
if num_steps < 1:
num_steps = math.ceil(max_steps * num_steps)
setattr(self, f"{step_kind}_steps", num_steps)
def init_training_references(self, trainer, train_dataloader, max_steps, num_train_epochs, trial):
"""
Stores the initial training references needed in `self`
"""
for attr in ("model", "optimizer", "lr_scheduler"):
setattr(self, attr, getattr(trainer, attr))
self.train_dataloader = train_dataloader
if trainer.hp_name is not None and trainer._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.trial_name = trainer.hp_name(trainer._trial)
self.trial_params = None
if trial is not None:
from transformers.integrations import hp_params
assignments = trial.assignments if trainer.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.trial_params = hp_params(assignments)
self.max_steps = max_steps
self.num_train_epochs = num_train_epochs
self.is_local_process_zero = trainer.is_local_process_zero()
self.is_world_process_zero = trainer.is_world_process_zero()
class ExportableState:
"""

View file

@ -1390,3 +1390,17 @@ class LayerWiseDummyScheduler(LRScheduler):
def _get_closed_form_lr(self):
return self.base_lrs
def set_rng_state_for_device(device_name, device_module, checkpoint_rng_state, is_distributed):
"""Helper to set RNG state for a specific device type (CUDA, NPU, MLU, MUSA)"""
device_state_key = device_name.lower()
err_template = "Didn't manage to set back the RNG states of the {backend} because of the following error:\n {exception}\nThis won't yield the same results as if the training had not been interrupted."
try:
if is_distributed:
device_module.random.set_rng_state_all(checkpoint_rng_state[device_state_key])
else:
device_module.random.set_rng_state(checkpoint_rng_state[device_state_key])
except Exception as e:
# Log error if setting RNG state fails
logger.error(err_template.format(backend=device_name, exception=e))