diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 30e195e05..61f5b4d2b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1218,7 +1218,11 @@ def get_torch_dtype( weights_only: bool, ) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]: """Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the - infered dtype. + infered dtype. We do the following: + 1. If torch_dtype is not None, we use that dtype + 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + we also may have config.torch_dtype available, but we won't rely on it till v5 """ dtype_orig = None is_sharded = sharded_metadata is not None @@ -1355,7 +1359,6 @@ def find_missing_and_unexpected_keys( remove_prefix_from_model: bool, add_prefix_to_model: bool, hf_quantizer: Optional[HfQuantizer], - device_map: Optional[Dict], ) -> Tuple["PreTrainedModel", List[str], List[str]]: """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys (keys found in the loaded state dict keys, but that are NOT part of the model parameters) @@ -1381,17 +1384,7 @@ def find_missing_and_unexpected_keys( unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] model.tie_weights() - if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): - ptrs = collections.defaultdict(list) - for name, tensor in model.state_dict().items(): - id_tensor = id_tensor_storage(tensor) - ptrs[id_tensor].append(name) - - # These are all the pointers of shared tensors. - tied_params = [names for _, names in ptrs.items() if len(names) > 1] - else: - # id function doesn't work for meta tensor so we need this function - tied_params = find_tied_parameters(model) + tied_params = find_tied_parameters(model) for group in tied_params: if remove_prefix_from_model: @@ -4384,11 +4377,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "tensors" ] - # set dtype to instantiate the model under: - # 1. If torch_dtype is not None, we use that dtype - # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first - # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype - # we also may have config.torch_dtype available, but we won't rely on it till v5 + # Find the correct dtype based on current state config, torch_dtype, dtype_orig = get_torch_dtype( cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only ) @@ -4683,39 +4672,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Some useful flags is_quantized = hf_quantizer is not None is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None - is_offloaded_safetensors = False - - if device_map is not None and "disk" in device_map.values(): - is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors") - if disk_offload_folder is None and not is_offloaded_safetensors: - raise ValueError( - "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" - " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" - " offers the weights in this format." - ) - if disk_offload_folder is not None: - os.makedirs(disk_offload_folder, exist_ok=True) - if offload_state_dict is None: - offload_state_dict = True - - is_sharded_offloaded_safetensors = is_offloaded_safetensors and sharded_metadata is not None # tie the model weights before retrieving the state_dict model.tie_weights() # Update model keys (expected keys) and loaded keys based on prefix, quantization, etc... + prefix = model.base_model_prefix model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) - prefix = model.base_model_prefix - if hf_quantizer is not None: expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys) - renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys] + # Check if we need to modify prefix has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False - remove_prefix_from_model = not has_prefix_module and expects_prefix_module add_prefix_to_model = has_prefix_module and not expects_prefix_module @@ -4726,7 +4697,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif add_prefix_to_model: expected_keys = [".".join([prefix, s]) for s in expected_keys] - # Find missing and unexpected keys from the state dict, to later log them + # Find missing and unexpected keys from the state dict model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys( cls, model, @@ -4735,7 +4706,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix remove_prefix_from_model, add_prefix_to_model, hf_quantizer, - device_map, ) # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as @@ -4744,6 +4714,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model = move_missing_keys_back_to_cpu( model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer ) + # In this case we also need to move everything back + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: + for key, param in model.state_dict().items(): + if param.device == torch.device("meta"): + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) # correctly initialize the missing keys model = initialize_missing_keys( @@ -4776,28 +4751,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if device_map is not None: device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} - # Find the folder where the checkpoints reside - if checkpoint_files is not None: - folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) - else: - folder = None - + is_offloaded_safetensors = False # This offload index if for params explicitly on the "disk" in the device_map disk_offload_index = None - if is_offloaded_safetensors: - param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix) - str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" - if sharded_metadata is None: - weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys} + # Find checkpoint files containing only weights offloaded to disk if any (allows to be faster) + disk_only_shard_files = [] + if device_map is not None and "disk" in device_map.values(): + is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors") + if disk_offload_folder is None and not is_offloaded_safetensors: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if disk_offload_folder is not None: + os.makedirs(disk_offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + if is_offloaded_safetensors: + param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix) + str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" + if sharded_metadata is None: + weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys} + else: + folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) + weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} + # Find potential checkpoints containing only offloaded weights + disk_only_shard_files = get_disk_only_shard_files( + device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix + ) + disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] + disk_offload_index = { + p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} + for p, f in weight_map.items() + if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" + } else: - weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} - disk_offload_index = { - p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} - for p, f in weight_map.items() - if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" - } - elif device_map is not None and "disk" in device_map.values(): - disk_offload_index = {} + disk_offload_index = {} # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size, @@ -4808,14 +4798,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix cpu_offload_folder = tempfile.mkdtemp() cpu_offload_index = {} - # Find checkpoint files containing only weights offloaded to disk if any (allows to be faster) - disk_only_shard_files = [] - if is_sharded_offloaded_safetensors: - disk_only_shard_files = get_disk_only_shard_files( - device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix - ) - disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] - map_location = None if ( device_map is not None @@ -4829,24 +4811,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if checkpoint_files is not None and len(checkpoint_files) > 1: checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") - # To be able to iterate, even if we don't use it + # To be able to iterate, even if we don't use it if the state_dict is already provided checkpoint_files = checkpoint_files if state_dict is None else [""] - assign_to_params_buffers = None - error_msgs = [] mismatched_keys = [] # Iterate on all the shards to load the weights for shard_file in checkpoint_files: - # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. + # Skip the load for shards that only contain disk-offloaded weights if shard_file in disk_only_shard_files: continue - # If shard_file == "", we use the existing state_dict + # If shard_file == "", we use the existing state_dict instead of loading it if shard_file != "": state_dict = load_state_dict( shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only ) + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. @@ -4862,13 +4843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) if low_cpu_mem_usage or gguf_file is not None: if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: - for key, param in model_to_load.state_dict().items(): - if param.device == torch.device("meta"): - set_module_tensor_to_device( - model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) - ) + pass else: - fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( model_to_load, fixed_state_dict, @@ -4888,11 +4864,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix error_msgs += new_error_msgs else: # Sharded checkpoint or whole but low_cpu_mem_usage==True - if assign_to_params_buffers is None: - assign_to_params_buffers = check_support_param_buffer_assignment( - model_to_load, state_dict, start_prefix - ) - fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) error_msgs += _load_state_dict_into_model( model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers )