Push sharded checkpoint to hub when push_to_hub=True in TrainingArguments (#31808)

Save sharded checkpoint in Trainer
This commit is contained in:
Marc Sun 2024-07-10 15:14:20 +02:00 committed by GitHub
parent da79b18087
commit 8df28bb308
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -22,6 +22,7 @@ import functools
import glob
import importlib.metadata
import inspect
import json
import math
import os
import random
@ -4215,6 +4216,15 @@ class Trainer:
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
# Add sharded checkpoints if we have an index
for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
index_path = os.path.join(checkpoint_folder, index_file)
if os.path.isfile(index_path):
modeling_files.append(index_file)
with open(index_path) as f:
index = json.loads(f.read())
shard_files = list(set(index["weight_map"].values()))
modeling_files.extend(shard_files)
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files: