mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix tp plan registry
This commit is contained in:
parent
7a87812f3e
commit
e6e75102e8
1 changed files with 12 additions and 7 deletions
|
|
@ -1521,24 +1521,29 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
|
|||
else:
|
||||
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
|
||||
|
||||
# This pattern is used to capture the whole model prefix before the layer in the keys of the state dict
|
||||
pattern = rf"^({layer_prefix}[0-9]+)"
|
||||
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
|
||||
tp_key_registry = {}
|
||||
for key in model.state_dict().keys():
|
||||
# This pattern is used to capture the whole model prefix before the layer
|
||||
pattern = rf"^({layer_prefix}[0-9]+)"
|
||||
match = re.match(pattern, key)
|
||||
# In this case, the current key is part of a layer to parallelize as a single entity
|
||||
if match is not None:
|
||||
# Extract the actual layer number
|
||||
layer = match.group(1)
|
||||
if layer not in tp_key_registry:
|
||||
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
|
||||
else:
|
||||
tp_key_registry[layer]["children"].add(key)
|
||||
elif key in other_modules_tp_plan:
|
||||
if key not in tp_key_registry:
|
||||
tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]}
|
||||
else:
|
||||
tp_key_registry[key]["children"].add(key)
|
||||
# The key may be part of another module to parallelize, e.g. `lm_head`
|
||||
else:
|
||||
for module, plan in other_modules_tp_plan.items():
|
||||
if key.startswith(module):
|
||||
if key not in tp_key_registry:
|
||||
tp_key_registry[module] = {"children": {key}, "plan": plan}
|
||||
else:
|
||||
tp_key_registry[module]["children"].add(key)
|
||||
break
|
||||
|
||||
return tp_key_registry
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue