From bd3b5fa50f28f78c9f7fbbd271ef375666565b6f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 10 Feb 2025 13:00:36 +0100 Subject: [PATCH] finalize sound renaming logic --- src/transformers/modeling_utils.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d9634c93d..5ad7676ac 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4688,6 +4688,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix renamed_loaded_keys = list( cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys() ) + # The device_map and weight_map have to be changed accordingly to match the keys that will be loaded then renamed + device_map = cls._fix_state_dict_keys_on_load(device_map) + if sharded_metadata is not None and "weight_map" in sharded_metadata.keys(): + sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"]) # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture prefix = model.base_model_prefix @@ -4780,10 +4784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " offers the weights in this format." ) if is_offloaded_safetensors: - param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix_to_remove) + param_device_map = expand_device_map(device_map, renamed_loaded_keys, start_prefix_to_remove) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: - weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys} + weight_map = {p: checkpoint_files[0] for p in renamed_loaded_keys} else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} @@ -4793,10 +4797,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] disk_offload_index = { - p[len(start_prefix_to_remove) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} - for p, f in weight_map.items() - if p.startswith(start_prefix_to_remove) - and param_device_map[p[len(start_prefix_to_remove) :]] == "disk" + name[len(start_prefix_to_remove) :]: {"safetensors_file": file, "weight_name": name, "dtype": str_dtype} + for name, file in weight_map.items() + if name.startswith(start_prefix_to_remove) + and param_device_map[name[len(start_prefix_to_remove) :]] == "disk" } else: disk_offload_index = {} @@ -4899,7 +4903,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Adjust offloaded weights name and save if needed if disk_offload_index is not None and len(disk_offload_index) > 0: - if model != model_to_load: + if loading_task_model_from_base_state_dict: # We need to add the prefix of the base model prefix = cls.base_model_prefix if not is_offloaded_safetensors: