keep improving

This commit is contained in:
Cyril Vallez 2025-02-05 15:14:50 +01:00
parent 574e3f76c9
commit a3401c3e23
No known key found for this signature in database

View file

@ -1218,7 +1218,11 @@ def get_torch_dtype(
weights_only: bool,
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
infered dtype.
infered dtype. We do the following:
1. If torch_dtype is not None, we use that dtype
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
we also may have config.torch_dtype available, but we won't rely on it till v5
"""
dtype_orig = None
is_sharded = sharded_metadata is not None
@ -1355,7 +1359,6 @@ def find_missing_and_unexpected_keys(
remove_prefix_from_model: bool,
add_prefix_to_model: bool,
hf_quantizer: Optional[HfQuantizer],
device_map: Optional[Dict],
) -> Tuple["PreTrainedModel", List[str], List[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
@ -1381,17 +1384,7 @@ def find_missing_and_unexpected_keys(
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
model.tie_weights()
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
tied_params = find_tied_parameters(model)
for group in tied_params:
if remove_prefix_from_model:
@ -4384,11 +4377,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"tensors"
]
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
# Find the correct dtype based on current state
config, torch_dtype, dtype_orig = get_torch_dtype(
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
)
@ -4683,39 +4672,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Some useful flags
is_quantized = hf_quantizer is not None
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
is_offloaded_safetensors = False
if device_map is not None and "disk" in device_map.values():
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
is_sharded_offloaded_safetensors = is_offloaded_safetensors and sharded_metadata is not None
# tie the model weights before retrieving the state_dict
model.tie_weights()
# Update model keys (expected keys) and loaded keys based on prefix, quantization, etc...
prefix = model.base_model_prefix
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
# Check if we need to modify prefix
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
@ -4726,7 +4697,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
# Find missing and unexpected keys from the state dict, to later log them
# Find missing and unexpected keys from the state dict
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
cls,
model,
@ -4735,7 +4706,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model,
add_prefix_to_model,
hf_quantizer,
device_map,
)
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
@ -4744,6 +4714,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = move_missing_keys_back_to_cpu(
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
)
# In this case we also need to move everything back
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# correctly initialize the missing keys
model = initialize_missing_keys(
@ -4776,28 +4751,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
# Find the folder where the checkpoints reside
if checkpoint_files is not None:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
else:
folder = None
is_offloaded_safetensors = False
# This offload index if for params explicitly on the "disk" in the device_map
disk_offload_index = None
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
disk_only_shard_files = []
if device_map is not None and "disk" in device_map.values():
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
disk_offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
disk_offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
elif device_map is not None and "disk" in device_map.values():
disk_offload_index = {}
disk_offload_index = {}
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
@ -4808,14 +4798,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cpu_offload_folder = tempfile.mkdtemp()
cpu_offload_index = {}
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
disk_only_shard_files = []
if is_sharded_offloaded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
map_location = None
if (
device_map is not None
@ -4829,24 +4811,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if checkpoint_files is not None and len(checkpoint_files) > 1:
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
# To be able to iterate, even if we don't use it
# To be able to iterate, even if we don't use it if the state_dict is already provided
checkpoint_files = checkpoint_files if state_dict is None else [""]
assign_to_params_buffers = None
error_msgs = []
mismatched_keys = []
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
continue
# If shard_file == "", we use the existing state_dict
# If shard_file == "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
@ -4862,13 +4843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if low_cpu_mem_usage or gguf_file is not None:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
pass
else:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
fixed_state_dict,
@ -4888,11 +4864,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs += new_error_msgs
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)