mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
keep improving
This commit is contained in:
parent
574e3f76c9
commit
a3401c3e23
1 changed files with 57 additions and 83 deletions
|
|
@ -1218,7 +1218,11 @@ def get_torch_dtype(
|
||||||
weights_only: bool,
|
weights_only: bool,
|
||||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||||
infered dtype.
|
infered dtype. We do the following:
|
||||||
|
1. If torch_dtype is not None, we use that dtype
|
||||||
|
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||||
|
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||||
|
we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||||
"""
|
"""
|
||||||
dtype_orig = None
|
dtype_orig = None
|
||||||
is_sharded = sharded_metadata is not None
|
is_sharded = sharded_metadata is not None
|
||||||
|
|
@ -1355,7 +1359,6 @@ def find_missing_and_unexpected_keys(
|
||||||
remove_prefix_from_model: bool,
|
remove_prefix_from_model: bool,
|
||||||
add_prefix_to_model: bool,
|
add_prefix_to_model: bool,
|
||||||
hf_quantizer: Optional[HfQuantizer],
|
hf_quantizer: Optional[HfQuantizer],
|
||||||
device_map: Optional[Dict],
|
|
||||||
) -> Tuple["PreTrainedModel", List[str], List[str]]:
|
) -> Tuple["PreTrainedModel", List[str], List[str]]:
|
||||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||||
|
|
@ -1381,17 +1384,7 @@ def find_missing_and_unexpected_keys(
|
||||||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||||
|
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
tied_params = find_tied_parameters(model)
|
||||||
ptrs = collections.defaultdict(list)
|
|
||||||
for name, tensor in model.state_dict().items():
|
|
||||||
id_tensor = id_tensor_storage(tensor)
|
|
||||||
ptrs[id_tensor].append(name)
|
|
||||||
|
|
||||||
# These are all the pointers of shared tensors.
|
|
||||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
|
||||||
else:
|
|
||||||
# id function doesn't work for meta tensor so we need this function
|
|
||||||
tied_params = find_tied_parameters(model)
|
|
||||||
|
|
||||||
for group in tied_params:
|
for group in tied_params:
|
||||||
if remove_prefix_from_model:
|
if remove_prefix_from_model:
|
||||||
|
|
@ -4384,11 +4377,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
"tensors"
|
"tensors"
|
||||||
]
|
]
|
||||||
|
|
||||||
# set dtype to instantiate the model under:
|
# Find the correct dtype based on current state
|
||||||
# 1. If torch_dtype is not None, we use that dtype
|
|
||||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
|
||||||
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
|
||||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
|
||||||
config, torch_dtype, dtype_orig = get_torch_dtype(
|
config, torch_dtype, dtype_orig = get_torch_dtype(
|
||||||
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
|
cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
|
||||||
)
|
)
|
||||||
|
|
@ -4683,39 +4672,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
# Some useful flags
|
# Some useful flags
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
|
||||||
is_offloaded_safetensors = False
|
|
||||||
|
|
||||||
if device_map is not None and "disk" in device_map.values():
|
|
||||||
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
|
|
||||||
if disk_offload_folder is None and not is_offloaded_safetensors:
|
|
||||||
raise ValueError(
|
|
||||||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
|
||||||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
|
||||||
" offers the weights in this format."
|
|
||||||
)
|
|
||||||
if disk_offload_folder is not None:
|
|
||||||
os.makedirs(disk_offload_folder, exist_ok=True)
|
|
||||||
if offload_state_dict is None:
|
|
||||||
offload_state_dict = True
|
|
||||||
|
|
||||||
is_sharded_offloaded_safetensors = is_offloaded_safetensors and sharded_metadata is not None
|
|
||||||
|
|
||||||
# tie the model weights before retrieving the state_dict
|
# tie the model weights before retrieving the state_dict
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
# Update model keys (expected keys) and loaded keys based on prefix, quantization, etc...
|
# Update model keys (expected keys) and loaded keys based on prefix, quantization, etc...
|
||||||
|
prefix = model.base_model_prefix
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
prefix = model.base_model_prefix
|
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys)
|
||||||
|
|
||||||
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
||||||
|
|
||||||
|
# Check if we need to modify prefix
|
||||||
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
|
has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False
|
||||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
|
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False
|
||||||
|
|
||||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
||||||
add_prefix_to_model = has_prefix_module and not expects_prefix_module
|
add_prefix_to_model = has_prefix_module and not expects_prefix_module
|
||||||
|
|
||||||
|
|
@ -4726,7 +4697,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
elif add_prefix_to_model:
|
elif add_prefix_to_model:
|
||||||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||||
|
|
||||||
# Find missing and unexpected keys from the state dict, to later log them
|
# Find missing and unexpected keys from the state dict
|
||||||
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
|
model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys(
|
||||||
cls,
|
cls,
|
||||||
model,
|
model,
|
||||||
|
|
@ -4735,7 +4706,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
add_prefix_to_model,
|
add_prefix_to_model,
|
||||||
hf_quantizer,
|
hf_quantizer,
|
||||||
device_map,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
|
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
|
||||||
|
|
@ -4744,6 +4714,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
model = move_missing_keys_back_to_cpu(
|
model = move_missing_keys_back_to_cpu(
|
||||||
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
|
model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer
|
||||||
)
|
)
|
||||||
|
# In this case we also need to move everything back
|
||||||
|
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||||
|
for key, param in model.state_dict().items():
|
||||||
|
if param.device == torch.device("meta"):
|
||||||
|
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||||
|
|
||||||
# correctly initialize the missing keys
|
# correctly initialize the missing keys
|
||||||
model = initialize_missing_keys(
|
model = initialize_missing_keys(
|
||||||
|
|
@ -4776,28 +4751,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
|
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
|
||||||
|
|
||||||
# Find the folder where the checkpoints reside
|
is_offloaded_safetensors = False
|
||||||
if checkpoint_files is not None:
|
|
||||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
|
||||||
else:
|
|
||||||
folder = None
|
|
||||||
|
|
||||||
# This offload index if for params explicitly on the "disk" in the device_map
|
# This offload index if for params explicitly on the "disk" in the device_map
|
||||||
disk_offload_index = None
|
disk_offload_index = None
|
||||||
if is_offloaded_safetensors:
|
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
|
||||||
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
disk_only_shard_files = []
|
||||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
if device_map is not None and "disk" in device_map.values():
|
||||||
if sharded_metadata is None:
|
is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors")
|
||||||
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
if disk_offload_folder is None and not is_offloaded_safetensors:
|
||||||
|
raise ValueError(
|
||||||
|
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
|
||||||
|
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
||||||
|
" offers the weights in this format."
|
||||||
|
)
|
||||||
|
if disk_offload_folder is not None:
|
||||||
|
os.makedirs(disk_offload_folder, exist_ok=True)
|
||||||
|
if offload_state_dict is None:
|
||||||
|
offload_state_dict = True
|
||||||
|
if is_offloaded_safetensors:
|
||||||
|
param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix)
|
||||||
|
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||||
|
if sharded_metadata is None:
|
||||||
|
weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys}
|
||||||
|
else:
|
||||||
|
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||||
|
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||||
|
# Find potential checkpoints containing only offloaded weights
|
||||||
|
disk_only_shard_files = get_disk_only_shard_files(
|
||||||
|
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
|
||||||
|
)
|
||||||
|
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||||
|
disk_offload_index = {
|
||||||
|
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||||
|
for p, f in weight_map.items()
|
||||||
|
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
disk_offload_index = {}
|
||||||
disk_offload_index = {
|
|
||||||
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
|
||||||
for p, f in weight_map.items()
|
|
||||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
|
||||||
}
|
|
||||||
elif device_map is not None and "disk" in device_map.values():
|
|
||||||
disk_offload_index = {}
|
|
||||||
|
|
||||||
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
||||||
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
||||||
|
|
@ -4808,14 +4798,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
cpu_offload_folder = tempfile.mkdtemp()
|
cpu_offload_folder = tempfile.mkdtemp()
|
||||||
cpu_offload_index = {}
|
cpu_offload_index = {}
|
||||||
|
|
||||||
# Find checkpoint files containing only weights offloaded to disk if any (allows to be faster)
|
|
||||||
disk_only_shard_files = []
|
|
||||||
if is_sharded_offloaded_safetensors:
|
|
||||||
disk_only_shard_files = get_disk_only_shard_files(
|
|
||||||
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
|
|
||||||
)
|
|
||||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
|
||||||
|
|
||||||
map_location = None
|
map_location = None
|
||||||
if (
|
if (
|
||||||
device_map is not None
|
device_map is not None
|
||||||
|
|
@ -4829,24 +4811,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
if checkpoint_files is not None and len(checkpoint_files) > 1:
|
if checkpoint_files is not None and len(checkpoint_files) > 1:
|
||||||
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
|
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
|
||||||
|
|
||||||
# To be able to iterate, even if we don't use it
|
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||||
checkpoint_files = checkpoint_files if state_dict is None else [""]
|
checkpoint_files = checkpoint_files if state_dict is None else [""]
|
||||||
|
|
||||||
assign_to_params_buffers = None
|
|
||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
# Iterate on all the shards to load the weights
|
# Iterate on all the shards to load the weights
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
# Skip the load for shards that only contain disk-offloaded weights
|
||||||
if shard_file in disk_only_shard_files:
|
if shard_file in disk_only_shard_files:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If shard_file == "", we use the existing state_dict
|
# If shard_file == "", we use the existing state_dict instead of loading it
|
||||||
if shard_file != "":
|
if shard_file != "":
|
||||||
state_dict = load_state_dict(
|
state_dict = load_state_dict(
|
||||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||||
)
|
)
|
||||||
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
|
|
||||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
# matching the weights in the model.
|
# matching the weights in the model.
|
||||||
|
|
@ -4862,13 +4843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
)
|
)
|
||||||
if low_cpu_mem_usage or gguf_file is not None:
|
if low_cpu_mem_usage or gguf_file is not None:
|
||||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||||
for key, param in model_to_load.state_dict().items():
|
pass
|
||||||
if param.device == torch.device("meta"):
|
|
||||||
set_module_tensor_to_device(
|
|
||||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
|
||||||
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
fixed_state_dict,
|
fixed_state_dict,
|
||||||
|
|
@ -4888,11 +4864,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
error_msgs += new_error_msgs
|
error_msgs += new_error_msgs
|
||||||
else:
|
else:
|
||||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||||
if assign_to_params_buffers is None:
|
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
model_to_load, state_dict, start_prefix
|
||||||
model_to_load, state_dict, start_prefix
|
)
|
||||||
)
|
|
||||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
|
||||||
error_msgs += _load_state_dict_into_model(
|
error_msgs += _load_state_dict_into_model(
|
||||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue