From e6e75102e8ac224c2428346abb4142ffeec8db0f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Sat, 8 Feb 2025 19:14:34 +0100 Subject: [PATCH] fix tp plan registry --- src/transformers/modeling_utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 993036bed..f792904cc 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1521,24 +1521,29 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]: else: other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan) + # This pattern is used to capture the whole model prefix before the layer in the keys of the state dict + pattern = rf"^({layer_prefix}[0-9]+)" # Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any) tp_key_registry = {} for key in model.state_dict().keys(): - # This pattern is used to capture the whole model prefix before the layer - pattern = rf"^({layer_prefix}[0-9]+)" match = re.match(pattern, key) # In this case, the current key is part of a layer to parallelize as a single entity if match is not None: + # Extract the actual layer number layer = match.group(1) if layer not in tp_key_registry: tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan} else: tp_key_registry[layer]["children"].add(key) - elif key in other_modules_tp_plan: - if key not in tp_key_registry: - tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]} - else: - tp_key_registry[key]["children"].add(key) + # The key may be part of another module to parallelize, e.g. `lm_head` + else: + for module, plan in other_modules_tp_plan.items(): + if key.startswith(module): + if key not in tp_key_registry: + tp_key_registry[module] = {"children": {key}, "plan": plan} + else: + tp_key_registry[module]["children"].add(key) + break return tp_key_registry