mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
keep improving
This commit is contained in:
parent
574e3f76c9
commit
a3401c3e23
1 changed files with 57 additions and 83 deletions
|
|
@ -1218,7 +1218,11 @@ def get_torch_dtype(
|
|||
weights_only: bool,
|
||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||
infered dtype.
|
||||
infered dtype. We do the following:
|
||||
1. If torch_dtype is not None, we use that dtype
|
||||
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||
we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
"""
|
||||
dtype_orig = None
|
||||
is_sharded = sharded_metadata is not None
|
||||
|
|
@ -1355,7 +1359,6 @@ def find_missing_and_unexpected_keys(
|
|||
remove_prefix_from_model: bool,
|
||||
add_prefix_to_model: bool,
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
device_map: Optional[Dict],
|
||||
) -> Tuple["PreTrainedModel", List[str], List[str]]:
|
||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||
|
|
@ -1381,17 +1384,7 @@ def find_missing_and_unexpected_keys(
|
|||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||
|
||||
model.tie_weights()
|
||||
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model.state_dict().items():
|
||||
id_tensor = id_tensor_storage(tensor)
|
||||
ptrs[id_tensor].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
else:
|
||||
# id function doesn't work for meta tensor so we need this function
|
||||
tied_params = find_tied_parameters(model)
|
||||
tied_params = find_tied_parameters(model)
|
||||
|
||||
for group in tied_params:
|
||||
if remove_prefix_from_model:
|
||||
|
|
@ -4384,11 +4377,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"tensors"
|
||||
]
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
# Find the correct dtype based on current state
|
||||
config, torch_dtype, dtype_orig = get_torch_dtype(
|
||||
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
|
||||
)
|
||||
|
|
@ -4683,39 +4672,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# Some useful flags
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
||||
is_offloaded_safetensors = False
|
||||
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
|
||||
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||
raise ValueError(
|
||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if disk_offload_folder is not None:
|
||||
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
|
||||
is_sharded_offloaded_safetensors = is_offloaded_safetensors and sharded_metadata is not None
|
||||
|
||||
# tie the model weights before retrieving the state_dict
|
||||
model.tie_weights()
|
||||
|
||||
# Update model keys (expected keys) and loaded keys based on prefix, quantization, etc...
|
||||
prefix = model.base_model_prefix
|
||||
model_state_dict = model.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
prefix = model.base_model_prefix
|
||||
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
|
||||
|
||||
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
||||
|
||||
# Check if we need to modify prefix
|
||||
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
|
||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
|
||||
|
||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
||||
add_prefix_to_model = has_prefix_module and not expects_prefix_module
|
||||
|
||||
|
|
@ -4726,7 +4697,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
elif add_prefix_to_model:
|
||||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||
|
||||
# Find missing and unexpected keys from the state dict, to later log them
|
||||
# Find missing and unexpected keys from the state dict
|
||||
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
|
||||
cls,
|
||||
model,
|
||||
|
|
@ -4735,7 +4706,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
remove_prefix_from_model,
|
||||
add_prefix_to_model,
|
||||
hf_quantizer,
|
||||
device_map,
|
||||
)
|
||||
|
||||
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
|
||||
|
|
@ -4744,6 +4714,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
model = move_missing_keys_back_to_cpu(
|
||||
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
|
||||
)
|
||||
# In this case we also need to move everything back
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||
for key, param in model.state_dict().items():
|
||||
if param.device == torch.device("meta"):
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
|
||||
# correctly initialize the missing keys
|
||||
model = initialize_missing_keys(
|
||||
|
|
@ -4776,28 +4751,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if device_map is not None:
|
||||
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
|
||||
|
||||
# Find the folder where the checkpoints reside
|
||||
if checkpoint_files is not None:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
else:
|
||||
folder = None
|
||||
|
||||
is_offloaded_safetensors = False
|
||||
# This offload index if for params explicitly on the "disk" in the device_map
|
||||
disk_offload_index = None
|
||||
if is_offloaded_safetensors:
|
||||
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
||||
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
|
||||
disk_only_shard_files = []
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
|
||||
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||
raise ValueError(
|
||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||
" offers the weights in this format."
|
||||
)
|
||||
if disk_offload_folder is not None:
|
||||
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
if is_offloaded_safetensors:
|
||||
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(
|
||||
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
|
||||
)
|
||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||
disk_offload_index = {
|
||||
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in weight_map.items()
|
||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||
}
|
||||
else:
|
||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||
disk_offload_index = {
|
||||
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in weight_map.items()
|
||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||
}
|
||||
elif device_map is not None and "disk" in device_map.values():
|
||||
disk_offload_index = {}
|
||||
disk_offload_index = {}
|
||||
|
||||
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
||||
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
||||
|
|
@ -4808,14 +4798,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
cpu_offload_folder = tempfile.mkdtemp()
|
||||
cpu_offload_index = {}
|
||||
|
||||
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
|
||||
disk_only_shard_files = []
|
||||
if is_sharded_offloaded_safetensors:
|
||||
disk_only_shard_files = get_disk_only_shard_files(
|
||||
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
|
||||
)
|
||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||
|
||||
map_location = None
|
||||
if (
|
||||
device_map is not None
|
||||
|
|
@ -4829,24 +4811,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if checkpoint_files is not None and len(checkpoint_files) > 1:
|
||||
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
|
||||
|
||||
# To be able to iterate, even if we don't use it
|
||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||
checkpoint_files = checkpoint_files if state_dict is None else [""]
|
||||
|
||||
assign_to_params_buffers = None
|
||||
|
||||
error_msgs = []
|
||||
mismatched_keys = []
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
if shard_file in disk_only_shard_files:
|
||||
continue
|
||||
|
||||
# If shard_file == "", we use the existing state_dict
|
||||
# If shard_file == "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
|
|
@ -4862,13 +4843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
if low_cpu_mem_usage or gguf_file is not None:
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||
for key, param in model_to_load.state_dict().items():
|
||||
if param.device == torch.device("meta"):
|
||||
set_module_tensor_to_device(
|
||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
pass
|
||||
else:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
fixed_state_dict,
|
||||
|
|
@ -4888,11 +4864,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
error_msgs += new_error_msgs
|
||||
else:
|
||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||
if assign_to_params_buffers is None:
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
error_msgs += _load_state_dict_into_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue