I deployed a loan-approval model. High accuracy. Fast predictions. Then a regulator asked: "Why did you reject this applicant?" I had no answer. The model was a black box. This article teaches three techniques that change that—letting you see what models focus on, discover hidden concepts they've learned, and control their behavior in real time.
Do You Actually Need Interpretability?
Before writing any code, be honest with yourself. Interpretability adds complexity—use it where it pays off.
- High-stakes decisions — loan approval, medical diagnosis, hiring: you need to explain why
- Bias detection — catch unintended discriminatory patterns before they cause harm
- Model alignment — understand if a model learned the right concepts, not shortcuts
- Debugging unexplained failures — find which features are causing wrong outputs
- Research — discover what representations models actually learn
- Low-stakes recommendations — Spotify playlists, TikTok feeds: nobody needs a justification
- Well-understood tasks — OCR, speech-to-text: model behavior is predictable and benchmarked
- Internal tooling — where failures don't affect users or compliance
- Rapid prototyping — interpretability is a production concern, not a Day 1 concern
Technique 1: Attention Maps
What Attention Actually Is
A transformer processes tokens by computing attention weights—for each position in the sequence, how much should it "attend to" every other position when producing output? These weights are learned, normalized (sum to 1), and computed across multiple "heads" per layer.
Input: "The cat sat on the mat"
↓
Tokenizer: ["The", "▁cat", "▁sat", "▁on", "▁the", "▁mat"]
↓
Layer 0, Head 0 attention matrix:
The cat sat on the mat
The [ 0.9 0.05 0.02 0.01 0.01 0.01 ]
cat [ 0.3 0.5 0.1 0.05 0.03 0.02 ]
sat [ 0.1 0.4 0.3 0.1 0.05 0.05 ]
on [ 0.05 0.1 0.35 0.4 0.05 0.05 ]
the [ 0.2 0.1 0.1 0.1 0.4 0.1 ]
mat [ 0.1 0.2 0.1 0.1 0.2 0.3 ]
→ Read row i as: "When generating token i, how much did the model
look at each other token?"
Different heads learn different roles: some copy the preceding token, some track named entities, some find subject-verb agreement. Visualizing many heads reveals the model's "reading strategy."
Setup and Code
pip install transformers torch matplotlib circuitsvis
# attention_maps.py
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
# GPT-2 requires no authentication — perfect for learning
# For Llama 3.2: run `huggingface-cli login` first, then use
# "meta-llama/Llama-3.2-1B-Instruct"
MODEL_NAME = "gpt2"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval()
def get_attentions(text: str) -> tuple:
"""
Run a forward pass and return:
- tokens: list of token strings
- attentions: tuple of (n_layers,) each shape [batch, heads, seq, seq]
"""
inputs = tokenizer(text, return_tensors="pt")
token_strs = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
return token_strs, outputs.attentions
def plot_attention_head(tokens: list, attentions: tuple, layer: int = 0, head: int = 0):
"""Plot a single attention head as a heatmap."""
attn = attentions[layer][0, head].cpu().numpy() # [seq, seq]
fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(attn, cmap="Blues", vmin=0, vmax=attn.max())
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=9)
ax.set_yticklabels(tokens, fontsize=9)
ax.set_title(f"Layer {layer}, Head {head}\n(row = query, col = key)")
ax.set_xlabel("Key (attended-to token)")
ax.set_ylabel("Query (attending token)")
plt.colorbar(im, ax=ax, label="Attention weight")
plt.tight_layout()
plt.savefig(f"attn_L{layer}H{head}.png", dpi=150)
plt.show()
print(f"Saved: attn_L{layer}H{head}.png")
def find_head_roles(tokens: list, attentions: tuple) -> dict:
"""
Heuristically classify what each head in layer 0 seems to do.
Returns dict of head_index -> description.
"""
n_heads = attentions[0].shape[1]
roles = {}
for h in range(n_heads):
attn = attentions[0][0, h].cpu().numpy() # [seq, seq]
seq_len = len(tokens)
# Check: does this head primarily attend to the PREVIOUS token?
prev_token_weight = np.mean([attn[i, i-1] for i in range(1, seq_len)])
# Check: does this head primarily attend to the SAME token?
self_weight = np.mean(np.diag(attn))
# Check: does attention spread across many tokens (global)?
entropy = -np.sum(attn * np.log(attn + 1e-9), axis=-1).mean()
if prev_token_weight > 0.4:
roles[h] = f"Head {h}: previous-token copy (prev_weight={prev_token_weight:.2f})"
elif self_weight > 0.5:
roles[h] = f"Head {h}: self-attention (self_weight={self_weight:.2f})"
elif entropy > 2.0:
roles[h] = f"Head {h}: global/distributed (entropy={entropy:.2f})"
else:
roles[h] = f"Head {h}: mixed pattern"
return roles
if __name__ == "__main__":
text = "The cat sat on the mat and looked at the dog"
tokens, attentions = get_attentions(text)
print(f"Model: {MODEL_NAME}")
print(f"Tokens: {tokens}")
print(f"Layers: {len(attentions)}, Heads per layer: {attentions[0].shape[1]}")
# Plot layer 0, heads 0-3
for head in range(4):
plot_attention_head(tokens, attentions, layer=0, head=head)
# Classify head roles
print("\nHeuristic head roles in Layer 0:")
roles = find_head_roles(tokens, attentions)
for desc in roles.values():
print(f" {desc}")
Interactive Visualization with CircuitsVis
For richer, interactive exploration (great in Jupyter notebooks):
from circuitsvis.attention import attention_heads
tokens, attentions = get_attentions("The quick brown fox jumps over the lazy dog")
# CircuitsVis renders an interactive HTML widget
# Each head is clickable; hover to see attention per token pair
attention_heads(
tokens=tokens,
attention=attentions[0][0].detach().numpy(), # Layer 0, all heads: [heads, seq, seq]
)
- Previous-token head: Strongly attends to the token just before — used for copying patterns
- BOS head: Always attends to the beginning-of-sequence token — provides global context
- Syntactic heads: Track subject-verb pairs, articles and nouns, prepositional phrases
- Semantic heads: Group semantically related tokens (e.g., "cat" ↔ "dog")
Technique 2: Sparse Autoencoders (SAEs)
The Superposition Problem
Language models pack more concepts than they have neurons by using superposition—concepts are encoded as directions in activation space, not individual neurons. A single neuron fires for "cat," "doctor," "the color blue," and "negative financial news" depending on context.
Sparse autoencoders solve this by learning a larger, sparser representation where each direction corresponds to exactly one interpretable concept.
Dense activations (512-dim):
[0.3, -0.7, 0.1, 0.9, ...] ← most neurons fire, meaning is entangled
SAE latent (2048-dim, sparse):
[0, 0, 0, 0, 0, 0, 1.8, 0, 0, 0, 0, 0, 0.5, 0, ...]
↑ ↑
"color concept" "comparison"
→ Only 2-5% of latent neurons fire. Each active neuron = one concept.
Training an SAE on GPT-2 Activations
# sae.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
# ── Step 1: Extract activations from a specific layer ──────────
def extract_activations(
model_name: str = "gpt2",
layer_idx: int = 6, # Middle layer of GPT-2 (12 layers)
n_samples: int = 5000,
batch_size: int = 8,
) -> torch.Tensor:
"""
Extract residual stream activations from a given layer.
Returns tensor of shape [n_tokens, hidden_dim].
"""
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model.eval()
# Load a text corpus (using Wikipedia for diverse vocabulary)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [t for t in dataset["text"] if len(t.strip()) > 50][:n_samples]
all_activations = []
hook_output = []
def hook_fn(module, input, output):
# output is (hidden_states, ...) — we want hidden_states
hidden = output[0] if isinstance(output, tuple) else output
hook_output.append(hidden.detach().cpu())
# Register hook on the target layer's output
hook = model.transformer.h[layer_idx].register_forward_hook(hook_fn)
for i in range(0, min(len(texts), n_samples), batch_size):
batch_texts = texts[i : i + batch_size]
inputs = tokenizer(
batch_texts, return_tensors="pt", truncation=True,
max_length=128, padding=True
)
with torch.no_grad():
model(**inputs)
# Flatten: [batch, seq, hidden] → [batch*seq, hidden]
for act in hook_output:
all_activations.append(act.reshape(-1, act.shape[-1]))
hook_output.clear()
if i % 100 == 0:
print(f" Extracted {i}/{min(len(texts), n_samples)} samples")
hook.remove()
activations = torch.cat(all_activations, dim=0)
print(f"Activation shape: {activations.shape}")
return activations
# ── Step 2: Define SAE architecture ────────────────────────────
class SparseAutoencoder(nn.Module):
"""
Sparse Autoencoder for mechanistic interpretability.
Key design choices:
- Encoder: single linear layer + ReLU (no bias in encoder promotes sparsity)
- Decoder: linear layer with L2-normalized columns
- Loss: MSE reconstruction + L1 sparsity penalty
"""
def __init__(self, input_dim: int, latent_dim: int):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder = nn.Linear(input_dim, latent_dim, bias=False)
self.decoder = nn.Linear(latent_dim, input_dim, bias=True)
# Initialize decoder columns to be unit vectors
with torch.no_grad():
self.decoder.weight.data = F.normalize(
self.decoder.weight.data, dim=0
)
def encode(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(self.encoder(x))
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def forward(self, x: torch.Tensor) -> tuple:
z = self.encode(x)
x_hat = self.decode(z)
return x_hat, z
def normalize_decoder(self):
"""Renormalize decoder columns after each optimizer step."""
with torch.no_grad():
self.decoder.weight.data = F.normalize(
self.decoder.weight.data, dim=0
)
def sae_loss(
x: torch.Tensor,
x_hat: torch.Tensor,
z: torch.Tensor,
sparsity_coeff: float = 1e-3,
) -> tuple:
recon_loss = F.mse_loss(x_hat, x)
sparsity_loss = z.abs().mean()
total_loss = recon_loss + sparsity_coeff * sparsity_loss
return total_loss, recon_loss, sparsity_loss
# ── Step 3: Train SAE ───────────────────────────────────────────
def train_sae(
activations: torch.Tensor,
input_dim: int,
expansion_factor: int = 4, # latent_dim = 4 × input_dim
n_epochs: int = 10,
batch_size: int = 512,
lr: float = 1e-3,
sparsity_coeff: float = 1e-3,
) -> SparseAutoencoder:
latent_dim = input_dim * expansion_factor
sae = SparseAutoencoder(input_dim, latent_dim)
optimizer = torch.optim.Adam(sae.parameters(), lr=lr)
# Normalize activations (important for stable training)
mean = activations.mean(dim=0, keepdim=True)
std = activations.std(dim=0, keepdim=True) + 1e-8
activations_norm = (activations - mean) / std
dataset = TensorDataset(activations_norm)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(n_epochs):
total_loss_sum = recon_sum = sparse_sum = 0
for (batch,) in loader:
optimizer.zero_grad()
x_hat, z = sae(batch)
loss, recon, sparse = sae_loss(batch, x_hat, z, sparsity_coeff)
loss.backward()
optimizer.step()
sae.normalize_decoder()
total_loss_sum += loss.item()
recon_sum += recon.item()
sparse_sum += sparse.item()
sparsity_rate = (z.abs() > 0.01).float().mean().item()
print(
f"Epoch {epoch+1}/{n_epochs} | "
f"Loss: {total_loss_sum/len(loader):.4f} | "
f"Recon: {recon_sum/len(loader):.4f} | "
f"Sparsity rate: {sparsity_rate:.2%}"
)
return sae
Discovering What Each Feature Represents
Once trained, you can find the most activating text segments for any latent feature:
def discover_features(
sae: SparseAutoencoder,
activations: torch.Tensor,
tokens_per_sample: list, # list of token strings for each activation row
top_k_features: int = 10,
top_k_samples: int = 5,
):
"""
For each of the top-K most-used features, print the text segments
that most strongly activate it.
"""
sae.eval()
with torch.no_grad():
z = sae.encode(activations)
# Find features by mean activation strength (excluding zero)
mean_activations = z.mean(dim=0)
top_feature_ids = mean_activations.argsort(descending=True)[:top_k_features]
print(f"\nTop {top_k_features} features and their activating tokens:\n")
for feat_id in top_feature_ids:
feat_id = feat_id.item()
# Get samples where this feature is strongest
feature_acts = z[:, feat_id]
top_sample_ids = feature_acts.argsort(descending=True)[:top_k_samples]
print(f"Feature {feat_id} (mean activation: {mean_activations[feat_id]:.3f}):")
for sample_id in top_sample_ids:
act_val = feature_acts[sample_id].item()
token = tokens_per_sample[sample_id] if sample_id < len(tokens_per_sample) else "?"
print(f" [{act_val:.2f}] '{token}'")
print()
- Features corresponding to specific people (Elon Musk, Abraham Lincoln)
- Features for abstract concepts (power, justice, freedom)
- Features for syntactic roles (subject of a sentence, verb phrase)
- Multi-token "concept features" that span several words
- Safety-relevant features (deception, harm, political bias)
Technique 3: Steering Vectors
The Concept
Representation engineering shows that many high-level concepts (positive sentiment, formality, honesty) are encoded as directions in the model's activation space. If you can find that direction, you can add a vector along it to push the model's outputs in that direction—without touching the weights.
Activation space (simplified to 2D):
← informal
"Hey! So pumped for this!" +
↑ activation
| positive drift
↓
"I am delighted to be here."
→ formal
The DIRECTION from informal-negative → formal-positive
is the "positive formality" steering vector.
Add it to activations × multiplier:
multiplier = 0.0 → original response
multiplier = 1.0 → noticeably more formal/positive
multiplier = 2.0 → exaggeratedly formal/positive
multiplier = -1.0 → steered toward informal/negative
Implementation with the steering-vectors Library
The steering-vectors library (by David Chanin) provides a clean API for this:
pip install steering-vectors
# steering.py
from steering_vectors import train_steering_vector, SteeringVector
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_NAME = "gpt2"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model.eval()
# ── Step 1: Prepare contrastive training pairs ─────────────────
#
# Each pair is (positive_example, negative_example).
# Positive = the trait you want to steer TOWARD.
# Negative = the opposite trait.
# More pairs → more robust vector (50-200 is ideal).
SENTIMENT_PAIRS = [
("I absolutely loved the conference. The speakers were inspiring and insightful.",
"I absolutely hated the conference. The speakers were boring and pointless."),
("This is one of the best books I've ever read. Couldn't put it down.",
"This is one of the worst books I've ever read. Couldn't get through it."),
("The team did outstanding work this quarter. I'm incredibly proud.",
"The team did terrible work this quarter. I'm incredibly disappointed."),
("After the dinner yesterday, I felt satisfied and energized.",
"After the dinner yesterday, I felt sick and drained."),
("I'm genuinely excited about this project's potential.",
"I'm genuinely worried about this project's future."),
("The new feature works flawlessly. Users will love it.",
"The new feature is broken. Users will hate it."),
("My week was productive and fulfilling.",
"My week was exhausting and pointless."),
("I'm optimistic about the results we'll see next month.",
"I'm pessimistic about the results we'll see next month."),
]
# ── Step 2: Train the steering vector ─────────────────────────
print("Training sentiment steering vector...")
steering_vector: SteeringVector = train_steering_vector(
model=model,
tokenizer=tokenizer,
training_samples=SENTIMENT_PAIRS,
layers=list(range(6, 10)), # Extract from middle layers (6-9 of 12)
show_progress=True,
)
print(f"Steering vector trained on {len(SENTIMENT_PAIRS)} pairs")
# ── Step 3: Test with different multipliers ────────────────────
def generate_with_steering(
prompt: str,
multiplier: float = 0.0,
max_new_tokens: int = 50,
) -> str:
"""Generate text with optional steering vector applied."""
inputs = tokenizer(prompt, return_tensors="pt")
if multiplier == 0.0:
# Baseline: no steering
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
else:
# Apply steering vector during generation
with steering_vector.apply(model, multiplier=multiplier):
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(output[0], skip_special_tokens=True)
# Return only the new tokens
return generated[len(prompt):]
if __name__ == "__main__":
prompt = "After the dinner yesterday, I felt"
for multiplier in [-1.5, -1.0, 0.0, 1.0, 1.5, 2.0]:
response = generate_with_steering(prompt, multiplier=multiplier)
print(f"multiplier={multiplier:+.1f} | ...{response[:80]}")
Typical output:
multiplier=-1.5 | ...terrible. My stomach was in knots. I couldn't sleep at all.
multiplier=-1.0 | ...uneasy. Something about the evening just didn't sit right.
multiplier= 0.0 | ...tired. It had been a long day and I needed rest.
multiplier=+1.0 | ...satisfied and content. The food was excellent.
multiplier=+1.5 | ...wonderful and grateful. What an extraordinary evening!
multiplier=+2.0 | ...absolutely ecstatic! The most incredible meal of my life!
Building Your Own Steering Vectors
The same approach works for other traits—just swap the contrastive pairs:
Negative: "Sure thing, happy to help out!"
Negative: "I know for a fact that..." [incorrect claim]
Negative: "Here are the step-by-step instructions..."
Negative: "AI looks at words and figures out what's important..."
Practical Workflow: Combining All Three
Here's how the three techniques fit together in a real debugging scenario:
Scenario: Your loan model is rejecting minority applicants at higher rates.
Step 1 — ATTENTION MAPS
→ Visualize what the model attends to for rejected vs. approved cases
→ Finding: Model heavily attends to zip code and name fields
→ Signal: Potential proxy discrimination via location/name
Step 2 — SPARSE AUTOENCODERS
→ Train SAE on loan model's layer activations
→ Discover features
→ Finding: Feature #847 strongly activates on neighborhood names
associated with minority communities
→ Signal: Model learned a "demographic proxy" feature
Step 3 — STEERING VECTORS
→ Train a "demographic-neutral" steering vector
→ Contrastive pairs: applications where only demographic proxies differ
→ Apply vector to push model away from demographic-influenced activations
→ Validate: rejection rates equalize across demographic groups
This is the workflow Anthropic and other AI safety organizations use—observe → diagnose → intervene.
Limitations: What Interpretability Can and Can't Tell You
- Attention ≠ explanation. High attention weight between two tokens doesn't mean the model "used" that relationship to produce its output—correlation, not causation. (See: Jain & Wallace, 2019, "Attention is not Explanation.")
- SAE features are statistical, not semantic. A feature labeled "colors" might actually encode something more subtle. Human interpretation of feature labels is inherently approximate.
- Steering vectors can be unstable at high multipliers. Beyond ±2.0 you often get degenerate outputs—repetition, incoherence. Always validate in a held-out set.
- These techniques don't fully generalize. SAEs trained on GPT-2 won't transfer to Llama without retraining. Even same-architecture models at different scales have different feature geometries.
Common Mistakes
Performance Reference
Key Takeaways
output_attentions=True in a forward pass, not in generate(). Use circuitsvis for interactive exploration in notebooks. GPT-2 works with no authentication.steering-vectors library. Collect 50–200 contrastive pairs, apply to middle layers (40–70% of model depth), and keep multipliers in the ±1.5 range to avoid instability.