diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4119e547a..70f4745f9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -22,6 +22,7 @@ import functools import glob import importlib.metadata import inspect +import json import math import os import random @@ -4215,6 +4216,15 @@ class Trainer: output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + # Add sharded checkpoints if we have an index + for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + index_path = os.path.join(checkpoint_folder, index_file) + if os.path.isfile(index_path): + modeling_files.append(index_file) + with open(index_path) as f: + index = json.loads(f.read()) + shard_files = list(set(index["weight_map"].values())) + modeling_files.extend(shard_files) if is_peft_available(): modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: