fix tp plan registry

This commit is contained in:
Cyril Vallez 2025-02-08 19:14:34 +01:00
parent 7a87812f3e
commit e6e75102e8
No known key found for this signature in database

View file

@ -1521,24 +1521,29 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
else: else:
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan) other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
# This pattern is used to capture the whole model prefix before the layer in the keys of the state dict
pattern = rf"^({layer_prefix}[0-9]+)"
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any) # Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
tp_key_registry = {} tp_key_registry = {}
for key in model.state_dict().keys(): for key in model.state_dict().keys():
# This pattern is used to capture the whole model prefix before the layer
pattern = rf"^({layer_prefix}[0-9]+)"
match = re.match(pattern, key) match = re.match(pattern, key)
# In this case, the current key is part of a layer to parallelize as a single entity # In this case, the current key is part of a layer to parallelize as a single entity
if match is not None: if match is not None:
# Extract the actual layer number
layer = match.group(1) layer = match.group(1)
if layer not in tp_key_registry: if layer not in tp_key_registry:
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan} tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
else: else:
tp_key_registry[layer]["children"].add(key) tp_key_registry[layer]["children"].add(key)
elif key in other_modules_tp_plan: # The key may be part of another module to parallelize, e.g. `lm_head`
if key not in tp_key_registry: else:
tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]} for module, plan in other_modules_tp_plan.items():
else: if key.startswith(module):
tp_key_registry[key]["children"].add(key) if key not in tp_key_registry:
tp_key_registry[module] = {"children": {key}, "plan": plan}
else:
tp_key_registry[module]["children"].add(key)
break
return tp_key_registry return tp_key_registry