fix tp plan registry

This commit is contained in:
Cyril Vallez 2025-02-08 19:14:34 +01:00
parent 7a87812f3e
commit e6e75102e8
No known key found for this signature in database

View file

@ -1521,24 +1521,29 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
else:
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
# This pattern is used to capture the whole model prefix before the layer in the keys of the state dict
pattern = rf"^({layer_prefix}[0-9]+)"
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
tp_key_registry = {}
for key in model.state_dict().keys():
# This pattern is used to capture the whole model prefix before the layer
pattern = rf"^({layer_prefix}[0-9]+)"
match = re.match(pattern, key)
# In this case, the current key is part of a layer to parallelize as a single entity
if match is not None:
# Extract the actual layer number
layer = match.group(1)
if layer not in tp_key_registry:
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
else:
tp_key_registry[layer]["children"].add(key)
elif key in other_modules_tp_plan:
if key not in tp_key_registry:
tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]}
else:
tp_key_registry[key]["children"].add(key)
# The key may be part of another module to parallelize, e.g. `lm_head`
else:
for module, plan in other_modules_tp_plan.items():
if key.startswith(module):
if key not in tp_key_registry:
tp_key_registry[module] = {"children": {key}, "plan": plan}
else:
tp_key_registry[module]["children"].add(key)
break
return tp_key_registry