Rotary Position Embeddings (RoPE)
Transformers have no built-in notion of order – they see tokens as a set, not a sequence. RoPE fixes this by encoding position into attention itself.
The core idea: rotate vector A by angle and vector B by angle – their dot product depends on (). Rotate each token's q/k by an angle proportional to its position, and attention scores will reflect relative distance between tokens.
Token 3 and token 5? The angle difference is the same as between token 10 and token 12.
Why 2D? 1D is just scaling, not rotation. Higher-D needs rotation matrices and doesn't really help. 2D hits the sweet spot: real rotation, simple complex math, and each pair gets its own frequency.
High frequencies pick up nearby tokens, low frequencies pick up distant ones.
I use complex numbers because rotating by is just multiplying by . Same math, cleaner code (at least for me :D).
I also use einops – easier to follow the tensor shapes.
Implementation
from dataclasses import dataclass
import torch
from einops import einsum, rearrange
from torch import nnConfig
@dataclass
class Config:head_dim must be even – we pair up dimensions for 2D rotation
head_dim: int = 8base frequency. 10_000 is standard. higher -> slower decay -> longer context
rope_theta: float = 10_000Computing Frequencies
Each pair of dimensions rotates at a different frequency. Pair 0 rotates fast, pair 1 slower, pair 2 even slower, etc.
class RoPE(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.config = config
self.theta = config.rope_theta
self.head_dim = config.head_diminv_freq = the rotation speeds for each dimension pair
inv_freq = self._compute_inverse_frequencies()register_buffer: not a learnable param, but moves to GPU with the model
persistent=False: don't save to state_dict (we recompute on load)
self.register_buffer("inv_freq", inv_freq, persistent=False)Computing Inverse Frequencies
RoPE uses different rotation speeds for each dimension pair. Low-index pairs rotate fast → sensitive to nearby positions. High-index pairs rotate slow → sensitive to distant positions.
def _compute_inverse_frequencies(self):for head_dim=8: [0, 2, 4, 6] / 8 = [0, 0.25, 0.5, 0.75] these are the exponents: in the formula
scale = torch.arange(0, self.head_dim, 2) / self.head_dimwith theta=10000: [1, 0.1, 0.01, 0.001] (roughly) high index → small freq → slow rotation → captures long-range
radians = 1.0 / self.theta ** scale
return radiansfor each position p and frequency f: angle = p * f
position 0 gets angles [0, 0, 0, 0], position 5 gets angles [5*f0, 5*f1, 5*f2, 5*f3]
then convert to complex:
these are unit vectors we'll multiply with to rotate.
also, we force float32 to avoid precision issues with bfloat16. see: HuggingFace PR #29285
@torch.no_grad()
def forward(self, x, position_ids):
with torch.autocast(device_type=x.device.type, enabled=False):
freqs = einsum(
self.inv_freq.float(), position_ids.float(),
"rotary_dim, batch seq -> batch seq rotary_dim"
)torch.polar(r, θ) gives
r=1 gives unit vectors.
freqs_cis = torch.polar(abs=torch.ones_like(freqs), angle=freqs)
return freqs_cis.to(dtype=x.dtype)Applying Rotation
now we rotate query and key vectors using the angles we computed. each pair of adjacent features is one 2D plane.
def apply_rope(query, key, freqs_cis):
def rotate(x):pair adjacent dims: [d0, d1, d2, d3, ...] -> [[d0, d1], [d2, d3], ...]
x_split = rearrange(x, "... (pairs two) -> ... pairs two", two=2)reinterpret as complex: [a, b] -> a + bi
x_complex = torch.view_as_complex(x_split.contiguous())rotate by multiplying with unit vector
x_rotated = x_complex * freqs_cis.unsqueeze(-2)back to real: a + bi -> [a, b]
return torch.view_as_real(x_rotated).flatten(-2)
query_out = rotate(query)
key_out = rotate(key)
return query_out.type_as(query), key_out.type_as(key)Alternative: Sin/Cos Version
most codebases (HuggingFace Transformers for example) skip complex numbers and use the rotation matrix directly:
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rope_explicit(x, cos, sin):
return (x * cos) + (rotate_half(x) * sin)