mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
simplify
This commit is contained in:
parent
e6e75102e8
commit
c7b175ec95
1 changed files with 36 additions and 32 deletions
|
|
@ -418,9 +418,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
|||
Args:
|
||||
model (`torch.nn.Module`): The model in which to load the checkpoint.
|
||||
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
|
||||
strict (`bool`, *optional`, defaults to `True`):
|
||||
strict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||
prefer_safe (`bool`, *optional*, defaults to `False`)
|
||||
prefer_safe (`bool`, *optional*, defaults to `False`):
|
||||
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
|
||||
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
|
||||
|
||||
|
|
@ -1485,10 +1485,18 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
|
|||
This strategy ensures the best tradeoff between loading speed and memory footprint while parallelizing the model.
|
||||
Indeed, loading the whole model and then parallelizing creates a huge memory overhead, as each process must
|
||||
first load the full model on its given device. The other extreme, parallelizing each leaf of the model (Linear layers)
|
||||
as their parameters are loaded creates a huge speed overhead, due to a lot of call to
|
||||
`torch.distributed.tensor.parallel.parallelize_module`. Parallelizing each layer as soon as their parameters are loaded
|
||||
provides the best of both world: almost no memory overhead, and as fast (sometimes even faster) as parallelizing
|
||||
the full model at once.
|
||||
as they are loaded creates a huge speed overhead, due to a lot of call to `torch.distributed.tensor.parallel.parallelize_module`.
|
||||
Parallelizing each Layer instead as soon as their parameters are loaded provides the best of both world: almost no
|
||||
memory overhead, and almost as fast as parallelizing the full model at once.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The model to parallelize. It must have a valid tp plan associated.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict]: The associated TP registry. They keys are the name of the module which will be parallelized,
|
||||
i.e. the module on which we will call `torch.distributed.tensor.parallel.parallelize_module`. The values
|
||||
are a `Dict` containing 2 keys, `children` which is a set of all parameters of the given module, and `plan`,
|
||||
which contains the tp_plan (a `Dict` or `str`) for the given module.
|
||||
"""
|
||||
prefix = model.base_model_prefix
|
||||
is_task_specific_model = hasattr(model, prefix) if len(prefix) > 0 else False
|
||||
|
|
@ -1506,44 +1514,40 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
|
|||
if "*" in key:
|
||||
# extract everything before the first "*" corresponding to layer number
|
||||
layer_prefix = key.split("*", 1)[0]
|
||||
# Remove ending dot
|
||||
layer_prefix = layer_prefix[:-1] if layer_prefix.endswith(".") else layer_prefix
|
||||
break
|
||||
if layer_prefix is None:
|
||||
raise ValueError("Could not parse format of the base_model_tp_plan in the config.")
|
||||
|
||||
# Separate between layer plan, and other module plans
|
||||
layer_wise_tp_plan = {}
|
||||
other_modules_tp_plan = {}
|
||||
layer_tp_plan = {}
|
||||
other_modules_to_parallelize = {}
|
||||
for key, plan in full_tp_plan.items():
|
||||
# In this case, keep the key starting after the "*" layer number indicator
|
||||
if key.startswith(layer_prefix):
|
||||
layer_key = key.split("*", 1)[1][1:]
|
||||
layer_wise_tp_plan[layer_key] = translate_to_torch_parallel_style(plan)
|
||||
layer_tp_plan[layer_key] = translate_to_torch_parallel_style(plan)
|
||||
else:
|
||||
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
|
||||
other_modules_to_parallelize[key] = translate_to_torch_parallel_style(plan)
|
||||
|
||||
# This contains all the modules to parallelize, as well as corresponding tp_plan for the full module
|
||||
modules_to_parallelize = {
|
||||
f"{layer_prefix}.{layer_idx}": layer_tp_plan for layer_idx in range(model.config.num_hidden_layers)
|
||||
}
|
||||
modules_to_parallelize.update(other_modules_to_parallelize)
|
||||
|
||||
# This pattern is used to capture the whole model prefix before the layer in the keys of the state dict
|
||||
pattern = rf"^({layer_prefix}[0-9]+)"
|
||||
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
|
||||
tp_key_registry = {}
|
||||
for key in model.state_dict().keys():
|
||||
match = re.match(pattern, key)
|
||||
# In this case, the current key is part of a layer to parallelize as a single entity
|
||||
if match is not None:
|
||||
# Extract the actual layer number
|
||||
layer = match.group(1)
|
||||
if layer not in tp_key_registry:
|
||||
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
|
||||
else:
|
||||
tp_key_registry[layer]["children"].add(key)
|
||||
# The key may be part of another module to parallelize, e.g. `lm_head`
|
||||
else:
|
||||
for module, plan in other_modules_tp_plan.items():
|
||||
if key.startswith(module):
|
||||
if key not in tp_key_registry:
|
||||
tp_key_registry[module] = {"children": {key}, "plan": plan}
|
||||
else:
|
||||
tp_key_registry[module]["children"].add(key)
|
||||
break
|
||||
for module_name, tp_plan in modules_to_parallelize.items():
|
||||
# Retrieve the actual module object corresponding to module_name
|
||||
actual_module = model
|
||||
for name in module_name.split("."):
|
||||
actual_module = getattr(actual_module, name)
|
||||
|
||||
# Find all parameters in the state dict of the module
|
||||
children = set(actual_module.state_dict().keys())
|
||||
children = {f"{module_name}.{child}" for child in children}
|
||||
tp_key_registry[module_name] = {"children": children, "plan": tp_plan}
|
||||
|
||||
return tp_key_registry
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue