mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[deepspeed] saving checkpoint fallback when fp16 weights aren't saved (#14948)
* [deepspeed] saving checkpoint fallback when fp16 weights aren't saved * Bump required deepspeed version to match usage when saving checkpoints * update version Co-authored-by: Mihai Balint <balint.mihai@gmail.com>
This commit is contained in:
parent
d25e25ee2b
commit
297602c7f4
3 changed files with 8 additions and 3 deletions
2
setup.py
2
setup.py
|
|
@ -98,7 +98,7 @@ _deps = [
|
|||
"cookiecutter==1.7.2",
|
||||
"dataclasses",
|
||||
"datasets",
|
||||
"deepspeed>=0.5.7",
|
||||
"deepspeed>=0.5.9",
|
||||
"fairscale>0.3",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ deps = {
|
|||
"cookiecutter": "cookiecutter==1.7.2",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets",
|
||||
"deepspeed": "deepspeed>=0.5.7",
|
||||
"deepspeed": "deepspeed>=0.5.9",
|
||||
"fairscale": "fairscale>0.3",
|
||||
"faiss-cpu": "faiss-cpu",
|
||||
"fastapi": "fastapi",
|
||||
|
|
|
|||
|
|
@ -2054,7 +2054,12 @@ class Trainer:
|
|||
# now save the real model if stage3_gather_fp16_weights_on_model_save=True
|
||||
# if false it will not be saved.
|
||||
# This must be called on all ranks
|
||||
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
|
||||
if not self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME):
|
||||
logger.warning(
|
||||
"deepspeed.save_fp16_model didn't save the model, since stage3_gather_fp16_weights_on_model_save=false. "
|
||||
"Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
|
|
|||
Loading…
Reference in a new issue