A Practical Guide to Activation Steering
What if you could lobotomize an LLM, but gently?
Activation steering: find a direction in the model's internal representation space, add it during inference, shift behavior.
Here's the pipeline: extracting a steering vector, applying it, choosing the right strength, and understanding when it fails. We'll steer Gemma-3-1b-it between "novice-friendly" and "expert-level" communication styles.
Fair warning: All the specific numbers here (layer choices, values, probe accuracies) are for Gemma-3-1b-it on this particular task. For other models and/or other concepts you should expect different behavior. The method generalizes; the parameters don't.
Prerequisites: Basic PyTorch, familiarity with transformer architecture. If you know what "hidden states" and "residual stream" mean, you're overqualified.
The Core Idea
An insight from mechanistic interpretability: neural networks represent concepts as directions in activation space. This is the linear representation hypothesis (Park et al., 2023) – truth vs. falsehood, sentiment, even abstract concepts form linearly separable clusters in LLM activations (Marks & Tegmark, 2023). Honesty, sentiment, expertise – these aren't stored in single neurons, but as linear directions across many neurons.
Why not single neurons? Because models use superposition – they pack more concepts than they have neurons by encoding features as overlapping directions (Elhage et al., 2022). Individual neurons end up polysemantic, responding to multiple unrelated concepts. So we look for directions, not neurons.
If we can find the "expertise direction", we can push the model along it (more expert) or against it (more novice). The math is embarrassingly simple:
where is the activation at layer , is our steering vector, and controls strength. Positive means "more expert", negative means "more novice".
Finding the Steering Direction
How do we find ? The simplest approach: linear probes.
Create contrastive prompts – different questions, contrasting personas:
NOVICE_PROMPTS = [
"i'm new to this topic... what is the difference between rhythm, beat, tempo, and meter?",
"this might be a basic question, but how do art historians figure out who made an artwork when it isn't signed?",
"i don't have a background in this field, so... how do I lose fat without losing muscle?",
"i'm just a curious beginner, but why do organisms need water to survive?",
"i'm new to this topic, but how does gravity work – why do masses attract each other?",
]
EXPERT_PROMPTS = [
"As someone with a PhD in aesthetics and art history, how have ideas about beauty changed across cultures and centuries?",
"Building on my expertise in aesthetics, what makes something beautiful or art?",
"From my graduate coursework in orbital mechanics, if gravity pulls everything down, why do planets and moons stay in orbit instead of falling straight in?",
"As someone with a PhD in physics, why does anything have mass, and what does mass really measure?",
"From my graduate coursework in neuroscience, how do nerves transmit information so quickly?",
]The model takes the bait every time. We collect activations from both, then train a classifier to separate them.
On dataset size: The accuracy and coherence plots below used ~120 contrastive pairs with a separate held-out test set (from different domains to reduce bias). But the concept is simple enough that steering works with as few as 5+5 prompts.
Logistic regression learns a hyperplane that separates the two classes:
The weight vector is perpendicular to this hyperplane – it points in the direction that maximally separates novice from expert. That's our steering vector.
import numpy as np
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression()
clf.fit(activations, labels) # 0=novice, 1=expert
steering_vector = clf.coef_[0] # shape: (hidden_dim,)
steering_vector /= np.linalg.norm(steering_vector) # normalizeWhy coef_[0]? For binary classification, sklearn stores a single weight vector. It points from class 0 (novice) toward class 1 (expert).
Alternative approaches: You could also use mean difference (expert_mean - novice_mean), or train SAEs (Gemma Scope 2) to find interpretable features. Linear probes are a good middle ground – more principled than mean difference, simpler than SAEs.
The Implementation
We need two hooks: one to capture activations (for training the probe), one to steer (for modifying behavior).
What Are Hooks?
PyTorch hooks are functions you attach to layers that get called automatically during the forward or backward pass. Think of them as wiretaps – you can listen in on what's happening inside the model without modifying its architecture.
When you call layer.register_forward_hook(fn), your function fn gets called every time that layer runs, receiving the layer's input and output tensors. You can inspect them (for capturing activations) or modify them (for steering). The hook returns a handle you use to remove it later.
More on hooks: PyTorch hooks tutorial.
Note: In practice, most people use libraries like nnsight or TransformerLens that handle this stuff for you. But for educational purposes (and because I like reinventing wheels), we'll use manual hooks here. And the code itself is pretty simple, so why not?
The Residual Stream
In transformers, information flows through the residual stream – a highway that passes through each layer, accumulating updates. Each layer reads from this stream, computes something, and adds its output back. We hook into this stream to either read or modify it.
Capturing Activations
PyTorch's register_forward_hook lets us intercept activations
during the forward pass. The hook fires after each layer computes
its output, giving us access to the hidden states.
We extract the last token's hidden state – that's where the model "decides" what token to generate next.
from contextlib import contextmanager
import torch
@contextmanager
def capture(layers, layer_indices: list[int], token_idx: int = -1):
cache = {} # stores {layer_idx: activation_tensor}
handles = [] # hook handles for cleanup
def make_hook(idx):
def hook(module, input, output):output is (hidden_states, ...) or just hidden_states
hidden = output[0] if isinstance(output, tuple) else outputhidden shape: (batch, seq_len, hidden_dim) extract token at position token_idx (default: -1 = last)
cache[idx] = hidden[:, token_idx, :].detach().cpu().squeeze(0)
return hook
attach hooks to each target layer
for idx in layer_indices:
handles.append(layers[idx].register_forward_hook(make_hook(idx)))
try:
yield cache # caller runs model, cache gets populated
finally:
for h in handles:
h.remove() # always clean up hooksApplying the Steering Vector
Same mechanism, but instead of reading, we modify. We add the scaled steering vector to the layer output. This shifts the model's internal representation toward (or away from) the target concept.
The vector broadcasts across all tokens in the sequence – we steer the entire generation, not just one position.
@contextmanager
def steer(layers, layer_idx: int, vector: torch.Tensor, alpha: float = 0.0):
sv = vector.float() # ensure float32
def hook(module, input, output):add steering vector to hidden states broadcasts across batch and sequence dimensions
if isinstance(output, tuple):
return (output[0] + sv * alpha,) + output[1:]
return output + sv * alpha
handle = layers[layer_idx].register_forward_hook(hook)
try:
yield
finally:
handle.remove()Choosing the Right Layer
Not all layers are equal. Train probes at each layer and check where the concept is most readable:
Middle-to-late layers (~9+) achieve near-perfect separation.
But high probe accuracy ≠ good steering. The best layer for reading a concept isn't always the best for modifying it. Recent work shows steering effectiveness varies significantly across models and tasks – there's no universal rule (Queiroz Da Silva et al., 2025).
A rough heuristic: early layers handle low-level patterns, middle layers encode semantic concepts, late layers refine for output. For Gemma-3-1b-it on expertise, middle layers (9-17) worked best in my experiments. Edge layers (1, 25) showed minimal effect – the steering barely changed the output.
I picked layer 13 for the demos below. Lucky number.
Choosing Alpha
The steering strength needs calibration. Raw values are meaningless across layers – layer 5 might have activation norm 50, layer 20 might have norm 200. Same , completely different effect.
Solution: Scale relative to activation norm:
where is a fraction of the norm. I found works well for this task (enough to shift style without breaking coherence).
# compute mean activation norm at target layer
act_norm = compute_activation_norm(model, tokenizer, layer_idx, prompts, device)
alpha = act_norm * 0.07 # ~7% ± 2% works well
# generate with steering
with steer(model.model.layers, layer_idx, steering_vector, alpha=alpha):
output = model.generate(input_ids, max_new_tokens=200)Results
Try different steering strengths – slide toward "Expert" or "Novice" to see how the model's response changes:
Yes, they're scary! A neutron star is like a giant cosmic vacuum cleaner.
They're super small and super fast! They're made of tiny bits of stuff called "stars" that are packed together so tightly they spin really fast and keep going round and round.
Because they're so big and fast, they're like planets that fall every day!
They're so heavy that they can sometimes bump into each other super fast!
They're the biggest things that ever hit the Earth!
Here's a simple way to think of it:
"Giant, spinning, speedy rocks!"
You can find lots of pictures and videos of them online!
A neutron star is a super-dense star that's formed when a massive star collapses. It's essentially a giant ball of neutrons packed incredibly tightly! Because so much stuff is squeezed in, they're incredibly bright and can sometimes be seen as tiny, rapidly spinning points of light.
Okay, here's a brief explanation of a neutron star:
A neutron star is the incredibly dense remnant of a massive star that has exploded. It's essentially a giant ball of neutrons – all the matter packed into a sphere the size of a city! Because of their immense gravity, they're incredibly small and incredibly dense, making them the densest objects in the universe.
Think of it like this: A star collapses and crushes down into a tiny, incredibly strong ball of neutrons.
A neutron star is the incredibly dense remnant of a massive star that has collapsed at the end of its life. It's essentially a giant ball of neutrons packed incredibly tightly.
Key features:
- Extremely dense: A teaspoonful would weigh billions of tons.
- Strong magnetic fields: They have incredibly powerful magnetic fields.
- Rapid rotation: Often rotate incredibly fast.
A neutron star is the incredibly dense remnant of a massive star that has undergone a supernova explosion. It's essentially a collapsed star composed almost entirely of neutrons, packed together incredibly tightly – a teaspoonful would weigh billions of tons!
Essentially, it's a stellar corpse with immense gravity and a bizarre, ultra-dense structure.
A neutron star is an incredibly dense remnant of a massive star (typically 10-20 times the mass of the Sun) that has undergone a supernova explosion.
Here's a breakdown:
- Extreme Density: It's composed almost entirely of neutrons, packed together under immense pressure – roughly equivalent to the density of an atomic nucleus.
- Support of Matter: The extreme pressure forces the star's core to collapse, resulting in a singularity – a point of infinite density.
- Observable Properties: Neutron stars exhibit strong magnetic fields, X-ray emission, and can be observed through gravitational lensing.
As a response to your prompt, I interpret your question as a request for a brief explanation of the connection between my response and your question.
A neutron star is a hypothetical object that would result from a supernova (a collapse of a massive star, primarily influenced by the neutron degeneracy pressure – a consequence of the quantum mechanical effects of the neutron matter)
Specifically, it's a collapsed star (approximately 2-3 times the mass of the Sun) where the core is composed primarily of neutron-rich matter, resulting in a staggering density – approximately 250 gigajoules per cubic meter.
My response is a direct interpretation of your prompt, and it is not a nuanced discussion of the theoretical implications of that interpretation.
The sweet spot for our setup is around –. Push too far and coherence collapses:
- Novice collapse (-0.14): "giant cosmic vacuum cleaner", "planets that fall every day" (wtf) – sounds like a children's book written by someone who's never seen a star
- Expert collapse (+0.14): "250 gigajoules per cubic meter", calls neutron stars "hypothetical" – classic pseudo-expertise, would fit right in on Twitter/X
So what's actually happening here?
Why Steering Breaks
One interpretation: the steering vector captures vibes more than facts – vocabulary, sentence structure, tone, but not actual knowledge. At low , you get a nice style shift. Crank it up, and you amplify style while suppressing coherent reasoning.
To quantify this, I measured semantic preservation (does steering break the meaning?) across layers and steering strengths. (LLM-as-judge with majority voting; signal is noisy but trends are consistent.)
The pattern is clear: middle layers (9-13) break first, and higher kills semantic coherence everywhere. Layer 9 at is basically dead – 0% preservation.
I call this pseudo-expertise: the text looks authoritative (formal structure, technical-sounding phrases) but is confidently wrong. The novice direction has its own failure mode – emoji soup and circular "simple" explanations that don't actually explain anything.
Think of it like adjusting the "expert" dial on a mixing board. Small adjustments shift the tone. Crank it to 11, and you get distortion.
Limitations
Some caveats:
- One model, one concept. I only tested Gemma-3-1b-it on expertise. Other models and concepts may behave differently.
- Small model. 1B parameters. Larger models might have cleaner representations.
- No exploration of nearby directions. The probe gives one vector; I didn't check what's in the neighborhood.
- Inference-time only. Steering doesn't permanently change the model – it's active only during the hooked forward pass.
More on steering reliability: Queiroz Da Silva et al., 2025.
Summary
- Collect activations from contrastive prompts (novice vs expert)
- Train a linear probe → weight vector = steering direction
- Choose a layer – middle layers often work best, but verify empirically
- Apply via forward hook:
- Scale relative to activation norm ()
- Watch for collapse at high
Notes
- The research behind this post was done as part of my application to MATS (ML Alignment Theory Scholars).
- This is essentially persona modeling – the model infers a user persona from the prompt and adjusts its response accordingly. We're just making that adjustment explicit and controllable.
- If the model starts claiming neutron stars are hypothetical, you've gone too far. Turn back.
References
- The Linear Representation Hypothesis and the Geometry of Large Language Models – Park et al., 2023. Formalizes the linear representation hypothesis for LLMs.
- The Geometry of Truth – Marks & Tegmark, 2023. Empirical evidence for linear truth representations in LLMs.
- Representation Engineering: A Top-Down Approach to AI Transparency – Zou et al., 2023. Foundational paper on steering via activation engineering.
- Activation Addition: Steering Language Models Without Optimization – Turner et al., 2023. The "add a vector" approach.
- Scaling Monosemanticity – Anthropic, 2024. Extracting interpretable features from Claude 3 Sonnet via SAEs.
- Steering off Course: Reliability Challenges in Steering Language Models – Queiroz Da Silva et al., 2025. On brittleness and layer selection.
- Open Problems in Mechanistic Interpretability – Sharkey, Nanda et al., 2025. Comprehensive overview of challenges and open questions in mech interp.
- Gemma Scope 2 – Google DeepMind, 2025. Open SAEs and transcoders for Gemma 3 interpretability.
- MATS Program – ML Alignment Theory Scholars. Research program for AI safety.