在大语言模型(Large Language Model, LLM)推理过程中,KV Cache(Key-Value Cache)是提升生成效率的核心机制。其本质是在自回归解码阶段缓存已计算的注意力键(Key)和值(Value)张量,避免对历史 token 重复执行前向传播中的注意力计算。然而,随着上下文长度增长,KV Cache 的内存占用呈线性增长,成为推理系统的主要瓶颈。
KV Cache 的基本结构与成本构成
在 Transformer 解码器中,每个注意力头对输入序列 $X \in \mathbb{R}^{n \times d}$ 执行如下操作:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
其中 $Q = XW_Q$, $K = XW_K$, $V = XW_V$。在自回归生成中,第 $t$ 步仅生成一个 token,但需对全部 $1..t$ 个 token 计算注意力。若不缓存,每步需重新计算所有历史 token 的 $K$ 和 $V$,时间复杂度为 $O(t^2)$。引入 KV Cache 后,仅计算新 token 的 $K_t, V_t$,并拼接到历史缓存中,使单步计算复杂度降至 $O(t)$,总生成复杂度从 $O(n^2)$ 降至 $O(n)$。
但代价是内存开销。设模型有 $L$ 层,每层 $H$ 个注意力头,隐藏维度 $d$,则单个 token 的 KV Cache 大小为:
$$ \text{Size per token} = 2 \times L \times H \times d \times \text{dtype_bytes} $$
以 LLaMA-2-7B 为例:$L=32$, $H=32$, $d=128$(因 head_dim = 128),使用 float16(2 字节),则单 token 缓存为:
$$ 2 \times 32 \times 32 \times 128 \times 2 = 524,288 \text{ bytes} \approx 512 \text{ KB} $$
生成 4096 token 上下文时,KV Cache 占用约 2 GB 显存,超过模型参数本身(7B × 2B ≈ 14 GB 中的 1/7)。若上下文扩展至 32k,缓存将达 16 GB,显著限制批处理能力与并发吞吐。
优化目标与技术分类
KV Cache 优化的核心目标是在 内存占用、计算开销、生成质量 三者间取得平衡。现有技术可归为三类:
- 压缩(Compression):降低每个 token 的缓存精度或维度
- 剪枝(Pruning):移除部分历史 token 的缓存
- 重计算(Recomputation):按需重建部分缓存以节省内存
下文逐类展开,分析其原理、实现细节与量化效果。
压缩技术:量化与低秩近似
动态量化(Dynamic Quantization)
最直接的方法是将 float16 KV Cache 转换为 int8。由于注意力机制对数值精度容忍度较高,实验表明 int8 量化在多数任务中 BLEU 或准确率下降 < 0.5%。
在 PyTorch 中,可通过自定义 attention 实现:
import torch
import torch.nn.functional as F
def quantize_cache(k: torch.Tensor, v: torch.Tensor):
# k, v: [seq_len, num_heads, head_dim]
k_min, k_max = k.min(), k.max()
v_min, v_max = v.min(), v.max()
k_scale = (k_max - k_min) / 255.0
v_scale = (v_max - v_min) / 255.0
k_quant = ((k - k_min) / k_scale).round().clamp(0, 255).to(torch.uint8)
v_quant = ((v - v_min) / v_scale).round().clamp(0, 255).to(torch.uint8)
return k_quant, v_quant, k_min, k_scale, v_min, v_scale
def dequantize_cache(k_quant, v_quant, k_min, k_scale, v_min, v_scale):
k = k_quant.to(torch.float16) * k_scale + k_min
v = v_quant.to(torch.float16) * v_scale + v_min
return k, v
在推理循环中,缓存以 int8 存储,计算注意力前反量化。实测 LLaMA-2-7B 在 GSM8K 数学推理任务上,int8 KV Cache 使显存减少 50%,P99 延迟增加 8%(因反量化开销),准确率从 52.1% 降至 51.7%。
低秩近似(Low-Rank Approximation)
观察发现,KV 张量在 token 维度存在冗余。例如,相邻 token 的 Key 向量高度相关。可对历史 KV 执行 SVD 或 PCA 降维。
设历史 Key 矩阵 $K_{hist} \in \mathbb{R}^{t \times d_k}$,对其进行秩-$r$ 近似:
$$ K_{hist} \approx U_r S_r V_r^T,\quad U_r \in \mathbb{R}^{t \times r} $$
仅缓存 $U_r$ 和 $S_r V_r^T$,将存储从 $O(t d_k)$ 降至 $O(t r + r d_k)$。当 $r \ll d_k$ 时显著节省内存。
Hugging Face Transformers 的 SinkCache 采用类似思想,但更工程化:保留最近 $w$ 个 token 的完整 KV,其余压缩为固定数量的“汇点”(sinks)。例如,设置 window_length=1024, num_sinks=4,则缓存大小上限为 1028 个 token,无论上下文多长。
from transformers import SinkCache
cache = SinkCache(window_length=1024, num_sinks=4)
# 在 generate() 中传入 past_key_values=cache
在 LongBench 评测中,LLaMA-2-7B 使用 SinkCache(1024+4)处理 8k 上下文时,显存从 4.1 GB 降至 0.53 GB,问答准确率仅下降 1.2%(从 48.3% → 47.1%)。
剪枝技术:基于重要性或语义的缓存淘汰
剪枝的核心假设是:并非所有历史 token 对当前生成同等重要。可依据注意力分数、token 语义角色或位置信息决定保留哪些 KV。
基于注意力分数的动态剪枝
在每步生成后,记录当前 token 对历史 token 的注意力权重 $\alpha_t = [\alpha_{t,1}, ..., \alpha_{t,t}]$。累积重要性得分:
$$ I_j = \sum_{t=j}^{T} \alpha_{t,j} $$
定期淘汰 $I_j$ 最低的 token。但该方法需维护全历史注意力图,开销大。
更实用的是 滑动窗口注意力(Sliding Window Attention),如 Mistral-7B 所采用:每层仅关注最近 4096 个 token。实现时,KV Cache 以环形缓冲区管理:
class SlidingWindowCache:
def __init__(self, window_size: int, num_layers: int, num_heads: int, head_dim: int):
self.window_size = window_size
self.k_cache = torch.zeros((num_layers, window_size, num_heads, head_dim), dtype=torch.float16)
self.v_cache = torch.zeros_like(self.k_cache)
self.start_pos = 0 # 环形缓冲区起始索引
self.seq_len = 0
def update(self, layer_idx: int, k_new: torch.Tensor, v_new: torch.Tensor):
batch_size, _, num_heads, head_dim = k_new.shape
assert batch_size == 1 # 简化
pos = self.seq_len % self.window_size
self.k_cache[layer_idx, pos] = k_new.squeeze(0)
self.v_cache[layer_idx, pos] = v_new.squeeze(0)
if self.seq_len >= self.window_size:
self.start_pos = (self.start_pos + 1) % self.window_size
self.seq_len += 1
def get(self, layer_idx: int):
if self.seq_len <= self.window_size:
return self.k_cache[layer_idx, :self.seq_len], self.v_cache[layer_idx, :self.seq_len]
else:
# 拼接环形缓冲区的两段
end = self.start_pos
start = (self.start_pos) % self.window_size
k1 = self.k_cache[layer_idx, start:]
k2 = self.k_cache[layer_idx, :end]
return torch.cat([k1, k2], dim=0), torch.cat([self.v_cache[layer_idx, start:], self.v_cache[layer_idx, :end]], dim=0)
Mistral-7B 在 32k 上下文任务中,滑动窗口使 KV Cache 恒定为 4096 token,显存占用稳定在 2.1 GB,而标准 LLaMA-2 需 16 GB。在 narrativeQA 数据集上,长文档问答 F1 仅比 full attention 低 2.4%。
语义感知剪枝
更高级的方法利用 token 的语义角色。例如,保留名词、动词等实词,丢弃停用词。需额外 NLP 工具链,工程复杂度高,且可能误删关键上下文(如否定词“not”)。
实践中,分层剪枝(Layer-wise Pruning) 更可行:浅层关注局部语法,可大幅剪枝;深层捕获全局语义,需保留更多历史。例如,第 1–8 层窗口=512,第 9–24 层窗口=2048,第 25–32 层 full attention。实测在保持质量的同时减少 35% 缓存。
重计算技术:时间换空间
当内存极度受限时,可选择性丢弃部分 KV 并在需要时重计算。典型代表是 StreamingLLM 提出的“attention sink”机制。
其核心发现:Transformer 注意力对开头若干 token(如前 4 个)存在强依赖,称为“sink tokens”。只要保留这些 sink,即使中间 token 被丢弃,模型仍能稳定生成。
实现策略:
- 缓存前 4 个 token 的 KV(不可丢弃)
- 使用 LRU 策略淘汰中间 token
- 当需访问已淘汰 token 时,从原始输入重新编码(需保留 input_ids)
class StreamingLLMCache:
def __init__(self, sink_size=4, cache_size=2048):
self.sink_size = sink_size
self.cache_size = cache_size
self.kv_cache = {} # {layer: {pos: (k, v)}}
self.input_ids = None
self.model = None # 需绑定模型用于重计算
def set_input(self, input_ids, model):
self.input_ids = input_ids
self.model = model
def get_kv(self, layer_idx, pos):
if pos < self.sink_size or pos in self.kv_cache.get(layer_idx, {}):
return self._get_cached_kv(layer_idx, pos)
else:
# 重计算:从 input_ids[pos:pos+1] 重新前向传播到指定层
return self._recompute_kv(layer_idx, pos)
def _recompute_kv(self, layer_idx, pos):
# 简化:实际需逐层前向至 layer_idx
token_id = self.input_ids[:, pos:pos+1]
with torch.no_grad():
hidden = self.model.embed_tokens(token_id)
for i in range(layer_idx + 1):
hidden = self.model.layers[i](hidden)[0]
k, v = self.model.layers[layer_idx].self_attn.compute_kv(hidden)
return k, v
StreamingLLM 在 256k 上下文测试中,仅用 4.8 GB 显存(vs 标准 128 GB),在 PG-19 语言建模任务上困惑度仅上升 5.2%。但重计算带来显著延迟:生成速度下降 3–5 倍,适用于离线批处理而非实时交互。
混合优化策略与工程权衡
单一技术难以满足所有场景。工业系统通常组合多种策略:
| 场景 | 推荐策略 | 显存节省 | 质量损失 | 延迟影响 |
|---|---|---|---|---|
| 实时聊天(<4k) | int8 量化 | 50% | <0.5% | +8% |
| 长文档摘要(8k–32k) | SinkCache + 滑动窗口 | 85% | 1–2% | +15% |
| 超长上下文分析(>64k) | StreamingLLM + 重计算 | >95% | 3–6% | +300% |
在 vLLM 推理引擎中,通过 PagedAttention 技术进一步优化内存管理。其将 KV Cache 分页(类似虚拟内存),允许非连续显存分配,减少碎片。实测在 A10G 上,batch_size=16 时吞吐提升 2.1 倍。
# vLLM 启动命令示例
python -m vllm.entrypoints.api_server \
--model meta-llama/Llama-2-7b-chat-hf \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--gpu-memory-utilization 0.95
输出指标:
Throughput: 142.3 tokens/s
P99 latency: 87ms
KV Cache memory: 1.8 GB (vs 4.1 GB without PagedAttention)
验证与监控
优化必须伴随严格验证。我们建议以下监控项:
- 缓存命中率:重计算场景中,应 >90%
- P99 端到端延迟:量化后增幅应 <15%
- 任务指标漂移:在 held-out dataset 上对比 BLEU/F1/准确率,阈值 Δ<1.5%
一致性验证脚本示例(对比原始与优化输出):
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model_orig = AutoModelForCausalLM.from_pretrained(..., torch_dtype=torch.float16)
model_opt = load_optimized_model() # 加载带 KV Cache 优化的模型
prompt = "Explain quantum entanglement in simple terms."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
out_orig = model_orig.generate(**inputs, max_new_tokens=100)
out_opt = model_opt.generate(**inputs, max_new_tokens=100)
text_orig = tokenizer.decode(out_orig[0], skip_special_tokens=True)
text_opt = tokenizer.decode(out_opt[0], skip_special_tokens=True)
# 使用 sentence-transformers 计算语义相似度
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer('all-MiniLM-L6-v2')
emb_orig = embedder.encode(text_orig)
emb_opt = embedder.encode(text_opt)
similarity = np.dot(emb_orig, emb_opt) / (np.linalg.norm(emb_orig) * np.linalg.norm(emb_opt))
assert similarity > 0.95, f"Semantic drift detected: {similarity:.3f}"
结论
KV Cache 优化是 LLM 推理工程的核心课题。量化、剪枝、重计算三类技术分别从精度、范围、持久性维度降低内存压力。选择策略需依据上下文长度、延迟 SLA 与质量容忍度进行量化权衡。未来方向包括硬件感知压缩(如 NVIDIA TensorRT-LLM 的 INT4 支持)与训练-推理联合优化(如在训练中注入缓存淘汰噪声)。当前最佳实践是:短上下文用量化,中长上下文用 SinkCache/滑动窗口,超长上下文用 StreamingLLM,并辅以严格的在线监控与 A/B 测试。