mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update modeling_utils.py
This commit is contained in:
parent
a3401c3e23
commit
1bdb7bba52
1 changed files with 19 additions and 18 deletions
|
|
@ -899,7 +899,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||
return weights_name
|
||||
|
||||
|
||||
def get_checkpoint_files(
|
||||
def _get_checkpoint_files(
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
subfolder: str,
|
||||
variant: Optional[str],
|
||||
|
|
@ -916,6 +916,7 @@ def get_checkpoint_files(
|
|||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
"""Get the full checkpoint filenames where the weights reside, and optional metadata if the checkpoints are sharded."""
|
||||
is_sharded = False
|
||||
|
||||
if pretrained_model_name_or_path is not None and gguf_file is None:
|
||||
|
|
@ -1208,7 +1209,7 @@ def get_checkpoint_files(
|
|||
return checkpoint_files, sharded_metadata
|
||||
|
||||
|
||||
def get_torch_dtype(
|
||||
def _get_torch_dtype(
|
||||
cls,
|
||||
torch_dtype: Optional[Union[str, torch.dtype, Dict]],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
|
|
@ -1284,14 +1285,14 @@ def get_torch_dtype(
|
|||
return config, torch_dtype, dtype_orig
|
||||
|
||||
|
||||
def get_device_map(
|
||||
def _get_device_map(
|
||||
model: "PreTrainedModel",
|
||||
device_map: Optional[Union[str, Dict]],
|
||||
max_memory: Optional[Dict],
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
keep_in_fp32_modules: Optional[List[str]],
|
||||
) -> Tuple["PreTrainedModel", Dict]:
|
||||
) -> Dict:
|
||||
"""Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies."""
|
||||
if isinstance(device_map, str):
|
||||
special_dtypes = {}
|
||||
|
|
@ -1348,10 +1349,10 @@ def get_device_map(
|
|||
# check if we don't have tied param in different devices
|
||||
check_tied_parameters_on_same_device(tied_params, device_map)
|
||||
|
||||
return model, device_map
|
||||
return device_map
|
||||
|
||||
|
||||
def find_missing_and_unexpected_keys(
|
||||
def _find_missing_and_unexpected_keys(
|
||||
cls,
|
||||
model: "PreTrainedModel",
|
||||
expected_keys: List[str],
|
||||
|
|
@ -1359,7 +1360,7 @@ def find_missing_and_unexpected_keys(
|
|||
remove_prefix_from_model: bool,
|
||||
add_prefix_to_model: bool,
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
) -> Tuple["PreTrainedModel", List[str], List[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||
"""
|
||||
|
|
@ -1407,10 +1408,10 @@ def find_missing_and_unexpected_keys(
|
|||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
|
||||
|
||||
return model, missing_keys, unexpected_keys
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def move_missing_keys_back_to_cpu(
|
||||
def _move_missing_keys_back_to_cpu(
|
||||
model: "PreTrainedModel",
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
|
|
@ -1457,7 +1458,7 @@ def move_missing_keys_back_to_cpu(
|
|||
return model
|
||||
|
||||
|
||||
def initialize_missing_keys(
|
||||
def _initialize_missing_keys(
|
||||
model: "PreTrainedModel",
|
||||
loaded_keys: List[str],
|
||||
ignore_mismatched_sizes: bool,
|
||||
|
|
@ -1514,7 +1515,7 @@ def initialize_missing_keys(
|
|||
return model
|
||||
|
||||
|
||||
def find_mismatched_keys(
|
||||
def _find_mismatched_keys(
|
||||
state_dict: Dict,
|
||||
model_state_dict: Dict,
|
||||
renamed_loaded_keys: List[str],
|
||||
|
|
@ -4312,7 +4313,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
|
||||
)
|
||||
|
||||
checkpoint_files, sharded_metadata = get_checkpoint_files(
|
||||
checkpoint_files, sharded_metadata = _get_checkpoint_files(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
variant=variant,
|
||||
|
|
@ -4378,7 +4379,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
]
|
||||
|
||||
# Find the correct dtype based on current state
|
||||
config, torch_dtype, dtype_orig = get_torch_dtype(
|
||||
config, torch_dtype, dtype_orig = _get_torch_dtype(
|
||||
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
|
||||
)
|
||||
|
||||
|
|
@ -4440,7 +4441,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
# Prepare the full device map
|
||||
if device_map is not None:
|
||||
model, device_map = get_device_map(
|
||||
device_map = _get_device_map(
|
||||
model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules
|
||||
)
|
||||
|
||||
|
|
@ -4698,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||
|
||||
# Find missing and unexpected keys from the state dict
|
||||
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
|
||||
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
|
||||
cls,
|
||||
model,
|
||||
expected_keys,
|
||||
|
|
@ -4711,7 +4712,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
|
||||
# they are not in the loaded state dict)
|
||||
if low_cpu_mem_usage:
|
||||
model = move_missing_keys_back_to_cpu(
|
||||
model = _move_missing_keys_back_to_cpu(
|
||||
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
|
||||
)
|
||||
# In this case we also need to move everything back
|
||||
|
|
@ -4721,7 +4722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
|
||||
# correctly initialize the missing keys
|
||||
model = initialize_missing_keys(
|
||||
model = _initialize_missing_keys(
|
||||
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||
)
|
||||
|
||||
|
|
@ -4831,7 +4832,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys += find_mismatched_keys(
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
renamed_loaded_keys,
|
||||
|
|
|
|||
Loading…
Reference in a new issue