This commit is contained in:
Cyril Vallez 2025-02-08 20:15:48 +01:00
parent e6e75102e8
commit c7b175ec95
No known key found for this signature in database

View file

@ -418,9 +418,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
Args:
model (`torch.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional`, defaults to `True`):
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
prefer_safe (`bool`, *optional*, defaults to `False`)
prefer_safe (`bool`, *optional*, defaults to `False`):
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
@ -1485,10 +1485,18 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
This strategy ensures the best tradeoff between loading speed and memory footprint while parallelizing the model.
Indeed, loading the whole model and then parallelizing creates a huge memory overhead, as each process must
first load the full model on its given device. The other extreme, parallelizing each leaf of the model (Linear layers)
as their parameters are loaded creates a huge speed overhead, due to a lot of call to
`torch.distributed.tensor.parallel.parallelize_module`. Parallelizing each layer as soon as their parameters are loaded
provides the best of both world: almost no memory overhead, and as fast (sometimes even faster) as parallelizing
the full model at once.
as they are loaded creates a huge speed overhead, due to a lot of call to `torch.distributed.tensor.parallel.parallelize_module`.
Parallelizing each Layer instead as soon as their parameters are loaded provides the best of both world: almost no
memory overhead, and almost as fast as parallelizing the full model at once.
Args:
model (`PreTrainedModel`): The model to parallelize. It must have a valid tp plan associated.
Returns:
Dict[str, Dict]: The associated TP registry. They keys are the name of the module which will be parallelized,
i.e. the module on which we will call `torch.distributed.tensor.parallel.parallelize_module`. The values
are a `Dict` containing 2 keys, `children` which is a set of all parameters of the given module, and `plan`,
which contains the tp_plan (a `Dict` or `str`) for the given module.
"""
prefix = model.base_model_prefix
is_task_specific_model = hasattr(model, prefix) if len(prefix) > 0 else False
@ -1506,44 +1514,40 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
if "*" in key:
# extract everything before the first "*" corresponding to layer number
layer_prefix = key.split("*", 1)[0]
# Remove ending dot
layer_prefix = layer_prefix[:-1] if layer_prefix.endswith(".") else layer_prefix
break
if layer_prefix is None:
raise ValueError("Could not parse format of the base_model_tp_plan in the config.")
# Separate between layer plan, and other module plans
layer_wise_tp_plan = {}
other_modules_tp_plan = {}
layer_tp_plan = {}
other_modules_to_parallelize = {}
for key, plan in full_tp_plan.items():
# In this case, keep the key starting after the "*" layer number indicator
if key.startswith(layer_prefix):
layer_key = key.split("*", 1)[1][1:]
layer_wise_tp_plan[layer_key] = translate_to_torch_parallel_style(plan)
layer_tp_plan[layer_key] = translate_to_torch_parallel_style(plan)
else:
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
other_modules_to_parallelize[key] = translate_to_torch_parallel_style(plan)
# This contains all the modules to parallelize, as well as corresponding tp_plan for the full module
modules_to_parallelize = {
f"{layer_prefix}.{layer_idx}": layer_tp_plan for layer_idx in range(model.config.num_hidden_layers)
}
modules_to_parallelize.update(other_modules_to_parallelize)
# This pattern is used to capture the whole model prefix before the layer in the keys of the state dict
pattern = rf"^({layer_prefix}[0-9]+)"
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
tp_key_registry = {}
for key in model.state_dict().keys():
match = re.match(pattern, key)
# In this case, the current key is part of a layer to parallelize as a single entity
if match is not None:
# Extract the actual layer number
layer = match.group(1)
if layer not in tp_key_registry:
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
else:
tp_key_registry[layer]["children"].add(key)
# The key may be part of another module to parallelize, e.g. `lm_head`
else:
for module, plan in other_modules_tp_plan.items():
if key.startswith(module):
if key not in tp_key_registry:
tp_key_registry[module] = {"children": {key}, "plan": plan}
else:
tp_key_registry[module]["children"].add(key)
break
for module_name, tp_plan in modules_to_parallelize.items():
# Retrieve the actual module object corresponding to module_name
actual_module = model
for name in module_name.split("."):
actual_module = getattr(actual_module, name)
# Find all parameters in the state dict of the module
children = set(actual_module.state_dict().keys())
children = {f"{module_name}.{child}" for child in children}
tp_key_registry[module_name] = {"children": children, "plan": tp_plan}
return tp_key_registry