From 8df28bb308e5676ad92eebafec2c4f2c3ebe5f31 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:14:20 +0200 Subject: [PATCH] Push sharded checkpoint to hub when `push_to_hub=True` in `TrainingArguments` (#31808) Save sharded checkpoint in Trainer --- src/transformers/trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4119e547a..70f4745f9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -22,6 +22,7 @@ import functools import glob import importlib.metadata import inspect +import json import math import os import random @@ -4215,6 +4216,15 @@ class Trainer: output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + # Add sharded checkpoints if we have an index + for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + index_path = os.path.join(checkpoint_folder, index_file) + if os.path.isfile(index_path): + modeling_files.append(index_file) + with open(index_path) as f: + index = json.loads(f.read()) + shard_files = list(set(index["weight_map"].values())) + modeling_files.extend(shard_files) if is_peft_available(): modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: