finalize sound renaming logic

This commit is contained in:
Cyril Vallez 2025-02-10 13:00:36 +01:00
parent f11085d6ea
commit bd3b5fa50f
No known key found for this signature in database

View file

@ -4688,6 +4688,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
renamed_loaded_keys = list(
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()
)
# The device_map and weight_map have to be changed accordingly to match the keys that will be loaded then renamed
device_map = cls._fix_state_dict_keys_on_load(device_map)
if sharded_metadata is not None and "weight_map" in sharded_metadata.keys():
sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"])
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
@ -4780,10 +4784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" offers the weights in this format."
)
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix_to_remove)
param_device_map = expand_device_map(device_map, renamed_loaded_keys, start_prefix_to_remove)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
weight_map = {p: checkpoint_files[0] for p in renamed_loaded_keys}
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
@ -4793,10 +4797,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
disk_offload_index = {
p[len(start_prefix_to_remove) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix_to_remove)
and param_device_map[p[len(start_prefix_to_remove) :]] == "disk"
name[len(start_prefix_to_remove) :]: {"safetensors_file": file, "weight_name": name, "dtype": str_dtype}
for name, file in weight_map.items()
if name.startswith(start_prefix_to_remove)
and param_device_map[name[len(start_prefix_to_remove) :]] == "disk"
}
else:
disk_offload_index = {}
@ -4899,7 +4903,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:
if model != model_to_load:
if loading_task_model_from_base_state_dict:
# We need to add the prefix of the base model
prefix = cls.base_model_prefix
if not is_offloaded_safetensors: