Update modeling_utils.py

This commit is contained in:
Cyril Vallez 2025-01-31 19:10:20 +01:00
parent 321f8ee777
commit 51f0aa02fb
No known key found for this signature in database

View file

@ -4366,7 +4366,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
pass
# Dispatch model with hooks on all devices if necessary
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it)
if device_map is not None and device_mesh is None:
device_map_kwargs = {
"device_map": device_map,
@ -4394,6 +4394,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
# not part of the state_dict (persistent=False)
if device_mesh is not None:
for buffer in model.buffers():
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer