new first tp loading version

This commit is contained in:
Cyril Vallez 2025-02-05 21:44:46 +01:00
parent c3e818561e
commit ff1078387e
No known key found for this signature in database

View file

@ -766,6 +766,8 @@ def _load_state_dict_into_meta_model(
is_safetensors: bool = False,
keep_in_fp32_modules: Optional[List[str]] = None,
unexpected_keys: Optional[Dict] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
tp_key_registry: Optional[Dict] = None,
) -> Tuple[List[str], Optional[Dict], Optional[Dict]]:
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@ -775,6 +777,7 @@ def _load_state_dict_into_meta_model(
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
It also initialize tensor parallelism according to `tp_key_registry` if needed.
"""
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
@ -891,6 +894,31 @@ def _load_state_dict_into_meta_model(
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return
# In this case, let's parallelize the modules!
if tp_key_registry is not None:
plan = None
prefix = None
for module_prefix in tp_key_registry.keys():
if f"{module_prefix}." in param_name:
tp_key_registry[module_prefix]["children"].remove(param_name)
if len(tp_key_registry[module_prefix]["children"]) == 0:
plan = tp_key_registry[module_prefix]["plan"]
prefix = module_prefix
break
if plan is not None:
del tp_key_registry[prefix]
parent_module = model
for name in prefix.split("."):
parent_module = getattr(parent_module, name)
torch.distributed.tensor.parallel.parallelize_module(
parent_module,
device_mesh=device_mesh,
parallelize_plan=plan,
)
return error_msgs, disk_offload_index, cpu_offload_index
@ -4101,12 +4129,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
# And temporarily setting the default device to current process rank result in the following error
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
tp_device = None
# `device_map` pointing to the correct device
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
@ -4118,6 +4145,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
if is_fsdp_enabled():
low_cpu_mem_usage = True
@ -4210,7 +4241,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage:
if is_deepspeed_zero3_enabled():
@ -4219,7 +4250,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
elif not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
@ -4432,6 +4463,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
# Last check for tp
if device_mesh is not None and not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@ -4491,6 +4526,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_file=gguf_file,
device_mesh=device_mesh,
weights_only=weights_only,
_fast_init=_fast_init,
)
@ -4526,8 +4562,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
pass
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)
if device_map is not None and device_mesh is None:
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
@ -4554,6 +4591,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
# not part of the state_dict (persistent=False)
if device_mesh is not None:
for buffer in model.buffers():
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer
@ -4566,16 +4610,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
adapter_kwargs=adapter_kwargs,
)
if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
if not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
@ -4675,6 +4709,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer: Optional[HfQuantizer] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
gguf_file: Optional[str] = None,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
weights_only: bool = True,
_fast_init: bool = True,
):
@ -4807,6 +4842,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
disk_offload_index = {}
# Prepare inputs if using tensor parallel
tp_key_registry = None
if device_mesh is not None:
tp_key_registry = {}
layer_tp_plan = {".".join(k.split(".")[2:]): v for k, v in model.config.base_model_tp_plan.items()}
layer_tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
layer_tp_plan,
)
for key in model_to_load.state_dict().keys():
pattern = r"^(.*layers\.[0-9]+)"
match = re.match(pattern, key)
if match is not None:
layer = match.group(1)
if layer not in tp_key_registry:
tp_key_registry[layer] = {"children": {key}, "plan": layer_tp_plan}
else:
tp_key_registry[layer]["children"].add(key)
elif "lm_head." in key:
if "lm_head" not in tp_key_registry:
plan = translate_to_torch_parallel_style(model._tp_plan["lm_head"])
tp_key_registry["lm_head"] = {"children": {key}, "plan": plan}
else:
tp_key_registry["lm_head"]["children"].add(key)
# This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
# It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
# i.e. 1x to load it, and 1x to copy it to model
@ -4878,6 +4938,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
tp_key_registry=tp_key_registry,
)
error_msgs += new_error_msgs
else:
@ -5192,7 +5254,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def tensor_parallel(self, device_mesh):
"""
Tensor parallelize the model across the given device mesh.
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
was already loaded in memory, note however that this means that each process will first initialize the whole model,
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
Args:
device_mesh (`torch.distributed.DeviceMesh`):