mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[DeepSpeed] improve checkpoint loading code plus tests (#10760)
* deepspeed checkpoint loading code plus tests * style * style
This commit is contained in:
parent
01c7fb04be
commit
cd8c93f701
4 changed files with 169 additions and 34 deletions
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
|
@ -19,6 +20,8 @@ import sys
|
|||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureStd,
|
||||
|
|
@ -35,7 +38,7 @@ from transformers.trainer_utils import set_seed
|
|||
|
||||
bindir = os.path.abspath(os.path.dirname(__file__))
|
||||
sys.path.append(f"{bindir}/../../../tests")
|
||||
from test_trainer import get_regression_trainer # noqa
|
||||
from test_trainer import TrainerIntegrationCommon, get_regression_trainer # noqa
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
|
@ -60,11 +63,21 @@ def require_deepspeed(test_case):
|
|||
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
class TrainerIntegrationDeepSpeed(TestCasePlus):
|
||||
""" This class is for testing directly via get_regression_trainer """
|
||||
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
"""
|
||||
|
||||
This class is for testing directly via get_regression_trainer
|
||||
|
||||
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods which we can re-use here.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
self.dist_env_1_gpu = dict(
|
||||
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
|
||||
)
|
||||
|
|
@ -222,6 +235,101 @@ class TrainerIntegrationDeepSpeed(TestCasePlus):
|
|||
# see the note above how to get identical loss on a small bs
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
|
||||
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, is_pretrained=True):
|
||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||
ds_file_list = ["mp_rank_00_model_states.pt", "zero_pp_rank_0_mp_rank_00optim_states.pt"]
|
||||
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
self.assertTrue(os.path.isdir(checkpoint))
|
||||
|
||||
# common files
|
||||
for filename in file_list:
|
||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||
|
||||
# ds files
|
||||
ds_path = os.path.join(checkpoint, f"global_step{step}")
|
||||
for filename in ds_file_list:
|
||||
# filename = os.path.join(path, filename)
|
||||
# print(filename)
|
||||
self.assertTrue(os.path.isfile(os.path.join(ds_path, filename)))
|
||||
|
||||
def test_save_checkpoints(self):
|
||||
# adapted from TrainerIntegrationTest.test_save_checkpoints
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
freq = 5
|
||||
|
||||
# save checkpoints
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=output_dir,
|
||||
save_steps=freq,
|
||||
deepspeed=ds_config_dict,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total)
|
||||
|
||||
def test_can_resume_training(self):
|
||||
# adapted from TrainerIntegrationTest.test_can_resume_training
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict)
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Now check with a later checkpoint that it also works when we span over one epoch
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Now check failures
|
||||
|
||||
# 1. fail to find a bogus checkpoint
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
|
||||
self.assertTrue("failed to resume from checkpoint" in str(context.exception))
|
||||
|
||||
# 2. fail to find any checkpoint - due a fresh output_dir
|
||||
output_dir2 = self.get_auto_remove_tmp_dir()
|
||||
trainer = get_regression_trainer(output_dir=output_dir2, deepspeed=ds_config_dict)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
|
||||
@slow
|
||||
@require_deepspeed
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import numbers
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import logging
|
||||
|
|
@ -268,15 +269,19 @@ def rewrite_logs(d):
|
|||
return new_d
|
||||
|
||||
|
||||
def init_deepspeed(trainer, num_training_steps):
|
||||
def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
"""
|
||||
Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration
|
||||
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
|
||||
|
||||
If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made.
|
||||
|
||||
Args:
|
||||
trainer: Trainer object
|
||||
num_training_steps: per single gpu
|
||||
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
|
||||
|
||||
Returns: model, optimizer, lr_scheduler
|
||||
|
||||
"""
|
||||
import deepspeed
|
||||
|
||||
|
|
@ -287,7 +292,9 @@ def init_deepspeed(trainer, num_training_steps):
|
|||
model = trainer.model
|
||||
|
||||
if isinstance(args.deepspeed, dict):
|
||||
config = args.deepspeed
|
||||
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
||||
# modified it, it will not be accepted here again, since some config params must be not set by users
|
||||
config = deepcopy(args.deepspeed)
|
||||
elif isinstance(args.deepspeed, str):
|
||||
with io.open(ds_config_file, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
|
@ -442,6 +449,15 @@ def init_deepspeed(trainer, num_training_steps):
|
|||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
if resume_from_checkpoint is not None: # and os.path.isdir(resume_from_checkpoint):
|
||||
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
|
||||
# this magically updates self.optimizer and self.lr_scheduler
|
||||
load_path, _ = model.load_checkpoint(
|
||||
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
||||
)
|
||||
if load_path is None:
|
||||
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
|
||||
|
||||
return model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -878,7 +878,11 @@ class Trainer:
|
|||
|
||||
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
|
||||
if self.deepspeed:
|
||||
# will be resumed in init_deepspeed
|
||||
pass
|
||||
elif isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(resume_from_checkpoint)
|
||||
model_reloaded = True
|
||||
else:
|
||||
|
|
@ -920,7 +924,9 @@ class Trainer:
|
|||
|
||||
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
if self.args.deepspeed:
|
||||
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
|
||||
model, optimizer, lr_scheduler = init_deepspeed(
|
||||
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
|
||||
)
|
||||
self.model = model.module
|
||||
self.model_wrapped = model
|
||||
self.deepspeed = model # DeepSpeedEngine object
|
||||
|
|
@ -1294,6 +1300,10 @@ class Trainer:
|
|||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if self.deepspeed:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in init_deepspeed
|
||||
return
|
||||
|
||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(checkpoint, "scheduler.pt")
|
||||
):
|
||||
|
|
@ -1318,10 +1328,6 @@ class Trainer:
|
|||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
if self.deepspeed:
|
||||
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
|
||||
self.deepspeed.load_checkpoint(checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True)
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import numpy as np
|
|||
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
get_tests_dir,
|
||||
require_datasets,
|
||||
require_optuna,
|
||||
|
|
@ -235,28 +236,7 @@ if is_torch_available():
|
|||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TrainerIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
trainer.train()
|
||||
self.default_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
|
||||
trainer.train()
|
||||
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
def check_trained_model(self, model, alternate_seed=False):
|
||||
# Checks a training seeded with learning_rate = 0.1
|
||||
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
|
||||
self.assertTrue(torch.allclose(model.a, a))
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
class TrainerIntegrationCommon:
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
if is_pretrained:
|
||||
|
|
@ -306,6 +286,30 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
_ = log1.pop("train_samples_per_second", None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
trainer.train()
|
||||
self.default_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
|
||||
trainer.train()
|
||||
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
def check_trained_model(self, model, alternate_seed=False):
|
||||
# Checks a training seeded with learning_rate = 0.1
|
||||
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
|
||||
self.assertTrue(torch.allclose(model.a, a))
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
def test_trainer_works_with_dict(self):
|
||||
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
|
||||
# anything.
|
||||
|
|
@ -607,6 +611,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer.train()
|
||||
|
|
|
|||
Loading…
Reference in a new issue