AI Engineering · Token Optimization
BERT Prompt Compressor
"Stop using a Ferrari to polish another Ferrari. Use a 2018 bicycle — and beat the Ferrari."
Most teams solving the "prompts are too long" problem reach for another LLM to rewrite them shorter. That feels intuitive. It's also paying full API rates to solve a cost problem — using a Ferrari to polish another Ferrari.
There's a better path: BERT attention scores, a few smart rules, and zero extra model calls. Here's the full story.
The Token Tax: Why This Actually Matters
Before we build anything, let's understand the cost model you're working against.
Every major LLM charges per token — both directions.
LLM Pricing — March 2026 (per 1M tokens)
| Model | Input | Output | 50% Compression Saves |
|---|---|---|---|
| Claude Sonnet 4.5 | $3.00 | $15.00 | $1.50/M input tokens |
| GPT-4o | $2.50 | $10.00 | $1.25/M input tokens |
| Gemini 2.5 Pro | $1.25 | $10.00 | $0.625/M input tokens |
| Claude Opus 4.6 | $5.00 | $25.00 | $2.50/M input tokens |
Now let's make it concrete. Say your app sends 1,000 requests/day, each with a 2,000-token prompt (a realistic system prompt + context):
Without compression (Claude Sonnet 4.5)
1,000 req × 2,000 tokens × $3/1M
With 50% compression (same model)
1,000 req × 1,000 tokens × $3/1M
Calculations use Claude Sonnet 4.5 input pricing. Scale linearly with request volume. Output tokens not included.
But cost is only half the problem.
The Hidden Tax: Quadratic Attention Complexity
Here's what most people miss about long prompts — the cost isn't just financial.
Inside every LLM, the attention mechanism computes a relationship between every pair of tokens. Every token looks at every other token.
Attention Operations Grow Quadratically with Prompt Length
Self-attention complexity is O(n²·d). Double the prompt → 4× the computation. That's slower responses, higher infrastructure costs, and more carbon emissions on top of the billing cost.
The math behind this: Each token computes Q·Kᵀ dot products against every other token. For a 400-token prompt, that's 400 × 400 = 160,000 pair comparisons per attention head. At 96 heads across layers in a large model, that's over 15 million operations — before a single output token is generated.
Half the tokens → quarter the attention work. That compounds across every request.
The Compression Landscape: What Exists and Where We Fit
Prompt compression is a real, active research field. Let's be honest about the trade-offs.
BERT: The 2018 Model Still Pulling Its Weight in 2026
Before we build, let's understand the tool we're using.
BERT (Bidirectional Encoder Representations from Transformers) was published by Google in October 2018 (arXiv:1810.04805) and has since become one of the most cited papers in NLP history.
BERT-base Architecture
Why BERT for Compression
How BERT's Attention Mechanism Works (The Part We Actually Use)
Inside each of BERT's 12 layers, every token generates three vectors:
"What am I looking for?" — each token broadcasts its search intent
"What do I offer?" — each token announces what it contains
"What do I actually contribute?" — the information passed forward
The attention score between token i and token j:
Attention(Q, K, V) = softmax(QKᵀ / √64) × V
Q·Kᵀ measures how well query matches key. √64 prevents vanishing gradients. Softmax turns scores into probabilities. V is the weighted information sum.
This runs in parallel across all 12 heads × 12 layers = 144 distinct attention patterns. Each head learns to specialize: some track grammatical subject→verb links, some track co-reference, some capture positional relationships.
The key insight for compression: A token that receives high attention from many other tokens is one the model considers important context. Tokens that are largely ignored (low received attention) are candidates for removal.
Our Architecture: The 4-Step Pipeline
Here's the complete pipeline — from raw prompt to compressed output in a single BERT pass.
BERT Prompt Compression Pipeline
Tokenize with BERT tokenizer · Run spaCy NER to tag entities (Egypt, Elon Musk, FastAPI…)
One BERT forward pass · Layer-weighted average across 144 attention patterns · Find Key Token (highest avg attention received) · Score every token by Q/K connection strength
final_score = 0.6×attention + 0.4×key_connection · ×2.0 bonus for named entities · ×0.1 penalty for stopwords · Keep top-N% by score
Re-join kept tokens in original order · Merge subwords back to whole words · Output shorter prompt with identical word order
Step-by-Step Deep Dive
Step 1: Preprocessing
import spacy
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp = spacy.load('en_core_web_sm')
prompt = "Do you happen to have details about what countries are located near Egypt?"
inputs = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True)
# → [CLS] do you happen to have details about what countries are located near egypt ? [SEP]
# Named Entity Recognition
doc = nlp(prompt)
entities = {token.text.lower() for ent in doc.ents for token in ent}
# → {'egypt'} ← this will get a 2x attention boost
Step 2: Attention Scoring — The Core
This is where the work happens. One forward pass through BERT gives us 144 attention matrices (12 heads × 12 layers). We use them to score every token.
import torch
# One forward pass — no gradient computation needed
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions: tuple of 12 tensors, each [1, 12_heads, seq_len, seq_len]
attentions = torch.stack(outputs.attentions) # [12_layers, 1, 12_heads, seq, seq]
# Weight later layers more heavily — they encode semantic meaning
# Layer 1 gets weight 0.5, Layer 12 gets weight 1.5 (linearly scaled)
layer_weights = torch.linspace(0.5, 1.5, 12)
layer_weights /= layer_weights.sum() # normalize to sum to 1
# Weighted average across layers, then average across heads
# Result: [seq_len, seq_len] — attention FROM each token TO each token
weighted_attn = sum(
layer_weights[i] * attentions[i, 0].mean(dim=0) # avg across 12 heads
for i in range(12)
)
# Token importance: average attention RECEIVED from all other tokens
# High score = other tokens pay a lot of attention to this one
token_importance = weighted_attn.mean(dim=0) # [seq_len]
Why weight later layers more? Research on BERT's internal representations ("What Does BERT Look At?", Clark et al., 2019) shows that early layers capture surface-level syntax patterns, while later layers encode abstract semantic relationships. For compression, semantic importance is what we care about.
Why average received attention? A token that many other tokens "look at" is one the model needs to understand the text. A stopword like "the" that few tokens reference can often be dropped safely.
Step 3: Token Scoring Formula
# Find the "Key Token" — the single word with the highest received attention
# (excluding [CLS] and [SEP] delimiter tokens)
content_positions = [
i for i, tid in enumerate(inputs['input_ids'][0].tolist())
if tid not in (tokenizer.cls_token_id, tokenizer.sep_token_id)
]
key_idx = max(content_positions, key=lambda i: float(token_importance[i]))
key_token_attn = weighted_attn[key_idx] # attention FROM the key token TO others
# Score formula for each word:
# 60% — general importance (avg attention received from everyone)
# 40% — how much the Key Token attends to this word
# ×2.0 — if the word is a named entity (Egypt, FastAPI, etc.)
# ×0.1 — if the word is a stopword (the, a, of, is...)
def score_word(attn_score, key_conn, word, entities, stopwords):
score = 0.6 * attn_score + 0.4 * float(key_token_attn[word_idx])
if word.lower() in entities:
score *= 2.0 # entity boost — always keep important names
if word.lower() in stopwords and word.lower() not in entities:
score *= 0.1 # stopword penalty — drop filler words
return score
Why 60/40? General importance tells you if a word matters in the overall context. Key-token connection tells you if it's relevant to the central concept. The 60/40 split is empirical — you can tune this for your domain.
Step 4: Reconstruction
# Keep the top N% words by score
ratio = 0.5 # keep 50% of words
n_keep = max(3, int(len(word_scores) * ratio))
# Sort by score, take top N
top_indices = sorted(word_scores, key=lambda i: word_scores[i]['score'], reverse=True)[:n_keep]
# CRITICAL: restore original word order (not score order)
kept_indices = sorted(top_indices)
compressed = ' '.join(word_scores[i]['word'] for i in kept_indices)
Restoring original order is essential. Jumbling word order can confuse the model even with the right words.
The Complete, Working Code
Install dependencies first:
pip install transformers torch spacy
python -m spacy download en_core_web_sm
import re
import torch
import spacy
from transformers import BertTokenizer, BertModel
class BERTPromptCompressor:
"""
Compresses prompts using BERT attention scores.
No extra LLM calls. Fully explainable. Runs on CPU in <50ms.
"""
STOPWORDS = {
'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
'should', 'may', 'might', 'shall', 'to', 'of', 'in', 'for', 'on',
'with', 'at', 'by', 'from', 'as', 'into', 'through', 'about', 'and',
'but', 'or', 'that', 'this', 'these', 'those', 'it', 'its', 'not',
'no', 'nor', 'only', 'than', 'too', 'very', 'just', 'even', 'also',
}
def __init__(self, model_name: str = 'bert-base-uncased'):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertModel.from_pretrained(model_name, output_attentions=True)
self.model.eval()
try:
self.nlp = spacy.load('en_core_web_sm')
except OSError:
print("spaCy model not found. Entity boosting disabled.")
print("Run: python -m spacy download en_core_web_sm")
self.nlp = None
# Layer weights: later layers weighted higher (semantic > syntactic)
self._layer_weights = torch.linspace(0.5, 1.5, 12)
self._layer_weights /= self._layer_weights.sum()
def _get_entities(self, text: str) -> set:
if self.nlp is None:
return set()
doc = self.nlp(text)
return {w.lower() for ent in doc.ents for w in ent.text.split()}
def _tokens_to_words(self, token_ids, tokens, importance, key_attn):
"""Merge BERT subword tokens back into whole words with scores."""
words = {}
current_pieces, current_imp, current_key = [], [], []
word_idx = 0
special = {self.tokenizer.cls_token_id, self.tokenizer.sep_token_id,
self.tokenizer.pad_token_id}
for i, (tid, tok) in enumerate(zip(token_ids, tokens)):
if tid in special:
if current_pieces:
word = ''.join(current_pieces).replace('##', '')
words[word_idx] = {
'word': word,
'attn': sum(current_imp) / len(current_imp),
'key': sum(current_key) / len(current_key),
}
word_idx += 1
current_pieces, current_imp, current_key = [], [], []
continue
if tok.startswith('##'):
current_pieces.append(tok)
else:
if current_pieces:
word = ''.join(current_pieces).replace('##', '')
words[word_idx] = {
'word': word,
'attn': sum(current_imp) / len(current_imp),
'key': sum(current_key) / len(current_key),
}
word_idx += 1
current_pieces = [tok]
current_imp, current_key = [], []
current_imp.append(float(importance[i]))
current_key.append(float(key_attn[i]))
if current_pieces:
word = ''.join(current_pieces).replace('##', '')
words[word_idx] = {
'word': word,
'attn': sum(current_imp) / len(current_imp),
'key': sum(current_key) / len(current_key),
}
return words
def compress(self, prompt: str, ratio: float = 0.5) -> dict:
"""
Compress a prompt.
Args:
prompt: Input text to compress.
ratio: Fraction of tokens to keep. 0.5 = keep 50% (50% compression).
Range: 0.2 (aggressive) to 0.8 (conservative).
Returns:
{
'compressed': str, # The compressed prompt
'original_tokens': int, # Token count before
'compressed_tokens': int, # Token count after
'savings_pct': float, # e.g. 54.2
'word_scores': dict # Per-word scores for debugging
}
"""
words = prompt.split()
if len(words) <= 4:
return {'compressed': prompt, 'original_tokens': len(words),
'compressed_tokens': len(words), 'savings_pct': 0.0,
'word_scores': {}}
entities = self._get_entities(prompt)
inputs = self.tokenizer(
prompt, return_tensors='pt', max_length=512, truncation=True
)
token_ids = inputs['input_ids'][0].tolist()
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
with torch.no_grad():
outputs = self.model(**inputs)
# Stack attention: [12_layers, 1, 12_heads, seq, seq]
attn_stack = torch.stack(outputs.attentions)
# Weighted average across layers, then average across heads → [seq, seq]
weighted = sum(
self._layer_weights[i] * attn_stack[i, 0].mean(dim=0)
for i in range(12)
)
# Token importance: average attention received
importance = weighted.mean(dim=0) # [seq]
# Key token: highest importance among content tokens
special_ids = {self.tokenizer.cls_token_id, self.tokenizer.sep_token_id,
self.tokenizer.pad_token_id}
content = [i for i, tid in enumerate(token_ids) if tid not in special_ids]
key_idx = max(content, key=lambda i: float(importance[i]))
key_attn = weighted[key_idx]
# Build word-level scores
word_scores = self._tokens_to_words(token_ids, tokens, importance, key_attn)
# Final scoring with entity/stopword adjustments
for idx, d in word_scores.items():
score = 0.6 * d['attn'] + 0.4 * d['key']
if d['word'].lower() in entities:
score *= 2.0
elif d['word'].lower() in self.STOPWORDS:
score *= 0.1
word_scores[idx]['score'] = score
# Select top-ratio% by score, restore original order
n_keep = max(3, int(len(word_scores) * ratio))
top = sorted(word_scores, key=lambda i: word_scores[i]['score'], reverse=True)
kept = sorted(top[:n_keep])
compressed = ' '.join(word_scores[i]['word'] for i in kept)
orig_toks = len(content)
comp_toks = max(1, len(self.tokenizer.encode(compressed)) - 2)
return {
'compressed': compressed,
'original_tokens': orig_toks,
'compressed_tokens': comp_toks,
'savings_pct': round((1 - comp_toks / orig_toks) * 100, 1),
'word_scores': word_scores,
}
# ── Quick demo ────────────────────────────────────────────────────────────────
if __name__ == '__main__':
compressor = BERTPromptCompressor()
prompts = [
"Do you happen to have details about what countries are located near Egypt?",
"Please analyze the following customer reviews and summarize the top three "
"complaints about our new smartphone battery life, including specific quotes "
"if possible.",
"I would really appreciate it if you could explain to me in simple terms "
"how the attention mechanism in transformer models actually works, "
"especially the part about queries, keys, and values.",
]
for prompt in prompts:
result = compressor.compress(prompt, ratio=0.5)
print(f"\nOriginal ({result['original_tokens']:>3} tokens): {prompt}")
print(f"Compressed ({result['compressed_tokens']:>3} tokens): {result['compressed']}")
print(f"Savings: {result['savings_pct']}%")
print("─" * 70)
Real Before & After Examples
Before / After · ratio=0.5
Benchmark Results: Honest Numbers
How We Compare Against LLMLingua-2
The honest conclusion: LLMLingua-2 wins on raw compression quality and ratio ceiling. Our approach wins on zero infrastructure overhead, latency, and transparency. For production apps where simplicity and explainability matter more than maximum compression, BERT attention is the right tool.
When to Use Which Method
✅ Latency is critical (<50ms CPU is your budget)
✅ You need to explain or audit which tokens were dropped
✅ Your prompts are simple instruction or Q&A style
✅ You want to start fast with no extra infrastructure
✅ You're compressing complex reasoning prompts or long CoT sequences
✅ Quality at high compression ratios is the priority
✅ You already have a model-serving infrastructure
✅ You can tolerate 15–30% quality loss
✅ You can invest in fine-tuning a custom compression LLM
✅ You're doing research into soft/generative compression
✅ You want adaptive compression (different ratios per context)
✅ You want to use the inference LLM's own attention without a second model
⚠️ Note: AttnComp is designed for RAG document filtering specifically, not general prompt compression
Production Integration
FastAPI Endpoint
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import time
app = FastAPI(title="BERT Prompt Compressor API")
compressor = BERTPromptCompressor() # Load once at startup
class CompressRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=10000)
ratio: float = Field(0.5, ge=0.1, le=0.9, description="Keep this fraction of tokens")
class CompressResponse(BaseModel):
compressed: str
original_tokens: int
compressed_tokens: int
savings_pct: float
latency_ms: float
@app.post("/compress", response_model=CompressResponse)
async def compress_prompt(req: CompressRequest):
start = time.perf_counter()
try:
result = compressor.compress(req.prompt, ratio=req.ratio)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return CompressResponse(
**result,
latency_ms=round((time.perf_counter() - start) * 1000, 1)
)
# Health check
@app.get("/health")
async def health():
return {"status": "ok", "model": "bert-base-uncased"}
Using as Middleware (LangChain style)
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage
compressor = BERTPromptCompressor()
llm = ChatAnthropic(model="claude-sonnet-4-5")
def compressed_chat(prompt: str, ratio: float = 0.5) -> str:
"""Drop-in wrapper that compresses before sending."""
result = compressor.compress(prompt, ratio=ratio)
print(f"[Compressor] {result['original_tokens']} → {result['compressed_tokens']} tokens "
f"({result['savings_pct']}% saved)")
response = llm.invoke([HumanMessage(content=result['compressed'])])
return response.content
# Use exactly like normal LLM calls
answer = compressed_chat(
"Could you please explain in detail what a transformer neural network is "
"and how the self-attention mechanism makes it so powerful compared to RNNs?",
ratio=0.5
)
Caching for Repeated Prompts
import hashlib
from functools import lru_cache
class CachedCompressor(BERTPromptCompressor):
"""Adds result caching for repeated identical prompts."""
def __init__(self, *args, cache_size: int = 1000, **kwargs):
super().__init__(*args, **kwargs)
self._cache = {}
self._cache_size = cache_size
def compress(self, prompt: str, ratio: float = 0.5) -> dict:
key = hashlib.md5(f"{prompt}:{ratio}".encode()).hexdigest()
if key not in self._cache:
if len(self._cache) >= self._cache_size:
# Evict oldest entry
oldest = next(iter(self._cache))
del self._cache[oldest]
self._cache[key] = super().compress(prompt, ratio)
return self._cache[key]
Using DistilBERT for 2x Faster Compression
For production workloads where every millisecond matters, swap BERT-base for DistilBERT. DistilBERT (Sanh et al., Hugging Face 2019) has:
- 40% fewer parameters (66M vs 110M)
- 60% faster inference on CPU
- Retains 97% of BERT-base performance on GLUE benchmarks
The swap is one line:
# Instead of:
compressor = BERTPromptCompressor('bert-base-uncased')
# Use:
compressor = BERTPromptCompressor('distilbert-base-uncased')
# For multilingual prompts (Arabic, Chinese, French, etc.):
compressor = BERTPromptCompressor('bert-base-multilingual-cased')
Limitations & Honest Assessment
Things This Approach Doesn't Do Well
bert-base-multilingual-cased — slightly slower but covers 104 languages.What We're Working on Next
Auto-select compression ratio based on task type (Q&A vs reasoning vs creative) using a lightweight classifier trained on prompt categories.
Add a fast cosine similarity check between original and compressed prompt embeddings. If similarity drops below a threshold, fall back to a higher ratio automatically.
When the prompt contains a question, cross-attend the question against the context to score tokens by their relevance to what's actually being asked — not just general importance.
Packaging this as a drop-in library — pip install bert-compress — with CLI, FastAPI server, and LangChain middleware in a single package. Goal: anyone can add it to their pipeline in under 5 minutes.
Try It On Your Prompts
The code above is complete and runnable. Paste your longest prompt into the demo at the end of the script, run it, and see the before/after with per-word scores printed. If you have a prompt that compresses poorly or surprisingly well — share it in the comments.
Long prompts don't have to slow you down or empty your wallet. A mile of words is optional. A few key tokens, chosen by BERT attention, really can do the job.
Related Reading
Tokens: Why Your Language Costs More Than English When You Use AI
How tokenization works, why Arabic prompts cost 50% more than English, and the 3 concrete techniques to reduce your token bill right now — no compression model required.