Update modeling_utils.py

This commit is contained in:
Cyril Vallez 2025-02-07 16:24:42 +01:00
parent c341f8c6fb
commit 87e97ea54c
No known key found for this signature in database

View file

@ -894,13 +894,15 @@ def _load_state_dict_into_meta_model(
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return
# In this case, let's parallelize the modules!
# In this case, let's parallelize the modules as soon as we can!
if tp_key_registry is not None:
plan = None
prefix = None
for module_prefix in tp_key_registry.keys():
if f"{module_prefix}." in param_name:
tp_key_registry[module_prefix]["children"].remove(param_name)
# If all the children were already removed, it means that all parameters are now on the correct
# device -> we should call `parallelize_module` right now!
if len(tp_key_registry[module_prefix]["children"]) == 0:
plan = tp_key_registry[module_prefix]["plan"]
prefix = module_prefix
@ -1367,15 +1369,12 @@ def _get_device_map(
max_memory = hf_quantizer.adjust_max_memory(max_memory)
device_map_kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
if hf_quantizer is not None:
hf_quantizer.validate_environment(device_map=device_map)
elif device_map is not None:
model.tie_weights()
tied_params = find_tied_parameters(model)
# check if we don't have tied param in different devices
check_tied_parameters_on_same_device(tied_params, device_map)
@ -1453,7 +1452,6 @@ def _find_missing_and_unexpected_keys(
if has_inv_freq_buffers:
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
model.tie_weights()
if not low_cpu_mem_usage and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
@ -1615,6 +1613,72 @@ def _find_mismatched_keys(
return mismatched_keys
def _get_tp_key_registry(model_to_load: "PreTrainedModel") -> Dict[str, Dict]:
"""Create a registry of all keys of the model, and the entity under which they should be parallelized.
The strategy is to parallelize each individual Layer as a single entity, then other individual modules as
specified in the tp_plan, if they are not part of a bigger Layer module.
This strategy ensures the best tradeoff between loading speed and memory footprint while parallelizing the model.
Indeed, loading the whole model and then parallelizing creates a huge memory overhead, as each process must
first load the full model on its given device. The other extreme, parallelizing each leaf of the model (Linear layers)
as their parameters are loaded creates a huge speed overhead, due to a lot of call to
`torch.distributed.tensor.parallel.parallelize_module`. Parallelizing each layer as soon as their parameters are loaded
provides the best of both world: almost no memory overhead, and as fast (sometimes even faster) as parallelizing
the full model at once.
"""
prefix = model_to_load.base_model_prefix
is_task_specific_model = hasattr(model_to_load, prefix) if len(prefix) > 0 else False
full_tp_plan = model_to_load.config.base_model_tp_plan
if is_task_specific_model:
# Add the prefix to the base model plan
full_tp_plan = {f"{prefix}.{key}": plan for key, plan in full_tp_plan.items()}
# Add potential task-specific additional plan
full_tp_plan.update(getattr(model_to_load, "_tp_plan", {}))
# Extract full prefix before the layer numbers
layer_prefix = None
for key in full_tp_plan.keys():
if "*" in key:
# extract everything before the first "*" corresponding to layer number
layer_prefix = key.split("*", 1)[0]
break
if layer_prefix is None:
raise ValueError("Could not parse format of the base_model_tp_plan in the config.")
# Separate between layer plan, and other module plans
layer_wise_tp_plan = {}
other_modules_tp_plan = {}
for key, plan in full_tp_plan.items():
# In this case, keep the key starting after the "*" layer number indicator
if key.startswith(layer_prefix):
layer_key = key.split("*", 1)[1][1:]
layer_wise_tp_plan[layer_key] = translate_to_torch_parallel_style(plan)
else:
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
tp_key_registry = {}
for key in model_to_load.state_dict().keys():
# This pattern is used to capture the whole model prefix before the layer
pattern = rf"^({layer_prefix}[0-9]+)"
match = re.match(pattern, key)
# In this case, the current key is part of a layer to parallelize as a single entity
if match is not None:
layer = match.group(1)
if layer not in tp_key_registry:
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
else:
tp_key_registry[layer]["children"].add(key)
elif key in other_modules_tp_plan:
if key not in tp_key_registry:
tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]}
else:
tp_key_registry[key]["children"].add(key)
return tp_key_registry
class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
@ -4496,6 +4560,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
# Make sure to tie the weights correctly
model.tie_weights()
# Last check for tp
if device_mesh is not None and not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
@ -4760,9 +4827,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
# tie the model weights before retrieving the state_dict
model.tie_weights()
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
_prefix = f"{prefix}."
@ -4880,28 +4944,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Prepare inputs if using tensor parallel
tp_key_registry = None
if device_mesh is not None:
tp_key_registry = {}
layer_tp_plan = {".".join(k.split(".")[2:]): v for k, v in model.config.base_model_tp_plan.items()}
layer_tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
layer_tp_plan,
)
for key in model_to_load.state_dict().keys():
pattern = r"^(.*layers\.[0-9]+)"
match = re.match(pattern, key)
if match is not None:
layer = match.group(1)
if layer not in tp_key_registry:
tp_key_registry[layer] = {"children": {key}, "plan": layer_tp_plan}
else:
tp_key_registry[layer]["children"].add(key)
elif "lm_head." in key:
if "lm_head" not in tp_key_registry:
plan = translate_to_torch_parallel_style(model._tp_plan["lm_head"])
tp_key_registry["lm_head"] = {"children": {key}, "plan": plan}
else:
tp_key_registry["lm_head"]["children"].add(key)
tp_key_registry = _get_tp_key_registry(model_to_load)
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
# i.e. 1x to load it, and 1x to copy it to model
@ -5018,6 +5062,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)
shutil.rmtree(cpu_offload_folder)
# Post-processing for tensor parallelism
if device_mesh is not None:
tp_device = list(device_map.values())[0]
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
# not part of the state_dict (persistent=False)
for buffer in model.buffers():
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
# In this case, the top-most task module weights were not moved to device and parallelized as they
# were not part of the loaded weights: do it now
if loading_task_model_from_base_state_dict:
tp_plan = getattr(model, "_tp_plan", {})
modules_to_initialize = {name: module for name, module in model.named_children() if name != prefix}
for name, module in modules_to_initialize.items():
# Push to device
module.to(tp_device)
if name in tp_plan:
torch.distributed.tensor.parallel.parallelize_module(
module,
device_mesh=device_mesh,
parallelize_plan=translate_to_torch_parallel_style(tp_plan[name]),
)
# All potential warnings/infos
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
@ -5025,7 +5094,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info