mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
default device for rotary
This commit is contained in:
parent
c8ab6ce6ce
commit
7599f0d156
1 changed files with 10 additions and 0 deletions
|
|
@ -4156,6 +4156,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
hf_quantizer.validate_environment(device_map=device_map)
|
||||
|
||||
elif device_map is not None:
|
||||
# Make sure we correctly place the rotary embedding module by default if not provided, as we moved it from
|
||||
# inside the Layers to the Model
|
||||
for buffer in {name for name, _ in model.named_buffers()}:
|
||||
rotary_module = None
|
||||
if "rotary_emb.inv_freq" in buffer:
|
||||
rotary_module = buffer.replace(".inv_freq", "")
|
||||
break
|
||||
if rotary_module is not None and rotary_module not in device_map:
|
||||
# Place it on the same device as the embedding if set and exists, else 0
|
||||
device_map[rotary_module] = device_map.get(rotary_module.replace("rotary_emb", "embed_tokens"), 0)
|
||||
model.tie_weights()
|
||||
tied_params = find_tied_parameters(model)
|
||||
# check if we don't have tied param in different devices
|
||||
|
|
|
|||
Loading…
Reference in a new issue