mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)
* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT * update tests * fix tests
This commit is contained in:
parent
ebfdb9ca62
commit
238d2e3c44
2 changed files with 23 additions and 4 deletions
|
|
@ -2033,10 +2033,15 @@ class Trainer:
|
|||
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
|
||||
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
|
||||
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(resume_from_checkpoint)
|
||||
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
|
||||
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||
any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(resume_from_checkpoint)
|
||||
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||
)
|
||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||
)
|
||||
|
||||
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from transformers.utils import is_accelerate_available, is_torch_bf16_available_
|
|||
|
||||
if is_torch_available():
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
|
||||
from transformers.trainer import FSDP_MODEL_NAME
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_1 = False
|
||||
|
||||
|
|
@ -211,6 +212,19 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||
# resume from ckpt
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-115")
|
||||
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
|
||||
|
||||
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
|
||||
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||
any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(checkpoint)
|
||||
if os.path.isdir(os.path.join(checkpoint, folder_name))
|
||||
)
|
||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||
)
|
||||
self.assertTrue(is_fsdp_ckpt)
|
||||
|
||||
logs_resume = self.run_cmd_and_get_logs(
|
||||
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue