diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ed17c0010..c8c48e5f3 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -894,13 +894,15 @@ def _load_state_dict_into_meta_model( setattr(module, tensor_name, value) # TODO: consider removing used param_parts from state_dict before return - # In this case, let's parallelize the modules! + # In this case, let's parallelize the modules as soon as we can! if tp_key_registry is not None: plan = None prefix = None for module_prefix in tp_key_registry.keys(): if f"{module_prefix}." in param_name: tp_key_registry[module_prefix]["children"].remove(param_name) + # If all the children were already removed, it means that all parameters are now on the correct + # device -> we should call `parallelize_module` right now! if len(tp_key_registry[module_prefix]["children"]) == 0: plan = tp_key_registry[module_prefix]["plan"] prefix = module_prefix @@ -1367,15 +1369,12 @@ def _get_device_map( max_memory = hf_quantizer.adjust_max_memory(max_memory) device_map_kwargs["max_memory"] = max_memory - # Make sure tied weights are tied before creating the device map. - model.tie_weights() device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) elif device_map is not None: - model.tie_weights() tied_params = find_tied_parameters(model) # check if we don't have tied param in different devices check_tied_parameters_on_same_device(tied_params, device_map) @@ -1453,7 +1452,6 @@ def _find_missing_and_unexpected_keys( if has_inv_freq_buffers: unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] - model.tie_weights() if not low_cpu_mem_usage and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): ptrs = collections.defaultdict(list) for name, tensor in model.state_dict().items(): @@ -1615,6 +1613,72 @@ def _find_mismatched_keys( return mismatched_keys +def _get_tp_key_registry(model_to_load: "PreTrainedModel") -> Dict[str, Dict]: + """Create a registry of all keys of the model, and the entity under which they should be parallelized. + The strategy is to parallelize each individual Layer as a single entity, then other individual modules as + specified in the tp_plan, if they are not part of a bigger Layer module. + + This strategy ensures the best tradeoff between loading speed and memory footprint while parallelizing the model. + Indeed, loading the whole model and then parallelizing creates a huge memory overhead, as each process must + first load the full model on its given device. The other extreme, parallelizing each leaf of the model (Linear layers) + as their parameters are loaded creates a huge speed overhead, due to a lot of call to + `torch.distributed.tensor.parallel.parallelize_module`. Parallelizing each layer as soon as their parameters are loaded + provides the best of both world: almost no memory overhead, and as fast (sometimes even faster) as parallelizing + the full model at once. + """ + prefix = model_to_load.base_model_prefix + is_task_specific_model = hasattr(model_to_load, prefix) if len(prefix) > 0 else False + + full_tp_plan = model_to_load.config.base_model_tp_plan + if is_task_specific_model: + # Add the prefix to the base model plan + full_tp_plan = {f"{prefix}.{key}": plan for key, plan in full_tp_plan.items()} + # Add potential task-specific additional plan + full_tp_plan.update(getattr(model_to_load, "_tp_plan", {})) + + # Extract full prefix before the layer numbers + layer_prefix = None + for key in full_tp_plan.keys(): + if "*" in key: + # extract everything before the first "*" corresponding to layer number + layer_prefix = key.split("*", 1)[0] + break + if layer_prefix is None: + raise ValueError("Could not parse format of the base_model_tp_plan in the config.") + + # Separate between layer plan, and other module plans + layer_wise_tp_plan = {} + other_modules_tp_plan = {} + for key, plan in full_tp_plan.items(): + # In this case, keep the key starting after the "*" layer number indicator + if key.startswith(layer_prefix): + layer_key = key.split("*", 1)[1][1:] + layer_wise_tp_plan[layer_key] = translate_to_torch_parallel_style(plan) + else: + other_modules_tp_plan[key] = translate_to_torch_parallel_style(plan) + + # Create the registry. Here we use the layers as single entities to parallelize (and other individual modules if any) + tp_key_registry = {} + for key in model_to_load.state_dict().keys(): + # This pattern is used to capture the whole model prefix before the layer + pattern = rf"^({layer_prefix}[0-9]+)" + match = re.match(pattern, key) + # In this case, the current key is part of a layer to parallelize as a single entity + if match is not None: + layer = match.group(1) + if layer not in tp_key_registry: + tp_key_registry[layer] = {"children": {key}, "plan": layer_wise_tp_plan} + else: + tp_key_registry[layer]["children"].add(key) + elif key in other_modules_tp_plan: + if key not in tp_key_registry: + tp_key_registry[key] = {"children": {key}, "plan": other_modules_tp_plan[key]} + else: + tp_key_registry[key]["children"].add(key) + + return tp_key_registry + + class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -4496,6 +4560,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) + # Make sure to tie the weights correctly + model.tie_weights() + # Last check for tp if device_mesh is not None and not model.supports_tp_plan: raise NotImplementedError("This model does not have a tensor parallel plan.") @@ -4760,9 +4827,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys] - # tie the model weights before retrieving the state_dict - model.tie_weights() - # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture prefix = model.base_model_prefix _prefix = f"{prefix}." @@ -4880,28 +4944,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Prepare inputs if using tensor parallel tp_key_registry = None if device_mesh is not None: - tp_key_registry = {} - layer_tp_plan = {".".join(k.split(".")[2:]): v for k, v in model.config.base_model_tp_plan.items()} - layer_tp_plan = torch.utils._pytree.tree_map( - translate_to_torch_parallel_style, - layer_tp_plan, - ) - for key in model_to_load.state_dict().keys(): - pattern = r"^(.*layers\.[0-9]+)" - match = re.match(pattern, key) - if match is not None: - layer = match.group(1) - if layer not in tp_key_registry: - tp_key_registry[layer] = {"children": {key}, "plan": layer_tp_plan} - else: - tp_key_registry[layer]["children"].add(key) - elif "lm_head." in key: - if "lm_head" not in tp_key_registry: - plan = translate_to_torch_parallel_style(model._tp_plan["lm_head"]) - tp_key_registry["lm_head"] = {"children": {key}, "plan": plan} - else: - tp_key_registry["lm_head"]["children"].add(key) - + tp_key_registry = _get_tp_key_registry(model_to_load) + # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size, # i.e. 1x to load it, and 1x to copy it to model @@ -5018,6 +5062,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder) shutil.rmtree(cpu_offload_folder) + # Post-processing for tensor parallelism + if device_mesh is not None: + tp_device = list(device_map.values())[0] + # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is + # not part of the state_dict (persistent=False) + for buffer in model.buffers(): + if buffer.device != tp_device: + buffer.data = buffer.to(tp_device) + + # In this case, the top-most task module weights were not moved to device and parallelized as they + # were not part of the loaded weights: do it now + if loading_task_model_from_base_state_dict: + tp_plan = getattr(model, "_tp_plan", {}) + modules_to_initialize = {name: module for name, module in model.named_children() if name != prefix} + for name, module in modules_to_initialize.items(): + # Push to device + module.to(tp_device) + if name in tp_plan: + torch.distributed.tensor.parallel.parallelize_module( + module, + device_mesh=device_mesh, + parallelize_plan=translate_to_torch_parallel_style(tp_plan[name]), + ) + + # All potential warnings/infos if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: @@ -5025,7 +5094,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - if len(unexpected_keys) > 0: archs = [] if model.config.architectures is None else model.config.architectures warner = logger.warning if model.__class__.__name__ in archs else logger.info