mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
15 lines
359 B
Python
15 lines
359 B
Python
import torch
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
|
|
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 4]
|
|
x2 = x[..., x.shape[-1] // 4 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
# example where we need some deps and some functions
|
|
class DummyModel(LlamaModel):
|
|
pass
|