diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f792904cc..061e68a26 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -418,9 +418,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): Args: model (`torch.nn.Module`): The model in which to load the checkpoint. folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. - strict (`bool`, *optional`, defaults to `True`): + strict (`bool`, *optional*, defaults to `True`): Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. - prefer_safe (`bool`, *optional*, defaults to `False`) + prefer_safe (`bool`, *optional*, defaults to `False`): If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. @@ -1485,10 +1485,18 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]: This strategy ensures the best tradeoff between loading speed and memory footprint while parallelizing the model. Indeed, loading the whole model and then parallelizing creates a huge memory overhead, as each process must first load the full model on its given device. The other extreme, parallelizing each leaf of the model (Linear layers) - as their parameters are loaded creates a huge speed overhead, due to a lot of call to - `torch.distributed.tensor.parallel.parallelize_module`. Parallelizing each layer as soon as their parameters are loaded - provides the best of both world: almost no memory overhead, and as fast (sometimes even faster) as parallelizing - the full model at once. + as they are loaded creates a huge speed overhead, due to a lot of call to `torch.distributed.tensor.parallel.parallelize_module`. + Parallelizing each Layer instead as soon as their parameters are loaded provides the best of both world: almost no + memory overhead, and almost as fast as parallelizing the full model at once. + + Args: + model (`PreTrainedModel`): The model to parallelize. It must have a valid tp plan associated. + + Returns: + Dict[str, Dict]: The associated TP registry. They keys are the name of the module which will be parallelized, + i.e. the module on which we will call `torch.distributed.tensor.parallel.parallelize_module`. The values + are a `Dict` containing 2 keys, `children` which is a set of all parameters of the given module, and `plan`, + which contains the tp_plan (a `Dict` or `str`) for the given module. """ prefix = model.base_model_prefix is_task_specific_model = hasattr(model, prefix) if len(prefix) > 0 else False @@ -1506,44 +1514,40 @@ def _get_tp_key_registry(model: "PreTrainedModel") -> Dict[str, Dict]: if "*" in key: # extract everything before the first "*" corresponding to layer number layer_prefix = key.split("*", 1)[0] + # Remove ending dot + layer_prefix = layer_prefix[:-1] if layer_prefix.endswith(".") else layer_prefix break if layer_prefix is None: raise ValueError("Could not parse format of the base_model_tp_plan in the config.") # Separate between layer plan, and other module plans - layer_wise_tp_plan = {} - other_modules_tp_plan = {} + layer_tp_plan = {} + other_modules_to_parallelize = {} for key, plan in full_tp_plan.items(): # In this case, keep the key starting after the "*" layer number indicator if key.startswith(layer_prefix): layer_key = key.split("*", 1)[1][1:] - layer_wise_tp_plan[layer_key] = translate_to_torch_parallel_style(plan) + layer_tp_plan[layer_key] = translate_to_torch_parallel_style(plan) else: - other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan) + other_modules_to_parallelize[key] = translate_to_torch_parallel_style(plan) + + # This contains all the modules to parallelize, as well as corresponding tp_plan for the full module + modules_to_parallelize = { + f"{layer_prefix}.{layer_idx}": layer_tp_plan for layer_idx in range(model.config.num_hidden_layers) + } + modules_to_parallelize.update(other_modules_to_parallelize) - # This pattern is used to capture the whole model prefix before the layer in the keys of the state dict - pattern = rf"^({layer_prefix}[0-9]+)" - # Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any) tp_key_registry = {} - for key in model.state_dict().keys(): - match = re.match(pattern, key) - # In this case, the current key is part of a layer to parallelize as a single entity - if match is not None: - # Extract the actual layer number - layer = match.group(1) - if layer not in tp_key_registry: - tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan} - else: - tp_key_registry[layer]["children"].add(key) - # The key may be part of another module to parallelize, e.g. `lm_head` - else: - for module, plan in other_modules_tp_plan.items(): - if key.startswith(module): - if key not in tp_key_registry: - tp_key_registry[module] = {"children": {key}, "plan": plan} - else: - tp_key_registry[module]["children"].add(key) - break + for module_name, tp_plan in modules_to_parallelize.items(): + # Retrieve the actual module object corresponding to module_name + actual_module = model + for name in module_name.split("."): + actual_module = getattr(actual_module, name) + + # Find all parameters in the state dict of the module + children = set(actual_module.state_dict().keys()) + children = {f"{module_name}.{child}" for child in children} + tp_key_registry[module_name] = {"children": children, "plan": tp_plan} return tp_key_registry