mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
finalize sound renaming logic
This commit is contained in:
parent
f11085d6ea
commit
bd3b5fa50f
1 changed files with 11 additions and 7 deletions
|
|
@ -4688,6 +4688,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
renamed_loaded_keys = list(
|
||||
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()
|
||||
)
|
||||
# The device_map and weight_map have to be changed accordingly to match the keys that will be loaded then renamed
|
||||
device_map = cls._fix_state_dict_keys_on_load(device_map)
|
||||
if sharded_metadata is not None and "weight_map" in sharded_metadata.keys():
|
||||
sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"])
|
||||
|
||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
|
||||
prefix = model.base_model_prefix
|
||||
|
|
@ -4780,10 +4784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
" offers the weights in this format."
|
||||
)
|
||||
if is_offloaded_safetensors:
|
||||
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix_to_remove)
|
||||
param_device_map = expand_device_map(device_map, renamed_loaded_keys, start_prefix_to_remove)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
||||
weight_map = {p: checkpoint_files[0] for p in renamed_loaded_keys}
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||
|
|
@ -4793,10 +4797,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||
disk_offload_index = {
|
||||
p[len(start_prefix_to_remove) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in weight_map.items()
|
||||
if p.startswith(start_prefix_to_remove)
|
||||
and param_device_map[p[len(start_prefix_to_remove) :]] == "disk"
|
||||
name[len(start_prefix_to_remove) :]: {"safetensors_file": file, "weight_name": name, "dtype": str_dtype}
|
||||
for name, file in weight_map.items()
|
||||
if name.startswith(start_prefix_to_remove)
|
||||
and param_device_map[name[len(start_prefix_to_remove) :]] == "disk"
|
||||
}
|
||||
else:
|
||||
disk_offload_index = {}
|
||||
|
|
@ -4899,7 +4903,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
# Adjust offloaded weights name and save if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||
if model != model_to_load:
|
||||
if loading_task_model_from_base_state_dict:
|
||||
# We need to add the prefix of the base model
|
||||
prefix = cls.base_model_prefix
|
||||
if not is_offloaded_safetensors:
|
||||
|
|
|
|||
Loading…
Reference in a new issue