mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix tp plan registry
This commit is contained in:
parent
7a87812f3e
commit
e6e75102e8
1 changed files with 12 additions and 7 deletions
|
|
@ -1521,24 +1521,29 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]:
|
||||||
else:
|
else:
|
||||||
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
|
other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan)
|
||||||
|
|
||||||
|
# This pattern is used to capture the whole model prefix before the layer in the keys of the state dict
|
||||||
|
pattern = rf"^({layer_prefix}[0-9]+)"
|
||||||
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
|
# Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any)
|
||||||
tp_key_registry = {}
|
tp_key_registry = {}
|
||||||
for key in model.state_dict().keys():
|
for key in model.state_dict().keys():
|
||||||
# This pattern is used to capture the whole model prefix before the layer
|
|
||||||
pattern = rf"^({layer_prefix}[0-9]+)"
|
|
||||||
match = re.match(pattern, key)
|
match = re.match(pattern, key)
|
||||||
# In this case, the current key is part of a layer to parallelize as a single entity
|
# In this case, the current key is part of a layer to parallelize as a single entity
|
||||||
if match is not None:
|
if match is not None:
|
||||||
|
# Extract the actual layer number
|
||||||
layer = match.group(1)
|
layer = match.group(1)
|
||||||
if layer not in tp_key_registry:
|
if layer not in tp_key_registry:
|
||||||
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
|
tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan}
|
||||||
else:
|
else:
|
||||||
tp_key_registry[layer]["children"].add(key)
|
tp_key_registry[layer]["children"].add(key)
|
||||||
elif key in other_modules_tp_plan:
|
# The key may be part of another module to parallelize, e.g. `lm_head`
|
||||||
if key not in tp_key_registry:
|
else:
|
||||||
tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]}
|
for module, plan in other_modules_tp_plan.items():
|
||||||
else:
|
if key.startswith(module):
|
||||||
tp_key_registry[key]["children"].add(key)
|
if key not in tp_key_registry:
|
||||||
|
tp_key_registry[module] = {"children": {key}, "plan": plan}
|
||||||
|
else:
|
||||||
|
tp_key_registry[module]["children"].add(key)
|
||||||
|
break
|
||||||
|
|
||||||
return tp_key_registry
|
return tp_key_registry
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue