add type hints/docstring

This commit is contained in:
Cyril Vallez 2025-02-05 13:52:57 +01:00
parent bbab9b26e0
commit 11e378024d
No known key found for this signature in database

View file

@ -752,22 +752,21 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model( def _load_state_dict_into_meta_model(
model, model: "PreTrainedModel",
state_dict, state_dict: Dict,
start_prefix, start_prefix: str,
expected_keys, expected_keys: Dict,
device_map=None, device_map: Optional[Dict] = None,
offload_folder=None, disk_offload_folder: Optional[str] = None,
offload_index=None, disk_offload_index: Optional[Dict] = None,
state_dict_folder=None, cpu_offload_folder: Optional[str] = None,
state_dict_index=None, cpu_offload_index: Optional[Dict] = None,
dtype=None, dtype: Optional[torch.dtype] = None,
hf_quantizer=None, hf_quantizer: Optional[HfQuantizer] = None,
is_safetensors=False, is_safetensors: bool = False,
keep_in_fp32_modules=None, keep_in_fp32_modules: Optional[List[str]] = None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items unexpected_keys: Optional[Dict] = None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys ) -> Tuple[List[str], Optional[Dict], Optional[Dict]]:
):
""" """
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
@ -857,9 +856,9 @@ def _load_state_dict_into_meta_model(
if param_device == "disk": if param_device == "disk":
if not is_safetensors: if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index) disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
elif param_device == "cpu" and state_dict_index is not None: elif param_device == "cpu" and cpu_offload_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index)
elif ( elif (
not is_quantized not is_quantized
or (not hf_quantizer.requires_parameters_quantization) or (not hf_quantizer.requires_parameters_quantization)
@ -892,7 +891,7 @@ def _load_state_dict_into_meta_model(
setattr(module, tensor_name, value) setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return # TODO: consider removing used param_parts from state_dict before return
return error_msgs, offload_index, state_dict_index return error_msgs, disk_offload_index, cpu_offload_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
@ -905,7 +904,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
def get_checkpoint_files( def get_checkpoint_files(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
subfolder: str, subfolder: str,
variant: Optional[str], variant: Optional[str],
gguf_file: Optional[str], gguf_file: Optional[str],
@ -920,7 +919,7 @@ def get_checkpoint_files(
user_agent: dict, user_agent: dict,
revision: str, revision: str,
commit_hash: Optional[str], commit_hash: Optional[str],
) -> Tuple[List[str], Dict]: ) -> Tuple[Optional[List[str]], Optional[Dict]]:
is_sharded = False is_sharded = False
if pretrained_model_name_or_path is not None and gguf_file is None: if pretrained_model_name_or_path is not None and gguf_file is None:
@ -1216,12 +1215,15 @@ def get_checkpoint_files(
def get_torch_dtype( def get_torch_dtype(
cls, cls,
torch_dtype: Optional[Union[str, torch.dtype, Dict]], torch_dtype: Optional[Union[str, torch.dtype, Dict]],
checkpoint_files: List[str], checkpoint_files: Optional[List[str]],
config: PretrainedConfig, config: PretrainedConfig,
sharded_metadata: Optional[Dict], sharded_metadata: Optional[Dict],
state_dict: Optional[Dict], state_dict: Optional[Dict],
weights_only: bool, weights_only: bool,
): ) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
infered dtype.
"""
dtype_orig = None dtype_orig = None
is_sharded = sharded_metadata is not None is_sharded = sharded_metadata is not None
@ -1282,7 +1284,15 @@ def get_torch_dtype(
return config, torch_dtype, dtype_orig return config, torch_dtype, dtype_orig
def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules): def get_device_map(
model: "PreTrainedModel",
device_map: Optional[Union[str, Dict]],
max_memory: Optional[Dict],
hf_quantizer: Optional[HfQuantizer],
torch_dtype: Optional[torch.dtype],
keep_in_fp32_modules: Optional[List[str]],
) -> Tuple["PreTrainedModel", Dict]:
"""Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies."""
if isinstance(device_map, str): if isinstance(device_map, str):
special_dtypes = {} special_dtypes = {}
if hf_quantizer is not None: if hf_quantizer is not None:
@ -1342,8 +1352,18 @@ def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, kee
def find_missing_and_unexpected_keys( def find_missing_and_unexpected_keys(
cls, model, expected_keys, loaded_keys, remove_prefix_from_model, add_prefix_to_model, hf_quantizer, device_map cls,
): model: "PreTrainedModel",
expected_keys: List[str],
loaded_keys: List[str],
remove_prefix_from_model: bool,
add_prefix_to_model: bool,
hf_quantizer: Optional[HfQuantizer],
device_map: Optional[Dict],
) -> Tuple["PreTrainedModel", List[str], List[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
"""
prefix = model.base_model_prefix prefix = model.base_model_prefix
_prefix = f"{prefix}." _prefix = f"{prefix}."
@ -1362,7 +1382,7 @@ def find_missing_and_unexpected_keys(
# Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858) # Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858)
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
if has_inv_freq_buffers: if has_inv_freq_buffers:
unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k} unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
model.tie_weights() model.tie_weights()
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
@ -1402,14 +1422,17 @@ def find_missing_and_unexpected_keys(
def move_missing_keys_back_to_cpu( def move_missing_keys_back_to_cpu(
model, model: "PreTrainedModel",
missing_keys, missing_keys: List[str],
unexpected_keys, unexpected_keys: List[str],
dtype, dtype: Optional[torch.dtype],
keep_in_fp32_modules, keep_in_fp32_modules: Optional[List[str]],
is_quantized, hf_quantizer: Optional[HfQuantizer],
hf_quantizer, ) -> "PreTrainedModel":
): """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
from meta device to cpu.
"""
is_quantized = hf_quantizer is not None
prefix = model.base_model_prefix prefix = model.base_model_prefix
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
@ -1446,8 +1469,18 @@ def move_missing_keys_back_to_cpu(
def initialize_missing_keys( def initialize_missing_keys(
model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized model: "PreTrainedModel",
): loaded_keys: List[str],
ignore_mismatched_sizes: bool,
has_prefix_module: bool,
expects_prefix_module: bool,
is_quantized: bool,
) -> "PreTrainedModel":
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
`_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to
be initialized correctly (i.e. weight initialization distribution).
Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
"""
prefix = model.base_model_prefix prefix = model.base_model_prefix
remove_prefix_from_model = not has_prefix_module and expects_prefix_module remove_prefix_from_model = not has_prefix_module and expects_prefix_module
@ -1493,18 +1526,20 @@ def initialize_missing_keys(
def find_mismatched_keys( def find_mismatched_keys(
state_dict, state_dict: Dict,
model_state_dict, model_state_dict: Dict,
loaded_keys, renamed_loaded_keys: List[str],
original_loaded_keys, original_loaded_keys: List[str],
add_prefix_to_model, add_prefix_to_model: bool,
remove_prefix_from_model, remove_prefix_from_model: bool,
ignore_mismatched_sizes, ignore_mismatched_sizes: bool,
prefix, prefix: str,
): ) -> List:
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
mismatched_keys = [] mismatched_keys = []
if ignore_mismatched_sizes: if ignore_mismatched_sizes:
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys): for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys):
# If the checkpoint is sharded, we may not have the key here. # If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict: if checkpoint_key not in state_dict:
continue continue
@ -4459,10 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
pretrained_model_name_or_path, pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata, sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map, device_map=device_map,
offload_folder=offload_folder, disk_offload_folder=offload_folder,
offload_state_dict=offload_state_dict, offload_state_dict=offload_state_dict,
dtype=torch_dtype, dtype=torch_dtype,
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
@ -4637,22 +4671,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@classmethod @classmethod
def _load_pretrained_model( def _load_pretrained_model(
cls, cls,
model, model: "PreTrainedModel",
state_dict, state_dict: Optional[Dict],
checkpoint_files, checkpoint_files: Optional[List[str]],
pretrained_model_name_or_path, pretrained_model_name_or_path: Optional[str],
ignore_mismatched_sizes=False, ignore_mismatched_sizes: bool = False,
sharded_metadata=None, sharded_metadata: Optional[Dict] = None,
_fast_init=True, low_cpu_mem_usage: bool = False,
low_cpu_mem_usage=False, device_map: Optional[Dict] = None,
device_map=None, disk_offload_folder: Optional[str] = None,
offload_folder=None, offload_state_dict: Optional[bool] = None,
offload_state_dict=None, dtype: Optional[torch.dtype] = None,
dtype=None, hf_quantizer: Optional[HfQuantizer] = None,
hf_quantizer=None, keep_in_fp32_modules: Optional[List[str]] = None,
keep_in_fp32_modules=None, gguf_file: Optional[str] = None,
gguf_file=None, weights_only: bool = True,
weights_only=True,
): ):
# Get all the keys of the state dicts that we have to initialize the model # Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None: if sharded_metadata is not None:
@ -4667,16 +4700,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
is_offloaded_safetensors = False is_offloaded_safetensors = False
if is_from_file and device_map is not None and "disk" in device_map.values(): if device_map is not None and "disk" in device_map.values():
is_offloaded_safetensors = checkpoint_files[0].endswith(".safetensors") is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
if offload_folder is None and not is_offloaded_safetensors: if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError( raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using" " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format." " offers the weights in this format."
) )
if offload_folder is not None: if disk_offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True) os.makedirs(disk_offload_folder, exist_ok=True)
if offload_state_dict is None: if offload_state_dict is None:
offload_state_dict = True offload_state_dict = True
@ -4693,9 +4726,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hf_quantizer is not None: if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys) expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys] renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) if len(prefix) > 0 else False has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
remove_prefix_from_model = not has_prefix_module and expects_prefix_module remove_prefix_from_model = not has_prefix_module and expects_prefix_module
@ -4713,7 +4746,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cls, cls,
model, model,
expected_keys, expected_keys,
loaded_keys, renamed_loaded_keys,
remove_prefix_from_model, remove_prefix_from_model,
add_prefix_to_model, add_prefix_to_model,
hf_quantizer, hf_quantizer,
@ -4728,10 +4761,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
# correctly initialize the missing keys # correctly initialize the missing keys
if _fast_init: model = initialize_missing_keys(
model = initialize_missing_keys( model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized )
)
# Set some modules to fp32 if needed # Set some modules to fp32 if needed
if keep_in_fp32_modules is not None: if keep_in_fp32_modules is not None:
@ -4748,7 +4780,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix) model_to_load = getattr(model, cls.base_model_prefix)
base_model_expected_keys = list(model_to_load.state_dict().keys()) base_model_expected_keys = list(model_to_load.state_dict().keys())
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): if any(
key in expected_keys_not_prefixed and key not in base_model_expected_keys
for key in renamed_loaded_keys
):
raise ValueError( raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was " "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?" "properly saved?"
@ -4762,8 +4797,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
folder = None folder = None
# In case we need to offload to disk, compute the index # This offload index if for params explicitly on the "disk" in the device_map
offload_index = None disk_offload_index = None
if is_offloaded_safetensors: if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix) param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
@ -4771,21 +4806,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys} weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
else: else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = { disk_offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items() for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
} }
elif device_map is not None and "disk" in device_map.values(): elif device_map is not None and "disk" in device_map.values():
offload_index = {} disk_offload_index = {}
state_dict_folder = None # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
state_dict_index = None # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
# i.e. 1x to load it, and 1x to copy it to model
cpu_offload_folder = None
cpu_offload_index = None
if offload_state_dict: if offload_state_dict:
state_dict_folder = tempfile.mkdtemp() cpu_offload_folder = tempfile.mkdtemp()
state_dict_index = {} cpu_offload_index = {}
# Find checkpoint files containing only weights offloaded to disk if any # Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
disk_only_shard_files = [] disk_only_shard_files = []
if is_sharded_offloaded_safetensors: if is_sharded_offloaded_safetensors:
disk_only_shard_files = get_disk_only_shard_files( disk_only_shard_files = get_disk_only_shard_files(
@ -4830,7 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys += find_mismatched_keys( mismatched_keys += find_mismatched_keys(
state_dict, state_dict,
model_state_dict, model_state_dict,
loaded_keys, renamed_loaded_keys,
loaded_state_dict_keys, loaded_state_dict_keys,
add_prefix_to_model, add_prefix_to_model,
remove_prefix_from_model, remove_prefix_from_model,
@ -4846,16 +4884,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
else: else:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load, model_to_load,
fixed_state_dict, fixed_state_dict,
start_prefix, start_prefix,
expected_keys, expected_keys,
device_map=device_map, device_map=device_map,
offload_folder=offload_folder, disk_offload_folder=disk_offload_folder,
offload_index=offload_index, disk_offload_index=disk_offload_index,
state_dict_folder=state_dict_folder, cpu_offload_folder=cpu_offload_folder,
state_dict_index=state_dict_index, cpu_offload_index=cpu_offload_index,
dtype=dtype, dtype=dtype,
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors, is_safetensors=is_offloaded_safetensors,
@ -4878,25 +4916,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del state_dict del state_dict
gc.collect() gc.collect()
if offload_index is not None and len(offload_index) > 0: # Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:
if model != model_to_load: if model != model_to_load:
# We need to add the prefix of the base model # We need to add the prefix of the base model
prefix = cls.base_model_prefix prefix = cls.base_model_prefix
if not is_offloaded_safetensors: if not is_offloaded_safetensors:
for weight_name in offload_index: for weight_name in disk_offload_index:
shutil.move( shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"), os.path.join(disk_offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), os.path.join(disk_offload_folder, f"{prefix}.{weight_name}.dat"),
) )
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} disk_offload_index = {f"{prefix}.{key}": value for key, value in disk_offload_index.items()}
if not is_offloaded_safetensors: if not is_offloaded_safetensors:
save_offload_index(offload_index, offload_folder) save_offload_index(disk_offload_index, disk_offload_folder)
offload_index = None disk_offload_index = None
# 1-by-1 param loading for the cpu params
if offload_state_dict: if offload_state_dict:
# Load back temporarily offloaded state dict # Load back temporarily offloaded state dict
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)
shutil.rmtree(state_dict_folder) shutil.rmtree(cpu_offload_folder)
if len(error_msgs) > 0: if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs) error_msg = "\n\t".join(error_msgs)
@ -4947,7 +4987,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" to use it for predictions and inference." " to use it for predictions and inference."
) )
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
@classmethod @classmethod
def _load_from_tf(cls, model, config, checkpoint_files): def _load_from_tf(cls, model, config, checkpoint_files):