mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Trainer Refactor: Part 1 (#35567)
* start * So far: 30% * Small fix * Continuing update * Continuing * Forgot to check if not None * Continuing refactor * Fix if else * Fix ref * Should make tests pass * Keep grad norm same * Document * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Err instead of info for logging RNG state error * Seperate out to func --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
23d782ead2
commit
5d257111c1
3 changed files with 147 additions and 132 deletions
|
|
@ -41,7 +41,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Ty
|
|||
# isort: off
|
||||
from .integrations import (
|
||||
get_reporting_integration_callbacks,
|
||||
hp_params,
|
||||
)
|
||||
|
||||
# isort: on
|
||||
|
|
@ -108,6 +107,7 @@ from .trainer_pt_utils import (
|
|||
nested_xla_mesh_reduce,
|
||||
reissue_pt_warnings,
|
||||
remove_dummy_checkpoint,
|
||||
set_rng_state_for_device,
|
||||
)
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
|
|
@ -2219,46 +2219,25 @@ class Trainer:
|
|||
# number of training steps per epoch: num_update_steps_per_epoch
|
||||
# total number of training steps to execute: max_steps
|
||||
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||
(
|
||||
num_train_epochs,
|
||||
num_update_steps_per_epoch,
|
||||
num_examples,
|
||||
num_train_samples,
|
||||
epoch_based,
|
||||
len_dataloader,
|
||||
max_steps,
|
||||
) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
|
||||
|
||||
len_dataloader = None
|
||||
num_train_tokens = None
|
||||
if has_length(train_dataloader):
|
||||
len_dataloader = len(train_dataloader)
|
||||
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
num_examples = self.num_examples(train_dataloader)
|
||||
if args.max_steps > 0:
|
||||
max_steps = args.max_steps
|
||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
||||
args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
|
||||
# the best we can do.
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
|
||||
)
|
||||
if self.args.include_tokens_per_second:
|
||||
num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)
|
||||
# If going by epochs, multiply tokens linearly
|
||||
if len_dataloader is not None and epoch_based:
|
||||
num_train_tokens *= args.num_train_epochs
|
||||
# Otherwise since its steps, we just multiply by grad accum
|
||||
else:
|
||||
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
|
||||
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
|
||||
max_steps = args.max_steps
|
||||
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
||||
num_train_epochs = sys.maxsize
|
||||
num_update_steps_per_epoch = max_steps
|
||||
num_examples = total_train_batch_size * args.max_steps
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
|
||||
else:
|
||||
raise ValueError(
|
||||
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
|
||||
f" {args.max_steps}"
|
||||
)
|
||||
num_train_tokens *= args.gradient_accumulation_steps
|
||||
|
||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||
if self.args.n_gpu > 1:
|
||||
|
|
@ -2293,21 +2272,7 @@ class Trainer:
|
|||
self.state.train_batch_size = self._train_batch_size
|
||||
|
||||
# Compute absolute values for logging, eval, and save if given as ratio
|
||||
if args.logging_steps is not None:
|
||||
if args.logging_steps < 1:
|
||||
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
|
||||
else:
|
||||
self.state.logging_steps = args.logging_steps
|
||||
if args.eval_steps is not None:
|
||||
if args.eval_steps < 1:
|
||||
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
|
||||
else:
|
||||
self.state.eval_steps = args.eval_steps
|
||||
if args.save_steps is not None:
|
||||
if args.save_steps < 1:
|
||||
self.state.save_steps = math.ceil(max_steps * args.save_steps)
|
||||
else:
|
||||
self.state.save_steps = args.save_steps
|
||||
self.state.compute_steps(args, max_steps)
|
||||
|
||||
# Activate gradient checkpointing if needed
|
||||
if args.gradient_checkpointing:
|
||||
|
|
@ -2420,25 +2385,7 @@ class Trainer:
|
|||
)
|
||||
|
||||
# Update the references
|
||||
self.callback_handler.model = self.model
|
||||
self.callback_handler.optimizer = self.optimizer
|
||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
if self.hp_name is not None and self._trial is not None:
|
||||
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
|
||||
# parameter to Train when using DDP.
|
||||
self.state.trial_name = self.hp_name(self._trial)
|
||||
if trial is not None:
|
||||
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
|
||||
self.state.trial_params = hp_params(assignments)
|
||||
else:
|
||||
self.state.trial_params = None
|
||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||
# to set this after the load.
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
self.state.init_training_references(self, train_dataloader, max_steps, num_train_epochs, trial)
|
||||
|
||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
||||
tr_loss = torch.tensor(0.0).to(args.device)
|
||||
|
|
@ -2495,10 +2442,7 @@ class Trainer:
|
|||
step += 1
|
||||
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
|
||||
# Since we perform prefetching, we need to manually set sync_gradients
|
||||
if not do_sync_step:
|
||||
self.accelerator.gradient_state._set_sync_gradients(False)
|
||||
else:
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
|
||||
|
||||
if self.args.include_num_input_tokens_seen:
|
||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||
|
|
@ -2565,8 +2509,6 @@ class Trainer:
|
|||
|
||||
# Gradient clipping
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
elif self.use_apex:
|
||||
|
|
@ -2598,8 +2540,7 @@ class Trainer:
|
|||
|
||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
if optimizer_was_run:
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.lr_scheduler.step()
|
||||
|
|
@ -3119,52 +3060,19 @@ class Trainer:
|
|||
random.setstate(checkpoint_rng_state["python"])
|
||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||
if torch.cuda.is_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
|
||||
else:
|
||||
try:
|
||||
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
if is_torch_xla_available():
|
||||
xm.set_rng_state(checkpoint_rng_state["xla"])
|
||||
|
||||
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
if torch.cuda.is_available():
|
||||
set_rng_state_for_device("GPU", torch.cuda, checkpoint_rng_state, is_distributed)
|
||||
if is_torch_npu_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
|
||||
else:
|
||||
try:
|
||||
torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
|
||||
if is_torch_mlu_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
|
||||
else:
|
||||
try:
|
||||
torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
|
||||
|
||||
if is_torch_musa_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
|
||||
else:
|
||||
try:
|
||||
torch.musa.set_rng_state(checkpoint_rng_state["musa"])
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
|
||||
|
||||
def _determine_best_metric(self, metrics, trial):
|
||||
"""
|
||||
|
|
@ -5050,11 +4958,10 @@ class Trainer:
|
|||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
|
||||
if is_accelerate_available("0.28.0"):
|
||||
# Extract dataloader config params from accelerator config
|
||||
dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
|
||||
dataloader_config = DataLoaderConfiguration(
|
||||
split_batches=accelerator_config.pop("split_batches"),
|
||||
dispatch_batches=accelerator_config.pop("dispatch_batches"),
|
||||
even_batches=accelerator_config.pop("even_batches"),
|
||||
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
||||
**{param: accelerator_config.pop(param) for param in dataloader_params}
|
||||
)
|
||||
if is_accelerate_available("1.1.0"):
|
||||
dataloader_config.data_seed = self.args.data_seed
|
||||
|
|
@ -5099,12 +5006,8 @@ class Trainer:
|
|||
# post accelerator creation setup
|
||||
if self.is_fsdp_enabled:
|
||||
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
|
||||
"limit_all_gathers", fsdp_plugin.limit_all_gathers
|
||||
)
|
||||
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
|
||||
"activation_checkpointing", fsdp_plugin.activation_checkpointing
|
||||
)
|
||||
for param in ["limit_all_gathers", "activation_checkpointing"]:
|
||||
setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
|
||||
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
|
||||
raise ValueError(
|
||||
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
|
||||
|
|
@ -5186,3 +5089,63 @@ class Trainer:
|
|||
num_items_in_batch = num_items_in_batch.item()
|
||||
|
||||
return batch_samples, num_items_in_batch
|
||||
|
||||
def set_initial_training_values(
|
||||
self, args: TrainingArguments, dataloader: DataLoader, total_train_batch_size: int
|
||||
):
|
||||
"""
|
||||
Calculates and returns the following values:
|
||||
- `num_train_epochs`
|
||||
- `num_update_steps_per_epoch`
|
||||
- `num_examples`
|
||||
- `num_train_samples`
|
||||
- `epoch_based`
|
||||
- `len_dataloader`
|
||||
- `max_steps`
|
||||
"""
|
||||
# Case 1: we rely on `args.max_steps` first
|
||||
max_steps = args.max_steps
|
||||
# If max_steps is negative, we use the number of epochs to determine the number of total steps later
|
||||
epoch_based = max_steps < 0
|
||||
len_dataloader = len(dataloader) if has_length(dataloader) else None
|
||||
|
||||
# Case 2: We have a dataloader length and can extrapolate
|
||||
if len_dataloader is not None:
|
||||
num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1)
|
||||
# Case 3: We have a length but are using epochs, we can extrapolate the number of steps
|
||||
if epoch_based:
|
||||
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
||||
|
||||
# Now we figure out `num_examples`, `num_train_epochs`, and `train_samples`
|
||||
if len_dataloader:
|
||||
num_examples = self.num_examples(dataloader)
|
||||
if args.max_steps > 0:
|
||||
num_train_epochs = max_steps // num_update_steps_per_epoch + int(
|
||||
max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
|
||||
# the best we can do.
|
||||
num_train_samples = max_steps * total_train_batch_size
|
||||
else:
|
||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||
num_train_samples = self.num_examples(dataloader) * args.num_train_epochs
|
||||
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
|
||||
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
||||
num_train_epochs = sys.maxsize
|
||||
num_update_steps_per_epoch = max_steps
|
||||
num_examples = total_train_batch_size * args.max_steps
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
|
||||
f" {args.max_steps}"
|
||||
)
|
||||
return (
|
||||
num_train_epochs,
|
||||
num_update_steps_per_epoch,
|
||||
num_examples,
|
||||
num_train_samples,
|
||||
epoch_based,
|
||||
len_dataloader,
|
||||
max_steps,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@ Callbacks to use with the Trainer class and customize the training loop.
|
|||
|
||||
import dataclasses
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import IntervalStrategy, SaveStrategy, has_length
|
||||
from .trainer_utils import HPSearchBackend, IntervalStrategy, SaveStrategy, has_length
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
|
@ -150,6 +151,43 @@ class TrainerState:
|
|||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
||||
def compute_steps(self, args, max_steps):
|
||||
"""
|
||||
Calculates and stores the absolute value for logging,
|
||||
eval, and save steps based on if it was a proportion
|
||||
or not.
|
||||
"""
|
||||
for step_kind in ("logging", "eval", "save"):
|
||||
num_steps = getattr(args, f"{step_kind}_steps")
|
||||
if num_steps is not None:
|
||||
if num_steps < 1:
|
||||
num_steps = math.ceil(max_steps * num_steps)
|
||||
setattr(self, f"{step_kind}_steps", num_steps)
|
||||
|
||||
def init_training_references(self, trainer, train_dataloader, max_steps, num_train_epochs, trial):
|
||||
"""
|
||||
Stores the initial training references needed in `self`
|
||||
"""
|
||||
for attr in ("model", "optimizer", "lr_scheduler"):
|
||||
setattr(self, attr, getattr(trainer, attr))
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
if trainer.hp_name is not None and trainer._trial is not None:
|
||||
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
|
||||
# parameter to Train when using DDP.
|
||||
self.trial_name = trainer.hp_name(trainer._trial)
|
||||
self.trial_params = None
|
||||
if trial is not None:
|
||||
from transformers.integrations import hp_params
|
||||
|
||||
assignments = trial.assignments if trainer.hp_search_backend == HPSearchBackend.SIGOPT else trial
|
||||
self.trial_params = hp_params(assignments)
|
||||
|
||||
self.max_steps = max_steps
|
||||
self.num_train_epochs = num_train_epochs
|
||||
self.is_local_process_zero = trainer.is_local_process_zero()
|
||||
self.is_world_process_zero = trainer.is_world_process_zero()
|
||||
|
||||
|
||||
class ExportableState:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1390,3 +1390,17 @@ class LayerWiseDummyScheduler(LRScheduler):
|
|||
|
||||
def _get_closed_form_lr(self):
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
def set_rng_state_for_device(device_name, device_module, checkpoint_rng_state, is_distributed):
|
||||
"""Helper to set RNG state for a specific device type (CUDA, NPU, MLU, MUSA)"""
|
||||
device_state_key = device_name.lower()
|
||||
err_template = "Didn't manage to set back the RNG states of the {backend} because of the following error:\n {exception}\nThis won't yield the same results as if the training had not been interrupted."
|
||||
try:
|
||||
if is_distributed:
|
||||
device_module.random.set_rng_state_all(checkpoint_rng_state[device_state_key])
|
||||
else:
|
||||
device_module.random.set_rng_state(checkpoint_rng_state[device_state_key])
|
||||
except Exception as e:
|
||||
# Log error if setting RNG state fails
|
||||
logger.error(err_template.format(backend=device_name, exception=e))
|
||||
|
|
|
|||
Loading…
Reference in a new issue