mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add type hints/docstring
This commit is contained in:
parent
bbab9b26e0
commit
11e378024d
1 changed files with 144 additions and 104 deletions
|
|
@ -752,22 +752,21 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict_into_meta_model(
|
def _load_state_dict_into_meta_model(
|
||||||
model,
|
model: "PreTrainedModel",
|
||||||
state_dict,
|
state_dict: Dict,
|
||||||
start_prefix,
|
start_prefix: str,
|
||||||
expected_keys,
|
expected_keys: Dict,
|
||||||
device_map=None,
|
device_map: Optional[Dict] = None,
|
||||||
offload_folder=None,
|
disk_offload_folder: Optional[str] = None,
|
||||||
offload_index=None,
|
disk_offload_index: Optional[Dict] = None,
|
||||||
state_dict_folder=None,
|
cpu_offload_folder: Optional[str] = None,
|
||||||
state_dict_index=None,
|
cpu_offload_index: Optional[Dict] = None,
|
||||||
dtype=None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
hf_quantizer=None,
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
is_safetensors=False,
|
is_safetensors: bool = False,
|
||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
unexpected_keys: Optional[Dict] = None, # passing `unexpected` for cleanup from quantization items
|
||||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
) -> Tuple[List[str], Optional[Dict], Optional[Dict]]:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
||||||
|
|
@ -857,9 +856,9 @@ def _load_state_dict_into_meta_model(
|
||||||
|
|
||||||
if param_device == "disk":
|
if param_device == "disk":
|
||||||
if not is_safetensors:
|
if not is_safetensors:
|
||||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
|
||||||
elif param_device == "cpu" and state_dict_index is not None:
|
elif param_device == "cpu" and cpu_offload_index is not None:
|
||||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index)
|
||||||
elif (
|
elif (
|
||||||
not is_quantized
|
not is_quantized
|
||||||
or (not hf_quantizer.requires_parameters_quantization)
|
or (not hf_quantizer.requires_parameters_quantization)
|
||||||
|
|
@ -892,7 +891,7 @@ def _load_state_dict_into_meta_model(
|
||||||
setattr(module, tensor_name, value)
|
setattr(module, tensor_name, value)
|
||||||
# TODO: consider removing used param_parts from state_dict before return
|
# TODO: consider removing used param_parts from state_dict before return
|
||||||
|
|
||||||
return error_msgs, offload_index, state_dict_index
|
return error_msgs, disk_offload_index, cpu_offload_index
|
||||||
|
|
||||||
|
|
||||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||||
|
|
@ -905,7 +904,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_files(
|
def get_checkpoint_files(
|
||||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||||
subfolder: str,
|
subfolder: str,
|
||||||
variant: Optional[str],
|
variant: Optional[str],
|
||||||
gguf_file: Optional[str],
|
gguf_file: Optional[str],
|
||||||
|
|
@ -920,7 +919,7 @@ def get_checkpoint_files(
|
||||||
user_agent: dict,
|
user_agent: dict,
|
||||||
revision: str,
|
revision: str,
|
||||||
commit_hash: Optional[str],
|
commit_hash: Optional[str],
|
||||||
) -> Tuple[List[str], Dict]:
|
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||||
is_sharded = False
|
is_sharded = False
|
||||||
|
|
||||||
if pretrained_model_name_or_path is not None and gguf_file is None:
|
if pretrained_model_name_or_path is not None and gguf_file is None:
|
||||||
|
|
@ -1216,12 +1215,15 @@ def get_checkpoint_files(
|
||||||
def get_torch_dtype(
|
def get_torch_dtype(
|
||||||
cls,
|
cls,
|
||||||
torch_dtype: Optional[Union[str, torch.dtype, Dict]],
|
torch_dtype: Optional[Union[str, torch.dtype, Dict]],
|
||||||
checkpoint_files: List[str],
|
checkpoint_files: Optional[List[str]],
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
sharded_metadata: Optional[Dict],
|
sharded_metadata: Optional[Dict],
|
||||||
state_dict: Optional[Dict],
|
state_dict: Optional[Dict],
|
||||||
weights_only: bool,
|
weights_only: bool,
|
||||||
):
|
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||||
|
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||||
|
infered dtype.
|
||||||
|
"""
|
||||||
dtype_orig = None
|
dtype_orig = None
|
||||||
is_sharded = sharded_metadata is not None
|
is_sharded = sharded_metadata is not None
|
||||||
|
|
||||||
|
|
@ -1282,7 +1284,15 @@ def get_torch_dtype(
|
||||||
return config, torch_dtype, dtype_orig
|
return config, torch_dtype, dtype_orig
|
||||||
|
|
||||||
|
|
||||||
def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules):
|
def get_device_map(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
device_map: Optional[Union[str, Dict]],
|
||||||
|
max_memory: Optional[Dict],
|
||||||
|
hf_quantizer: Optional[HfQuantizer],
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
keep_in_fp32_modules: Optional[List[str]],
|
||||||
|
) -> Tuple["PreTrainedModel", Dict]:
|
||||||
|
"""Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies."""
|
||||||
if isinstance(device_map, str):
|
if isinstance(device_map, str):
|
||||||
special_dtypes = {}
|
special_dtypes = {}
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
|
|
@ -1342,8 +1352,18 @@ def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, kee
|
||||||
|
|
||||||
|
|
||||||
def find_missing_and_unexpected_keys(
|
def find_missing_and_unexpected_keys(
|
||||||
cls, model, expected_keys, loaded_keys, remove_prefix_from_model, add_prefix_to_model, hf_quantizer, device_map
|
cls,
|
||||||
):
|
model: "PreTrainedModel",
|
||||||
|
expected_keys: List[str],
|
||||||
|
loaded_keys: List[str],
|
||||||
|
remove_prefix_from_model: bool,
|
||||||
|
add_prefix_to_model: bool,
|
||||||
|
hf_quantizer: Optional[HfQuantizer],
|
||||||
|
device_map: Optional[Dict],
|
||||||
|
) -> Tuple["PreTrainedModel", List[str], List[str]]:
|
||||||
|
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||||
|
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||||
|
"""
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
_prefix = f"{prefix}."
|
_prefix = f"{prefix}."
|
||||||
|
|
||||||
|
|
@ -1362,7 +1382,7 @@ def find_missing_and_unexpected_keys(
|
||||||
# Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858)
|
# Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858)
|
||||||
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
|
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
|
||||||
if has_inv_freq_buffers:
|
if has_inv_freq_buffers:
|
||||||
unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k}
|
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||||
|
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||||
|
|
@ -1402,14 +1422,17 @@ def find_missing_and_unexpected_keys(
|
||||||
|
|
||||||
|
|
||||||
def move_missing_keys_back_to_cpu(
|
def move_missing_keys_back_to_cpu(
|
||||||
model,
|
model: "PreTrainedModel",
|
||||||
missing_keys,
|
missing_keys: List[str],
|
||||||
unexpected_keys,
|
unexpected_keys: List[str],
|
||||||
dtype,
|
dtype: Optional[torch.dtype],
|
||||||
keep_in_fp32_modules,
|
keep_in_fp32_modules: Optional[List[str]],
|
||||||
is_quantized,
|
hf_quantizer: Optional[HfQuantizer],
|
||||||
hf_quantizer,
|
) -> "PreTrainedModel":
|
||||||
):
|
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
|
||||||
|
from meta device to cpu.
|
||||||
|
"""
|
||||||
|
is_quantized = hf_quantizer is not None
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
|
|
@ -1446,8 +1469,18 @@ def move_missing_keys_back_to_cpu(
|
||||||
|
|
||||||
|
|
||||||
def initialize_missing_keys(
|
def initialize_missing_keys(
|
||||||
model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
model: "PreTrainedModel",
|
||||||
):
|
loaded_keys: List[str],
|
||||||
|
ignore_mismatched_sizes: bool,
|
||||||
|
has_prefix_module: bool,
|
||||||
|
expects_prefix_module: bool,
|
||||||
|
is_quantized: bool,
|
||||||
|
) -> "PreTrainedModel":
|
||||||
|
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
|
||||||
|
`_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to
|
||||||
|
be initialized correctly (i.e. weight initialization distribution).
|
||||||
|
Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
|
||||||
|
"""
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
||||||
|
|
@ -1493,18 +1526,20 @@ def initialize_missing_keys(
|
||||||
|
|
||||||
|
|
||||||
def find_mismatched_keys(
|
def find_mismatched_keys(
|
||||||
state_dict,
|
state_dict: Dict,
|
||||||
model_state_dict,
|
model_state_dict: Dict,
|
||||||
loaded_keys,
|
renamed_loaded_keys: List[str],
|
||||||
original_loaded_keys,
|
original_loaded_keys: List[str],
|
||||||
add_prefix_to_model,
|
add_prefix_to_model: bool,
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model: bool,
|
||||||
ignore_mismatched_sizes,
|
ignore_mismatched_sizes: bool,
|
||||||
prefix,
|
prefix: str,
|
||||||
):
|
) -> List:
|
||||||
|
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
|
||||||
|
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
if ignore_mismatched_sizes:
|
if ignore_mismatched_sizes:
|
||||||
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
|
for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys):
|
||||||
# If the checkpoint is sharded, we may not have the key here.
|
# If the checkpoint is sharded, we may not have the key here.
|
||||||
if checkpoint_key not in state_dict:
|
if checkpoint_key not in state_dict:
|
||||||
continue
|
continue
|
||||||
|
|
@ -4459,10 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
sharded_metadata=sharded_metadata,
|
sharded_metadata=sharded_metadata,
|
||||||
_fast_init=_fast_init,
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
offload_folder=offload_folder,
|
disk_offload_folder=offload_folder,
|
||||||
offload_state_dict=offload_state_dict,
|
offload_state_dict=offload_state_dict,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
|
|
@ -4637,22 +4671,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_pretrained_model(
|
def _load_pretrained_model(
|
||||||
cls,
|
cls,
|
||||||
model,
|
model: "PreTrainedModel",
|
||||||
state_dict,
|
state_dict: Optional[Dict],
|
||||||
checkpoint_files,
|
checkpoint_files: Optional[List[str]],
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path: Optional[str],
|
||||||
ignore_mismatched_sizes=False,
|
ignore_mismatched_sizes: bool = False,
|
||||||
sharded_metadata=None,
|
sharded_metadata: Optional[Dict] = None,
|
||||||
_fast_init=True,
|
low_cpu_mem_usage: bool = False,
|
||||||
low_cpu_mem_usage=False,
|
device_map: Optional[Dict] = None,
|
||||||
device_map=None,
|
disk_offload_folder: Optional[str] = None,
|
||||||
offload_folder=None,
|
offload_state_dict: Optional[bool] = None,
|
||||||
offload_state_dict=None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
dtype=None,
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
hf_quantizer=None,
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||||
keep_in_fp32_modules=None,
|
gguf_file: Optional[str] = None,
|
||||||
gguf_file=None,
|
weights_only: bool = True,
|
||||||
weights_only=True,
|
|
||||||
):
|
):
|
||||||
# Get all the keys of the state dicts that we have to initialize the model
|
# Get all the keys of the state dicts that we have to initialize the model
|
||||||
if sharded_metadata is not None:
|
if sharded_metadata is not None:
|
||||||
|
|
@ -4667,16 +4700,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
||||||
is_offloaded_safetensors = False
|
is_offloaded_safetensors = False
|
||||||
|
|
||||||
if is_from_file and device_map is not None and "disk" in device_map.values():
|
if device_map is not None and "disk" in device_map.values():
|
||||||
is_offloaded_safetensors = checkpoint_files[0].endswith(".safetensors")
|
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
|
||||||
if offload_folder is None and not is_offloaded_safetensors:
|
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||||
" offers the weights in this format."
|
" offers the weights in this format."
|
||||||
)
|
)
|
||||||
if offload_folder is not None:
|
if disk_offload_folder is not None:
|
||||||
os.makedirs(offload_folder, exist_ok=True)
|
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||||
if offload_state_dict is None:
|
if offload_state_dict is None:
|
||||||
offload_state_dict = True
|
offload_state_dict = True
|
||||||
|
|
||||||
|
|
@ -4693,9 +4726,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
|
||||||
|
|
||||||
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
||||||
|
|
||||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) if len(prefix) > 0 else False
|
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
|
||||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
|
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
|
||||||
|
|
||||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
||||||
|
|
@ -4713,7 +4746,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
cls,
|
cls,
|
||||||
model,
|
model,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
loaded_keys,
|
renamed_loaded_keys,
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
add_prefix_to_model,
|
add_prefix_to_model,
|
||||||
hf_quantizer,
|
hf_quantizer,
|
||||||
|
|
@ -4728,10 +4761,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
)
|
)
|
||||||
|
|
||||||
# correctly initialize the missing keys
|
# correctly initialize the missing keys
|
||||||
if _fast_init:
|
model = initialize_missing_keys(
|
||||||
model = initialize_missing_keys(
|
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||||
model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Set some modules to fp32 if needed
|
# Set some modules to fp32 if needed
|
||||||
if keep_in_fp32_modules is not None:
|
if keep_in_fp32_modules is not None:
|
||||||
|
|
@ -4748,7 +4780,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||||
model_to_load = getattr(model, cls.base_model_prefix)
|
model_to_load = getattr(model, cls.base_model_prefix)
|
||||||
base_model_expected_keys = list(model_to_load.state_dict().keys())
|
base_model_expected_keys = list(model_to_load.state_dict().keys())
|
||||||
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
|
if any(
|
||||||
|
key in expected_keys_not_prefixed and key not in base_model_expected_keys
|
||||||
|
for key in renamed_loaded_keys
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
||||||
"properly saved?"
|
"properly saved?"
|
||||||
|
|
@ -4762,8 +4797,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
else:
|
else:
|
||||||
folder = None
|
folder = None
|
||||||
|
|
||||||
# In case we need to offload to disk, compute the index
|
# This offload index if for params explicitly on the "disk" in the device_map
|
||||||
offload_index = None
|
disk_offload_index = None
|
||||||
if is_offloaded_safetensors:
|
if is_offloaded_safetensors:
|
||||||
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
||||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||||
|
|
@ -4771,21 +4806,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
||||||
else:
|
else:
|
||||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||||
offload_index = {
|
disk_offload_index = {
|
||||||
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||||
for p, f in weight_map.items()
|
for p, f in weight_map.items()
|
||||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||||
}
|
}
|
||||||
elif device_map is not None and "disk" in device_map.values():
|
elif device_map is not None and "disk" in device_map.values():
|
||||||
offload_index = {}
|
disk_offload_index = {}
|
||||||
|
|
||||||
state_dict_folder = None
|
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
||||||
state_dict_index = None
|
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
||||||
|
# i.e. 1x to load it, and 1x to copy it to model
|
||||||
|
cpu_offload_folder = None
|
||||||
|
cpu_offload_index = None
|
||||||
if offload_state_dict:
|
if offload_state_dict:
|
||||||
state_dict_folder = tempfile.mkdtemp()
|
cpu_offload_folder = tempfile.mkdtemp()
|
||||||
state_dict_index = {}
|
cpu_offload_index = {}
|
||||||
|
|
||||||
# Find checkpoint files containing only weights offloaded to disk if any
|
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
|
||||||
disk_only_shard_files = []
|
disk_only_shard_files = []
|
||||||
if is_sharded_offloaded_safetensors:
|
if is_sharded_offloaded_safetensors:
|
||||||
disk_only_shard_files = get_disk_only_shard_files(
|
disk_only_shard_files = get_disk_only_shard_files(
|
||||||
|
|
@ -4830,7 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
mismatched_keys += find_mismatched_keys(
|
mismatched_keys += find_mismatched_keys(
|
||||||
state_dict,
|
state_dict,
|
||||||
model_state_dict,
|
model_state_dict,
|
||||||
loaded_keys,
|
renamed_loaded_keys,
|
||||||
loaded_state_dict_keys,
|
loaded_state_dict_keys,
|
||||||
add_prefix_to_model,
|
add_prefix_to_model,
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
|
|
@ -4846,16 +4884,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
fixed_state_dict,
|
fixed_state_dict,
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
offload_folder=offload_folder,
|
disk_offload_folder=disk_offload_folder,
|
||||||
offload_index=offload_index,
|
disk_offload_index=disk_offload_index,
|
||||||
state_dict_folder=state_dict_folder,
|
cpu_offload_folder=cpu_offload_folder,
|
||||||
state_dict_index=state_dict_index,
|
cpu_offload_index=cpu_offload_index,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
is_safetensors=is_offloaded_safetensors,
|
is_safetensors=is_offloaded_safetensors,
|
||||||
|
|
@ -4878,25 +4916,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
del state_dict
|
del state_dict
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if offload_index is not None and len(offload_index) > 0:
|
# Adjust offloaded weights name and save if needed
|
||||||
|
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||||
if model != model_to_load:
|
if model != model_to_load:
|
||||||
# We need to add the prefix of the base model
|
# We need to add the prefix of the base model
|
||||||
prefix = cls.base_model_prefix
|
prefix = cls.base_model_prefix
|
||||||
if not is_offloaded_safetensors:
|
if not is_offloaded_safetensors:
|
||||||
for weight_name in offload_index:
|
for weight_name in disk_offload_index:
|
||||||
shutil.move(
|
shutil.move(
|
||||||
os.path.join(offload_folder, f"{weight_name}.dat"),
|
os.path.join(disk_offload_folder, f"{weight_name}.dat"),
|
||||||
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
|
os.path.join(disk_offload_folder, f"{prefix}.{weight_name}.dat"),
|
||||||
)
|
)
|
||||||
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
|
disk_offload_index = {f"{prefix}.{key}": value for key, value in disk_offload_index.items()}
|
||||||
if not is_offloaded_safetensors:
|
if not is_offloaded_safetensors:
|
||||||
save_offload_index(offload_index, offload_folder)
|
save_offload_index(disk_offload_index, disk_offload_folder)
|
||||||
offload_index = None
|
disk_offload_index = None
|
||||||
|
|
||||||
|
# 1-by-1 param loading for the cpu params
|
||||||
if offload_state_dict:
|
if offload_state_dict:
|
||||||
# Load back temporarily offloaded state dict
|
# Load back temporarily offloaded state dict
|
||||||
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
|
load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)
|
||||||
shutil.rmtree(state_dict_folder)
|
shutil.rmtree(cpu_offload_folder)
|
||||||
|
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
error_msg = "\n\t".join(error_msgs)
|
error_msg = "\n\t".join(error_msgs)
|
||||||
|
|
@ -4947,7 +4987,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
" to use it for predictions and inference."
|
" to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
|
|
||||||
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
|
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_from_tf(cls, model, config, checkpoint_files):
|
def _load_from_tf(cls, model, config, checkpoint_files):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue