mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Push sharded checkpoint to hub when push_to_hub=True in TrainingArguments (#31808)
Save sharded checkpoint in Trainer
This commit is contained in:
parent
da79b18087
commit
8df28bb308
1 changed files with 10 additions and 0 deletions
|
|
@ -22,6 +22,7 @@ import functools
|
|||
import glob
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
|
@ -4215,6 +4216,15 @@ class Trainer:
|
|||
output_dir = self.args.output_dir
|
||||
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
|
||||
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
|
||||
# Add sharded checkpoints if we have an index
|
||||
for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
||||
index_path = os.path.join(checkpoint_folder, index_file)
|
||||
if os.path.isfile(index_path):
|
||||
modeling_files.append(index_file)
|
||||
with open(index_path) as f:
|
||||
index = json.loads(f.read())
|
||||
shard_files = list(set(index["weight_map"].values()))
|
||||
modeling_files.extend(shard_files)
|
||||
if is_peft_available():
|
||||
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
|
||||
for modeling_file in modeling_files:
|
||||
|
|
|
|||
Loading…
Reference in a new issue