mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update modeling_utils.py
This commit is contained in:
parent
c341f8c6fb
commit
87e97ea54c
1 changed files with 99 additions and 31 deletions
|
|
@ -894,13 +894,15 @@ def _load_state_dict_into_meta_model(
|
|||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
# In this case, let's parallelize the modules!
|
||||
# In this case, let's parallelize the modules as soon as we can!
|
||||
if tp_key_registry is not None:
|
||||
plan = None
|
||||
prefix = None
|
||||
for module_prefix in tp_key_registry.keys():
|
||||
if f"{module_prefix}." in param_name:
|
||||
tp_key_registry[module_prefix]["children"].remove(param_name)
|
||||
# If all the children were already removed, it means that all parameters are now on the correct
|
||||
# device -> we should call `parallelize_module` right now!
|
||||
if len(tp_key_registry[module_prefix]["children"]) == 0:
|
||||
plan = tp_key_registry[module_prefix]["plan"]
|
||||
prefix = module_prefix
|
||||
|
|
@ -1367,15 +1369,12 @@ def _get_device_map(
|
|||
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
||||
device_map_kwargs["max_memory"] = max_memory
|
||||
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.validate_environment(device_map=device_map)
|
||||
|
||||
elif device_map is not None:
|
||||
model.tie_weights()
|
||||
tied_params = find_tied_parameters(model)
|
||||
# check if we don't have tied param in different devices
|
||||
check_tied_parameters_on_same_device(tied_params, device_map)
|
||||
|
|
@ -1453,7 +1452,6 @@ def _find_missing_and_unexpected_keys(
|
|||
if has_inv_freq_buffers:
|
||||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||
|
||||
model.tie_weights()
|
||||
if not low_cpu_mem_usage and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model.state_dict().items():
|
||||
|
|
@ -1615,6 +1613,72 @@ def _find_mismatched_keys(
|
|||
return mismatched_keys
|
||||
|
||||
|
||||
def _get_tp_key_registry(model_to_load: "PreTrainedModel") -> Dict[str, Dict]:
|
||||
"""Create a registry of all keys of the model, and the entity under which they should be parallelized.
|
||||
The strategy is to parallelize each individual Layer as a single entity, then other individual modules as
|
||||
specified in the tp_plan, if they are not part of a bigger Layer module.
|
||||
|
||||
This strategy ensures the best tradeoff between loading speed and memory footprint while parallelizing the model.
|
||||
Indeed, loading the whole model and then parallelizing creates a huge memory overhead, as each process must
|
||||
first load the full model on its given device. The other extreme, parallelizing each leaf of the model (Linear layers)
|
||||
as their parameters are loaded creates a huge speed overhead, due to a lot of call to
|
||||
`torch.distributed.tensor.parallel.parallelize_module`. Parallelizing each layer as soon as their parameters are loaded
|
||||
provides the best of both world: almost no memory overhead, and as fast (sometimes even faster) as parallelizing
|
||||
the full model at once.
|
||||
"""
|
||||
prefix = model_to_load.base_model_prefix
|
||||
is_task_specific_model = hasattr(model_to_load, prefix) if len(prefix) > 0 else False
|
||||
|
||||
full_tp_plan = model_to_load.config.base_model_tp_plan
|
||||
if is_task_specific_model:
|
||||
# Add the prefix to the base model plan
|
||||
full_tp_plan = {f"{prefix}.{key}": plan for key, plan in full_tp_plan.items()}
|
||||
# Add potential task-specific additional plan
|
||||
full_tp_plan.update(getattr(model_to_load, "_tp_plan", {}))
|
||||
|
||||
# Extract full prefix before the layer numbers
|
||||
layer_prefix = None
|
||||
for key in full_tp_plan.keys():
|
||||
if "*" in key:
|
||||
# extract everything before the first "*" corresponding to layer number
|
||||
layer_prefix = key.split("*", 1)[0]
|
||||
break
|
||||
if layer_prefix is None:
|
||||
raise ValueError("Could not parse format of the base_model_tp_plan in the config.")
|
||||
|
||||
# Separate between layer plan, and other module plans
|
||||
layer_wise_tp_plan = {}
|
||||
other_modules_tp_plan = {}
|
||||
for key, plan in full_tp_plan.items():
|
||||
# In this case, keep the key starting after the "*" layer number indicator
|
||||
if key.startswith(layer_prefix):
|
||||
layer_key = key.split("*", 1)[1][1:]
|
||||
layer_wise_tp_plan[layer_key] = translate_to_torch_parallel_style(plan)
|
||||
else:
|
||||
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
|
||||
|
||||
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
|
||||
tp_key_registry = {}
|
||||
for key in model_to_load.state_dict().keys():
|
||||
# This pattern is used to capture the whole model prefix before the layer
|
||||
pattern = rf"^({layer_prefix}[0-9]+)"
|
||||
match = re.match(pattern, key)
|
||||
# In this case, the current key is part of a layer to parallelize as a single entity
|
||||
if match is not None:
|
||||
layer = match.group(1)
|
||||
if layer not in tp_key_registry:
|
||||
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
|
||||
else:
|
||||
tp_key_registry[layer]["children"].add(key)
|
||||
elif key in other_modules_tp_plan:
|
||||
if key not in tp_key_registry:
|
||||
tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]}
|
||||
else:
|
||||
tp_key_registry[key]["children"].add(key)
|
||||
|
||||
return tp_key_registry
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||
|
|
@ -4496,6 +4560,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# Make sure to tie the weights correctly
|
||||
model.tie_weights()
|
||||
|
||||
# Last check for tp
|
||||
if device_mesh is not None and not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
|
|
@ -4760,9 +4827,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys]
|
||||
|
||||
# tie the model weights before retrieving the state_dict
|
||||
model.tie_weights()
|
||||
|
||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
|
||||
prefix = model.base_model_prefix
|
||||
_prefix = f"{prefix}."
|
||||
|
|
@ -4880,27 +4944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# Prepare inputs if using tensor parallel
|
||||
tp_key_registry = None
|
||||
if device_mesh is not None:
|
||||
tp_key_registry = {}
|
||||
layer_tp_plan = {".".join(k.split(".")[2:]): v for k, v in model.config.base_model_tp_plan.items()}
|
||||
layer_tp_plan = torch.utils._pytree.tree_map(
|
||||
translate_to_torch_parallel_style,
|
||||
layer_tp_plan,
|
||||
)
|
||||
for key in model_to_load.state_dict().keys():
|
||||
pattern = r"^(.*layers\.[0-9]+)"
|
||||
match = re.match(pattern, key)
|
||||
if match is not None:
|
||||
layer = match.group(1)
|
||||
if layer not in tp_key_registry:
|
||||
tp_key_registry[layer] = {"children": {key}, "plan": layer_tp_plan}
|
||||
else:
|
||||
tp_key_registry[layer]["children"].add(key)
|
||||
elif "lm_head." in key:
|
||||
if "lm_head" not in tp_key_registry:
|
||||
plan = translate_to_torch_parallel_style(model._tp_plan["lm_head"])
|
||||
tp_key_registry["lm_head"] = {"children": {key}, "plan": plan}
|
||||
else:
|
||||
tp_key_registry["lm_head"]["children"].add(key)
|
||||
tp_key_registry = _get_tp_key_registry(model_to_load)
|
||||
|
||||
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
|
||||
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
|
||||
|
|
@ -5018,6 +5062,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)
|
||||
shutil.rmtree(cpu_offload_folder)
|
||||
|
||||
# Post-processing for tensor parallelism
|
||||
if device_mesh is not None:
|
||||
tp_device = list(device_map.values())[0]
|
||||
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
||||
# not part of the state_dict (persistent=False)
|
||||
for buffer in model.buffers():
|
||||
if buffer.device != tp_device:
|
||||
buffer.data = buffer.to(tp_device)
|
||||
|
||||
# In this case, the top-most task module weights were not moved to device and parallelized as they
|
||||
# were not part of the loaded weights: do it now
|
||||
if loading_task_model_from_base_state_dict:
|
||||
tp_plan = getattr(model, "_tp_plan", {})
|
||||
modules_to_initialize = {name: module for name, module in model.named_children() if name != prefix}
|
||||
for name, module in modules_to_initialize.items():
|
||||
# Push to device
|
||||
module.to(tp_device)
|
||||
if name in tp_plan:
|
||||
torch.distributed.tensor.parallel.parallelize_module(
|
||||
module,
|
||||
device_mesh=device_mesh,
|
||||
parallelize_plan=translate_to_torch_parallel_style(tp_plan[name]),
|
||||
)
|
||||
|
||||
# All potential warnings/infos
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
if "size mismatch" in error_msg:
|
||||
|
|
@ -5025,7 +5094,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
||||
)
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
archs = [] if model.config.architectures is None else model.config.architectures
|
||||
warner = logger.warning if model.__class__.__name__ in archs else logger.info
|
||||
|
|
|
|||
Loading…
Reference in a new issue