mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
Update modeling_utils.py
This commit is contained in:
parent
321f8ee777
commit
51f0aa02fb
1 changed files with 8 additions and 1 deletions
|
|
@ -4366,7 +4366,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
pass
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it)
|
||||
if device_map is not None and device_mesh is None:
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
|
|
@ -4394,6 +4394,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
||||
# not part of the state_dict (persistent=False)
|
||||
if device_mesh is not None:
|
||||
for buffer in model.buffers():
|
||||
if buffer.device != tp_device:
|
||||
buffer.data = buffer.to(tp_device)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model, config=config)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
|
|
|||
Loading…
Reference in a new issue