default device for rotary

This commit is contained in:
Cyril Vallez 2025-01-09 13:40:38 +01:00
parent c8ab6ce6ce
commit 7599f0d156
No known key found for this signature in database

View file

@ -4156,6 +4156,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer.validate_environment(device_map=device_map)
elif device_map is not None:
# Make sure we correctly place the rotary embedding module by default if not provided, as we moved it from
# inside the Layers to the Model
for buffer in {name for name, _ in model.named_buffers()}:
rotary_module = None
if "rotary_emb.inv_freq" in buffer:
rotary_module = buffer.replace(".inv_freq", "")
break
if rotary_module is not None and rotary_module not in device_map:
# Place it on the same device as the embedding if set and exists, else 0
device_map[rotary_module] = device_map.get(rotary_module.replace("rotary_emb", "embed_tokens"), 0)
model.tie_weights()
tied_params = find_tied_parameters(model)
# check if we don't have tied param in different devices