Skip to main content
2025.12.04173 words

Rotary Position Embeddings (RoPE)

Transformers have no built-in notion of order – they see tokens as a set, not a sequence. RoPE fixes this by encoding position into attention itself.

The core idea: rotate vector A by angle α\alpha and vector B by angle β\beta – their dot product depends on (αβ\alpha - \beta). Rotate each token's q/k by an angle proportional to its position, and attention scores will reflect relative distance between tokens.

Token 3 and token 5? The angle difference is the same as between token 10 and token 12.

Why 2D? 1D is just scaling, not rotation. Higher-D needs rotation matrices and doesn't really help. 2D hits the sweet spot: real rotation, simple complex math, and each pair gets its own frequency.

High frequencies pick up nearby tokens, low frequencies pick up distant ones.

I use complex numbers because rotating (x,y)(x, y) by θ\theta is just multiplying (x+iy)(x + iy) by eiθe^{i\theta}. Same math, cleaner code (at least for me :D).

I also use einops – easier to follow the tensor shapes.


01

Implementation

from dataclasses import dataclass

import torch
from einops import einsum, rearrange
from torch import nn
02

Config

@dataclass
class Config:
03

head_dim must be even – we pair up dimensions for 2D rotation

    head_dim: int = 8
04

base frequency. 10_000 is standard. higher -> slower decay -> longer context

    rope_theta: float = 10_000
05

Computing Frequencies

Each pair of dimensions rotates at a different frequency. Pair 0 rotates fast, pair 1 slower, pair 2 even slower, etc.

class RoPE(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.theta = config.rope_theta
        self.head_dim = config.head_dim
06

inv_freq = the rotation speeds for each dimension pair

        inv_freq = self._compute_inverse_frequencies()
07

register_buffer: not a learnable param, but moves to GPU with the model persistent=False: don't save to state_dict (we recompute on load)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
08

Computing Inverse Frequencies

RoPE uses different rotation speeds for each dimension pair. Low-index pairs rotate fast → sensitive to nearby positions. High-index pairs rotate slow → sensitive to distant positions.

    def _compute_inverse_frequencies(self):
09

for head_dim=8: [0, 2, 4, 6] / 8 = [0, 0.25, 0.5, 0.75] these are the exponents: 2i/d2i/d in the formula fi=1/θ2i/df_i = 1/\theta^{2i/d}

        scale = torch.arange(0, self.head_dim, 2) / self.head_dim
10

with theta=10000: [1, 0.1, 0.01, 0.001] (roughly) high index → small freq → slow rotation → captures long-range

        radians = 1.0 / self.theta ** scale
        return radians
11

for each position p and frequency f: angle = p * f

position 0 gets angles [0, 0, 0, 0], position 5 gets angles [5*f0, 5*f1, 5*f2, 5*f3]

then convert to complex:

eiθ=cosθ+isinθe^{i\theta} = \cos\theta + i\sin\theta

these are unit vectors we'll multiply with to rotate.

also, we force float32 to avoid precision issues with bfloat16. see: HuggingFace PR #29285

    @torch.no_grad()
    def forward(self, x, position_ids):
        with torch.autocast(device_type=x.device.type, enabled=False):
            freqs = einsum(
                self.inv_freq.float(), position_ids.float(),
                "rotary_dim, batch seq -> batch seq rotary_dim"
            )
12

torch.polar(r, θ) gives reiθr \cdot e^{i\theta}

r=1 gives unit vectors.

            freqs_cis = torch.polar(abs=torch.ones_like(freqs), angle=freqs)
        return freqs_cis.to(dtype=x.dtype)
13

Applying Rotation

now we rotate query and key vectors using the angles we computed. each pair of adjacent features is one 2D plane.

def apply_rope(query, key, freqs_cis):
    def rotate(x):
14

pair adjacent dims: [d0, d1, d2, d3, ...] -> [[d0, d1], [d2, d3], ...]

        x_split = rearrange(x, "... (pairs two) -> ... pairs two", two=2)
15

reinterpret as complex: [a, b] -> a + bi

        x_complex = torch.view_as_complex(x_split.contiguous())
16

rotate by multiplying with unit vector

        x_rotated = x_complex * freqs_cis.unsqueeze(-2)
17

back to real: a + bi -> [a, b]

        return torch.view_as_real(x_rotated).flatten(-2)

    query_out = rotate(query)
    key_out = rotate(key)
    return query_out.type_as(query), key_out.type_as(key)
18

Alternative: Sin/Cos Version

most codebases (HuggingFace Transformers for example) skip complex numbers and use the rotation matrix directly:

x=xcosθysinθx' = x \cos\theta - y \sin\theta y=xsinθ+ycosθy' = x \sin\theta + y \cos\theta
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rope_explicit(x, cos, sin):
    return (x * cos) + (rotate_half(x) * sin)