mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Updated checkpoint support for Sagemaker Model Parallel (#17219)
* adding partial checkpoint support for optimizer state * formatted trainer.py * Refactoring based on comments * reformatting * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Cavdar <dcavdar@a07817b12d7e.ant.amazon.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
71d18d0831
commit
518dd1277e
1 changed files with 59 additions and 30 deletions
|
|
@ -18,6 +18,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
|||
|
||||
import contextlib
|
||||
import functools
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
|
|
@ -1305,7 +1306,7 @@ class Trainer:
|
|||
if resume_from_checkpoint is None:
|
||||
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
|
||||
|
||||
if resume_from_checkpoint is not None:
|
||||
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
|
||||
self._load_from_checkpoint(resume_from_checkpoint)
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
|
|
@ -1406,6 +1407,9 @@ class Trainer:
|
|||
|
||||
model = self._wrap_model(self.model_wrapped)
|
||||
|
||||
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
|
||||
self._load_from_checkpoint(resume_from_checkpoint, model)
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
|
|
@ -1671,6 +1675,8 @@ class Trainer:
|
|||
xm.rendezvous("load_best_model_at_end")
|
||||
elif args.local_rank != -1:
|
||||
dist.barrier()
|
||||
elif is_sagemaker_mp_enabled():
|
||||
smp.barrier()
|
||||
|
||||
self._load_best_model()
|
||||
|
||||
|
|
@ -1693,7 +1699,12 @@ class Trainer:
|
|||
|
||||
return TrainOutput(self.state.global_step, train_loss, metrics)
|
||||
|
||||
def _load_from_checkpoint(self, resume_from_checkpoint):
|
||||
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
|
||||
|
||||
if model is None:
|
||||
model = self.model
|
||||
strict_load = is_sagemaker_mp_enabled()
|
||||
|
||||
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
|
|
@ -1718,20 +1729,22 @@ class Trainer:
|
|||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
load_result = self.model.load_state_dict(state_dict, strict=False)
|
||||
self._issue_warnings_after_load(load_result)
|
||||
|
||||
load_result = model.load_state_dict(state_dict, strict=strict_load)
|
||||
if not strict_load:
|
||||
self._issue_warnings_after_load(load_result)
|
||||
# release memory
|
||||
del state_dict
|
||||
else:
|
||||
# We load the sharded checkpoint
|
||||
load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False)
|
||||
self._issue_warnings_after_load(load_result)
|
||||
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
|
||||
if not strict_load:
|
||||
self._issue_warnings_after_load(load_result)
|
||||
|
||||
def _load_best_model(self):
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
|
||||
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
|
||||
strict_load = is_sagemaker_mp_enabled()
|
||||
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if os.path.exists(best_model_path):
|
||||
if self.deepspeed:
|
||||
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
|
||||
|
|
@ -1748,12 +1761,13 @@ class Trainer:
|
|||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
load_result = self.model.load_state_dict(state_dict, strict=False)
|
||||
self._issue_warnings_after_load(load_result)
|
||||
load_result = model.load_state_dict(state_dict, strict=strict_load)
|
||||
if not strict_load:
|
||||
self._issue_warnings_after_load(load_result)
|
||||
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
|
||||
# Best model is a sharded checkpoint
|
||||
load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
|
||||
self._issue_warnings_after_load(load_result)
|
||||
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
|
||||
if not strict_load:
|
||||
self._issue_warnings_after_load(load_result)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
|
||||
|
|
@ -1891,17 +1905,21 @@ class Trainer:
|
|||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
if smp.rdp_rank() == 0:
|
||||
# Consolidate the state dict on all processed of rdp_rank 0
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
# Save it and the scheduler on the main process
|
||||
if self.args.should_save:
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
|
||||
smp.barrier()
|
||||
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
|
||||
smp.save(
|
||||
opt_state_dict,
|
||||
os.path.join(output_dir, OPTIMIZER_NAME),
|
||||
partial=True,
|
||||
v3=smp.state.cfg.shard_optimizer_state,
|
||||
)
|
||||
if self.args.should_save:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.deepspeed:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
|
|
@ -1950,6 +1968,7 @@ class Trainer:
|
|||
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
|
||||
# not yet exist.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||
if local_rank == -1:
|
||||
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||
|
|
@ -1972,9 +1991,12 @@ class Trainer:
|
|||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
return
|
||||
|
||||
if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
|
||||
os.path.join(checkpoint, SCHEDULER_NAME)
|
||||
):
|
||||
checkpoint_file_exists = (
|
||||
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
|
||||
if is_sagemaker_mp_enabled()
|
||||
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
|
||||
)
|
||||
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_tpu_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
|
|
@ -1990,9 +2012,16 @@ class Trainer:
|
|||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||
else:
|
||||
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
if is_sagemaker_mp_enabled():
|
||||
|
||||
def opt_load_hook(mod, opt):
|
||||
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
|
||||
|
||||
self.model_wrapped.register_post_step_hook(opt_load_hook)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
|
|
|||
Loading…
Reference in a new issue