Update modeling_utils.py

This commit is contained in:
Cyril Vallez 2025-02-05 15:27:08 +01:00
parent a3401c3e23
commit 1bdb7bba52
No known key found for this signature in database

View file

@ -899,7 +899,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name
def get_checkpoint_files(
def _get_checkpoint_files(
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
subfolder: str,
variant: Optional[str],
@ -916,6 +916,7 @@ def get_checkpoint_files(
revision: str,
commit_hash: Optional[str],
) -> Tuple[Optional[List[str]], Optional[Dict]]:
"""Get the full checkpoint filenames where the weights reside, and optional metadata if the checkpoints are sharded."""
is_sharded = False
if pretrained_model_name_or_path is not None and gguf_file is None:
@ -1208,7 +1209,7 @@ def get_checkpoint_files(
return checkpoint_files, sharded_metadata
def get_torch_dtype(
def _get_torch_dtype(
cls,
torch_dtype: Optional[Union[str, torch.dtype, Dict]],
checkpoint_files: Optional[List[str]],
@ -1284,14 +1285,14 @@ def get_torch_dtype(
return config, torch_dtype, dtype_orig
def get_device_map(
def _get_device_map(
model: "PreTrainedModel",
device_map: Optional[Union[str, Dict]],
max_memory: Optional[Dict],
hf_quantizer: Optional[HfQuantizer],
torch_dtype: Optional[torch.dtype],
keep_in_fp32_modules: Optional[List[str]],
) -> Tuple["PreTrainedModel", Dict]:
) -> Dict:
"""Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies."""
if isinstance(device_map, str):
special_dtypes = {}
@ -1348,10 +1349,10 @@ def get_device_map(
# check if we don't have tied param in different devices
check_tied_parameters_on_same_device(tied_params, device_map)
return model, device_map
return device_map
def find_missing_and_unexpected_keys(
def _find_missing_and_unexpected_keys(
cls,
model: "PreTrainedModel",
expected_keys: List[str],
@ -1359,7 +1360,7 @@ def find_missing_and_unexpected_keys(
remove_prefix_from_model: bool,
add_prefix_to_model: bool,
hf_quantizer: Optional[HfQuantizer],
) -> Tuple["PreTrainedModel", List[str], List[str]]:
) -> Tuple[List[str], List[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
"""
@ -1407,10 +1408,10 @@ def find_missing_and_unexpected_keys(
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
return model, missing_keys, unexpected_keys
return missing_keys, unexpected_keys
def move_missing_keys_back_to_cpu(
def _move_missing_keys_back_to_cpu(
model: "PreTrainedModel",
missing_keys: List[str],
unexpected_keys: List[str],
@ -1457,7 +1458,7 @@ def move_missing_keys_back_to_cpu(
return model
def initialize_missing_keys(
def _initialize_missing_keys(
model: "PreTrainedModel",
loaded_keys: List[str],
ignore_mismatched_sizes: bool,
@ -1514,7 +1515,7 @@ def initialize_missing_keys(
return model
def find_mismatched_keys(
def _find_mismatched_keys(
state_dict: Dict,
model_state_dict: Dict,
renamed_loaded_keys: List[str],
@ -4312,7 +4313,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
)
checkpoint_files, sharded_metadata = get_checkpoint_files(
checkpoint_files, sharded_metadata = _get_checkpoint_files(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
variant=variant,
@ -4378,7 +4379,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
]
# Find the correct dtype based on current state
config, torch_dtype, dtype_orig = get_torch_dtype(
config, torch_dtype, dtype_orig = _get_torch_dtype(
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
)
@ -4440,7 +4441,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Prepare the full device map
if device_map is not None:
model, device_map = get_device_map(
device_map = _get_device_map(
model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules
)
@ -4698,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = [".".join([prefix, s]) for s in expected_keys]
# Find missing and unexpected keys from the state dict
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
cls,
model,
expected_keys,
@ -4711,7 +4712,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
# they are not in the loaded state dict)
if low_cpu_mem_usage:
model = move_missing_keys_back_to_cpu(
model = _move_missing_keys_back_to_cpu(
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
)
# In this case we also need to move everything back
@ -4721,7 +4722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# correctly initialize the missing keys
model = initialize_missing_keys(
model = _initialize_missing_keys(
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
)
@ -4831,7 +4832,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += find_mismatched_keys(
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
renamed_loaded_keys,