mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add type hints/docstring
This commit is contained in:
parent
bbab9b26e0
commit
11e378024d
1 changed files with 144 additions and 104 deletions
|
|
@ -752,22 +752,21 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
|||
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=None,
|
||||
offload_folder=None,
|
||||
offload_index=None,
|
||||
state_dict_folder=None,
|
||||
state_dict_index=None,
|
||||
dtype=None,
|
||||
hf_quantizer=None,
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
||||
):
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Dict,
|
||||
start_prefix: str,
|
||||
expected_keys: Dict,
|
||||
device_map: Optional[Dict] = None,
|
||||
disk_offload_folder: Optional[str] = None,
|
||||
disk_offload_index: Optional[Dict] = None,
|
||||
cpu_offload_folder: Optional[str] = None,
|
||||
cpu_offload_index: Optional[Dict] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
is_safetensors: bool = False,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
unexpected_keys: Optional[Dict] = None, # passing `unexpected` for cleanup from quantization items
|
||||
) -> Tuple[List[str], Optional[Dict], Optional[Dict]]:
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
||||
|
|
@ -857,9 +856,9 @@ def _load_state_dict_into_meta_model(
|
|||
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
|
||||
elif param_device == "cpu" and cpu_offload_index is not None:
|
||||
cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index)
|
||||
elif (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
|
|
@ -892,7 +891,7 @@ def _load_state_dict_into_meta_model(
|
|||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
return error_msgs, disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
|
|
@ -905,7 +904,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||
|
||||
|
||||
def get_checkpoint_files(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
subfolder: str,
|
||||
variant: Optional[str],
|
||||
gguf_file: Optional[str],
|
||||
|
|
@ -920,7 +919,7 @@ def get_checkpoint_files(
|
|||
user_agent: dict,
|
||||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
) -> Tuple[List[str], Dict]:
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
is_sharded = False
|
||||
|
||||
if pretrained_model_name_or_path is not None and gguf_file is None:
|
||||
|
|
@ -1216,12 +1215,15 @@ def get_checkpoint_files(
|
|||
def get_torch_dtype(
|
||||
cls,
|
||||
torch_dtype: Optional[Union[str, torch.dtype, Dict]],
|
||||
checkpoint_files: List[str],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
config: PretrainedConfig,
|
||||
sharded_metadata: Optional[Dict],
|
||||
state_dict: Optional[Dict],
|
||||
weights_only: bool,
|
||||
):
|
||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||
infered dtype.
|
||||
"""
|
||||
dtype_orig = None
|
||||
is_sharded = sharded_metadata is not None
|
||||
|
||||
|
|
@ -1282,7 +1284,15 @@ def get_torch_dtype(
|
|||
return config, torch_dtype, dtype_orig
|
||||
|
||||
|
||||
def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules):
|
||||
def get_device_map(
|
||||
model: "PreTrainedModel",
|
||||
device_map: Optional[Union[str, Dict]],
|
||||
max_memory: Optional[Dict],
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
keep_in_fp32_modules: Optional[List[str]],
|
||||
) -> Tuple["PreTrainedModel", Dict]:
|
||||
"""Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies."""
|
||||
if isinstance(device_map, str):
|
||||
special_dtypes = {}
|
||||
if hf_quantizer is not None:
|
||||
|
|
@ -1342,8 +1352,18 @@ def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, kee
|
|||
|
||||
|
||||
def find_missing_and_unexpected_keys(
|
||||
cls, model, expected_keys, loaded_keys, remove_prefix_from_model, add_prefix_to_model, hf_quantizer, device_map
|
||||
):
|
||||
cls,
|
||||
model: "PreTrainedModel",
|
||||
expected_keys: List[str],
|
||||
loaded_keys: List[str],
|
||||
remove_prefix_from_model: bool,
|
||||
add_prefix_to_model: bool,
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
device_map: Optional[Dict],
|
||||
) -> Tuple["PreTrainedModel", List[str], List[str]]:
|
||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||
"""
|
||||
prefix = model.base_model_prefix
|
||||
_prefix = f"{prefix}."
|
||||
|
||||
|
|
@ -1362,7 +1382,7 @@ def find_missing_and_unexpected_keys(
|
|||
# Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858)
|
||||
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
|
||||
if has_inv_freq_buffers:
|
||||
unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k}
|
||||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||
|
||||
model.tie_weights()
|
||||
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
|
|
@ -1402,14 +1422,17 @@ def find_missing_and_unexpected_keys(
|
|||
|
||||
|
||||
def move_missing_keys_back_to_cpu(
|
||||
model,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
dtype,
|
||||
keep_in_fp32_modules,
|
||||
is_quantized,
|
||||
hf_quantizer,
|
||||
):
|
||||
model: "PreTrainedModel",
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
dtype: Optional[torch.dtype],
|
||||
keep_in_fp32_modules: Optional[List[str]],
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
) -> "PreTrainedModel":
|
||||
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
|
||||
from meta device to cpu.
|
||||
"""
|
||||
is_quantized = hf_quantizer is not None
|
||||
prefix = model.base_model_prefix
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
|
|
@ -1446,8 +1469,18 @@ def move_missing_keys_back_to_cpu(
|
|||
|
||||
|
||||
def initialize_missing_keys(
|
||||
model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||
):
|
||||
model: "PreTrainedModel",
|
||||
loaded_keys: List[str],
|
||||
ignore_mismatched_sizes: bool,
|
||||
has_prefix_module: bool,
|
||||
expects_prefix_module: bool,
|
||||
is_quantized: bool,
|
||||
) -> "PreTrainedModel":
|
||||
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
|
||||
`_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to
|
||||
be initialized correctly (i.e. weight initialization distribution).
|
||||
Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
|
||||
"""
|
||||
prefix = model.base_model_prefix
|
||||
|
||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
||||
|
|
@ -1493,18 +1526,20 @@ def initialize_missing_keys(
|
|||
|
||||
|
||||
def find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
original_loaded_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
prefix,
|
||||
):
|
||||
state_dict: Dict,
|
||||
model_state_dict: Dict,
|
||||
renamed_loaded_keys: List[str],
|
||||
original_loaded_keys: List[str],
|
||||
add_prefix_to_model: bool,
|
||||
remove_prefix_from_model: bool,
|
||||
ignore_mismatched_sizes: bool,
|
||||
prefix: str,
|
||||
) -> List:
|
||||
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
|
||||
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
|
||||
for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys):
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
|
|
@ -4459,10 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
_fast_init=_fast_init,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
disk_offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
|
|
@ -4637,22 +4671,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
model,
|
||||
state_dict,
|
||||
checkpoint_files,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=False,
|
||||
sharded_metadata=None,
|
||||
_fast_init=True,
|
||||
low_cpu_mem_usage=False,
|
||||
device_map=None,
|
||||
offload_folder=None,
|
||||
offload_state_dict=None,
|
||||
dtype=None,
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
gguf_file=None,
|
||||
weights_only=True,
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Optional[Dict],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
pretrained_model_name_or_path: Optional[str],
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
sharded_metadata: Optional[Dict] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
device_map: Optional[Dict] = None,
|
||||
disk_offload_folder: Optional[str] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
gguf_file: Optional[str] = None,
|
||||
weights_only: bool = True,
|
||||
):
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
if sharded_metadata is not None:
|
||||
|
|
@ -4667,16 +4700,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
||||
is_offloaded_safetensors = False
|
||||
|
||||
if is_from_file and device_map is not None and "disk" in device_map.values():
|
||||
is_offloaded_safetensors = checkpoint_files[0].endswith(".safetensors")
|
||||
if offload_folder is None and not is_offloaded_safetensors:
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
|
||||
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||
raise ValueError(
|
||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if offload_folder is not None:
|
||||
os.makedirs(offload_folder, exist_ok=True)
|
||||
if disk_offload_folder is not None:
|
||||
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
|
||||
|
|
@ -4693,9 +4726,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
|
||||
|
||||
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
||||
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
||||
|
||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) if len(prefix) > 0 else False
|
||||
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
|
||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
|
||||
|
||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
||||
|
|
@ -4713,7 +4746,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
cls,
|
||||
model,
|
||||
expected_keys,
|
||||
loaded_keys,
|
||||
renamed_loaded_keys,
|
||||
remove_prefix_from_model,
|
||||
add_prefix_to_model,
|
||||
hf_quantizer,
|
||||
|
|
@ -4728,10 +4761,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
|
||||
# correctly initialize the missing keys
|
||||
if _fast_init:
|
||||
model = initialize_missing_keys(
|
||||
model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||
)
|
||||
model = initialize_missing_keys(
|
||||
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||
)
|
||||
|
||||
# Set some modules to fp32 if needed
|
||||
if keep_in_fp32_modules is not None:
|
||||
|
|
@ -4748,7 +4780,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||
model_to_load = getattr(model, cls.base_model_prefix)
|
||||
base_model_expected_keys = list(model_to_load.state_dict().keys())
|
||||
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
|
||||
if any(
|
||||
key in expected_keys_not_prefixed and key not in base_model_expected_keys
|
||||
for key in renamed_loaded_keys
|
||||
):
|
||||
raise ValueError(
|
||||
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
||||
"properly saved?"
|
||||
|
|
@ -4762,8 +4797,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
else:
|
||||
folder = None
|
||||
|
||||
# In case we need to offload to disk, compute the index
|
||||
offload_index = None
|
||||
# This offload index if for params explicitly on the "disk" in the device_map
|
||||
disk_offload_index = None
|
||||
if is_offloaded_safetensors:
|
||||
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
|
|
@ -4771,21 +4806,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
||||
else:
|
||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||
offload_index = {
|
||||
disk_offload_index = {
|
||||
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in weight_map.items()
|
||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||
}
|
||||
elif device_map is not None and "disk" in device_map.values():
|
||||
offload_index = {}
|
||||
disk_offload_index = {}
|
||||
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
||||
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
||||
# i.e. 1x to load it, and 1x to copy it to model
|
||||
cpu_offload_folder = None
|
||||
cpu_offload_index = None
|
||||
if offload_state_dict:
|
||||
state_dict_folder = tempfile.mkdtemp()
|
||||
state_dict_index = {}
|
||||
cpu_offload_folder = tempfile.mkdtemp()
|
||||
cpu_offload_index = {}
|
||||
|
||||
# Find checkpoint files containing only weights offloaded to disk if any
|
||||
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
|
||||
disk_only_shard_files = []
|
||||
if is_sharded_offloaded_safetensors:
|
||||
disk_only_shard_files = get_disk_only_shard_files(
|
||||
|
|
@ -4830,7 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
mismatched_keys += find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
renamed_loaded_keys,
|
||||
loaded_state_dict_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
|
|
@ -4846,16 +4884,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
else:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
fixed_state_dict,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
state_dict_folder=state_dict_folder,
|
||||
state_dict_index=state_dict_index,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
dtype=dtype,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
|
|
@ -4878,25 +4916,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
# Adjust offloaded weights name and save if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||
if model != model_to_load:
|
||||
# We need to add the prefix of the base model
|
||||
prefix = cls.base_model_prefix
|
||||
if not is_offloaded_safetensors:
|
||||
for weight_name in offload_index:
|
||||
for weight_name in disk_offload_index:
|
||||
shutil.move(
|
||||
os.path.join(offload_folder, f"{weight_name}.dat"),
|
||||
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
|
||||
os.path.join(disk_offload_folder, f"{weight_name}.dat"),
|
||||
os.path.join(disk_offload_folder, f"{prefix}.{weight_name}.dat"),
|
||||
)
|
||||
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
|
||||
disk_offload_index = {f"{prefix}.{key}": value for key, value in disk_offload_index.items()}
|
||||
if not is_offloaded_safetensors:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
offload_index = None
|
||||
save_offload_index(disk_offload_index, disk_offload_folder)
|
||||
disk_offload_index = None
|
||||
|
||||
# 1-by-1 param loading for the cpu params
|
||||
if offload_state_dict:
|
||||
# Load back temporarily offloaded state dict
|
||||
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
|
||||
shutil.rmtree(state_dict_folder)
|
||||
load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)
|
||||
shutil.rmtree(cpu_offload_folder)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
|
|
@ -4947,7 +4987,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
|
||||
|
||||
@classmethod
|
||||
def _load_from_tf(cls, model, config, checkpoint_files):
|
||||
|
|
|
|||
Loading…
Reference in a new issue