diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 874c7abaf..2c7e057a8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4366,7 +4366,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) pass - # Dispatch model with hooks on all devices if necessary + # Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it) if device_map is not None and device_mesh is None: device_map_kwargs = { "device_map": device_map, @@ -4394,6 +4394,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) + # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is + # not part of the state_dict (persistent=False) + if device_mesh is not None: + for buffer in model.buffers(): + if buffer.device != tp_device: + buffer.data = buffer.to(tp_device) + if hf_quantizer is not None: hf_quantizer.postprocess_model(model, config=config) model.hf_quantizer = hf_quantizer