keep improving

This commit is contained in:
Cyril Vallez 2025-02-05 15:14:50 +01:00
parent 574e3f76c9
commit a3401c3e23
No known key found for this signature in database

View file

@ -1218,7 +1218,11 @@ def get_torch_dtype(
weights_only: bool, weights_only: bool,
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]: ) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the """Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
infered dtype. infered dtype. We do the following:
1. If torch_dtype is not None, we use that dtype
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
we also may have config.torch_dtype available, but we won't rely on it till v5
""" """
dtype_orig = None dtype_orig = None
is_sharded = sharded_metadata is not None is_sharded = sharded_metadata is not None
@ -1355,7 +1359,6 @@ def find_missing_and_unexpected_keys(
remove_prefix_from_model: bool, remove_prefix_from_model: bool,
add_prefix_to_model: bool, add_prefix_to_model: bool,
hf_quantizer: Optional[HfQuantizer], hf_quantizer: Optional[HfQuantizer],
device_map: Optional[Dict],
) -> Tuple["PreTrainedModel", List[str], List[str]]: ) -> Tuple["PreTrainedModel", List[str], List[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters) (keys found in the loaded state dict keys, but that are NOT part of the model parameters)
@ -1381,17 +1384,7 @@ def find_missing_and_unexpected_keys(
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
model.tie_weights() model.tie_weights()
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): tied_params = find_tied_parameters(model)
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
for group in tied_params: for group in tied_params:
if remove_prefix_from_model: if remove_prefix_from_model:
@ -4384,11 +4377,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"tensors" "tensors"
] ]
# set dtype to instantiate the model under: # Find the correct dtype based on current state
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
config, torch_dtype, dtype_orig = get_torch_dtype( config, torch_dtype, dtype_orig = get_torch_dtype(
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
) )
@ -4683,39 +4672,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Some useful flags # Some useful flags
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
is_offloaded_safetensors = False
if device_map is not None and "disk" in device_map.values():
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
is_sharded_offloaded_safetensors = is_offloaded_safetensors and sharded_metadata is not None
# tie the model weights before retrieving the state_dict # tie the model weights before retrieving the state_dict
model.tie_weights() model.tie_weights()
# Update model keys (expected keys) and loaded keys based on prefix, quantization, etc... # Update model keys (expected keys) and loaded keys based on prefix, quantization, etc...
prefix = model.base_model_prefix
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys()) expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
if hf_quantizer is not None: if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys) expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys] renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
# Check if we need to modify prefix
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
remove_prefix_from_model = not has_prefix_module and expects_prefix_module remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module add_prefix_to_model = has_prefix_module and not expects_prefix_module
@ -4726,7 +4697,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif add_prefix_to_model: elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys] expected_keys = [".".join([prefix, s]) for s in expected_keys]
# Find missing and unexpected keys from the state dict, to later log them # Find missing and unexpected keys from the state dict
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys( model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
cls, cls,
model, model,
@ -4735,7 +4706,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model, remove_prefix_from_model,
add_prefix_to_model, add_prefix_to_model,
hf_quantizer, hf_quantizer,
device_map,
) )
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
@ -4744,6 +4714,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = move_missing_keys_back_to_cpu( model = move_missing_keys_back_to_cpu(
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
) )
# In this case we also need to move everything back
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# correctly initialize the missing keys # correctly initialize the missing keys
model = initialize_missing_keys( model = initialize_missing_keys(
@ -4776,28 +4751,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if device_map is not None: if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
# Find the folder where the checkpoints reside is_offloaded_safetensors = False
if checkpoint_files is not None:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
else:
folder = None
# This offload index if for params explicitly on the "disk" in the device_map # This offload index if for params explicitly on the "disk" in the device_map
disk_offload_index = None disk_offload_index = None
if is_offloaded_safetensors: # Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix) disk_only_shard_files = []
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if device_map is not None and "disk" in device_map.values():
if sharded_metadata is None: is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys} if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
disk_offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
else: else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} disk_offload_index = {}
disk_offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
elif device_map is not None and "disk" in device_map.values():
disk_offload_index = {}
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size, # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
@ -4808,14 +4798,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cpu_offload_folder = tempfile.mkdtemp() cpu_offload_folder = tempfile.mkdtemp()
cpu_offload_index = {} cpu_offload_index = {}
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
disk_only_shard_files = []
if is_sharded_offloaded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
map_location = None map_location = None
if ( if (
device_map is not None device_map is not None
@ -4829,24 +4811,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if checkpoint_files is not None and len(checkpoint_files) > 1: if checkpoint_files is not None and len(checkpoint_files) > 1:
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
# To be able to iterate, even if we don't use it # To be able to iterate, even if we don't use it if the state_dict is already provided
checkpoint_files = checkpoint_files if state_dict is None else [""] checkpoint_files = checkpoint_files if state_dict is None else [""]
assign_to_params_buffers = None
error_msgs = [] error_msgs = []
mismatched_keys = [] mismatched_keys = []
# Iterate on all the shards to load the weights # Iterate on all the shards to load the weights
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. # Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files: if shard_file in disk_only_shard_files:
continue continue
# If shard_file == "", we use the existing state_dict # If shard_file == "", we use the existing state_dict instead of loading it
if shard_file != "": if shard_file != "":
state_dict = load_state_dict( state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
) )
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model. # matching the weights in the model.
@ -4862,13 +4843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
if low_cpu_mem_usage or gguf_file is not None: if low_cpu_mem_usage or gguf_file is not None:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items(): pass
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else: else:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load, model_to_load,
fixed_state_dict, fixed_state_dict,
@ -4888,11 +4864,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs += new_error_msgs error_msgs += new_error_msgs
else: else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True # Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None: assign_to_params_buffers = check_support_param_buffer_assignment(
assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix
model_to_load, state_dict, start_prefix )
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs += _load_state_dict_into_model( error_msgs += _load_state_dict_into_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
) )