LLM推理加速方法对比:Speculative Decoding、EAGLE与Medusa
一、核心框架对比
1.1 通用Speculative Decoding范式
所有方法都遵循三步骤:
1. Draft (草稿生成):快速生成候选tokens
2. Verify (验证):Target LLM并行验证所有候选
3. Accept (接受):选择正确的tokens
1.2 三种方法的差异
维度 | Traditional SpecDec | EAGLE | Medusa |
---|---|---|---|
Draft方法 | 独立的小模型 | Target LLM的特征 + 1层Decoder | Target LLM上的多个MLP头 |
Draft参数量 | 完整小模型(1B-7B) | 0.24B-1B | 0.1-0.3B per head |
是否需要额外模型 | ✓ 需要 | ✗ 不需要 | ✗ 不需要 |
Draft准确率 | 中等(0.5-0.7) | 高(0.8) | 较低(0.6) |
预测方式 | 自回归(sequential) | 自回归(sequential) | 并行(parallel) |
特征来源 | 独立模型 | 倒数第二层(Layer 79) | 最后一层(Layer 80) |
核心创新 | N/A | Shifted token消除不确定性 | 多头并行 + Tree attention |
加速比(7B) | 1.5-1.9x | 2.8-3.0x | 2.2-2.8x |
二、各方法详细机制
2.1 Traditional Speculative Decoding
Draft阶段:
# 使用独立的小模型
draft_tokens = small_model.generate(context, num_tokens=K)
# 例如: 用1B模型为70B模型生成draft
问题: - 需要维护两个模型(部署复杂) - 小模型需要单独训练/预训练 - 分布可能不匹配
2.2 EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)
核心思想:利用Target LLM的倒数第二层特征 + Shifted Token机制
架构:
Target LLM:
Input → Embedding → [Layer 1-79] → features₇₉ ──┐
↓
EAGLE Draft Model: ┌───────────┴─────────┐
│ Concat(features, token_emb)
│ FC Layer (2d → d)
│ 1× Decoder Layer ← 唯一可训练
│ LM Head (复用)
└──────────────────────┘
关键机制 - Shifted Token:
# 消除采样不确定性
def eagle_draft(features_t, token_t):
# token_t是已知的采样结果!
feature_t+1 = draft_model(
features=features_t,
shifted_token=token_t # 提前一步的token
)
# 明确知道预测"token_t之后"的内容
为什么准确率高(0.8): - 特征层面更规律(连续vs离散) - Shifted token消除了采样的不确定性 - 例子:知道采样了"love",预测"natural"比不知道采样什么更准确
成本: - Draft: 3次 × (1层Decoder) - Verify: 1次 × (80层Transformer,处理4个tokens)
2.3 Medusa (Multiple Decoding Heads)
核心思想:在Target LLM上添加多个并行预测头
架构:
Target LLM:
Input → [Layer 1-80] → last_hidden ──┬──→ Original LM Head (位置t+1)
├──→ Medusa Head 1 (位置t+2)
├──→ Medusa Head 2 (位置t+3)
└──→ Medusa Head 3 (位置t+4)
所有heads共享输入,并行预测!
Medusa Head结构:
class MedusaHead(nn.Module):
def forward(self, last_hidden):
x = self.w1(last_hidden) # Linear
x = SiLU(x) # 激活
x = x + last_hidden # Residual!
logits = self.w2(x) # 输出
return logits
树形候选构建:
假设3个Medusa heads,每个head取top-k:
- Head 0: top-4 → [t1a, t1b, t1c, t1d]
- Head 1: top-3 → [t2a, t2b, t2c]
- Head 2: top-3 → [t3a, t3b, t3c]
树结构:
root
/ | | \
t1a t1b t1c t1d (4个分支)
/ | \ ...
t2a t2b t2c (每个3个分支)
/ | \
t3a t3b t3c (每个3个分支)
总节点数: 1 + 4 + 12 + 36 = 53个
关键:只有4个heads,但树展开成53个节点!
成本: - Draft: 1次Transformer forward + 4次MLP - Verify: 1次Transformer forward,处理53个tokens(tree attention)
三、Transformer一次输出多个token的原理
3.1 基础概念
标准理解(误解):
# 看起来只输出1个token
logits = model([t1, t2, t3]) # [1, 3, vocab_size]
next_token = sample(logits[:, -1, :]) # 只取最后一个位置
真实情况:
# Transformer实际输出所有位置的预测
logits = model([t1, t2, t3]) # [1, 3, vocab_size]
# logits[0, 0, :]: 基于t1,预测t2的分布
# logits[0, 1, :]: 基于t1,t2,预测t3的分布
# logits[0, 2, :]: 基于t1,t2,t3,预测t4的分布
# 推理时只用最后一个,训练时用所有位置!
3.2 训练vs推理的区别
# 训练:并行计算所有位置的loss
def training():
input_ids = [t1, t2, t3, t4]
labels = [t2, t3, t4, t5]
logits = model(input_ids) # [1, 4, vocab_size]
# 4个位置的loss并行计算
loss = cross_entropy(logits[:, :-1], labels)
# 推理:只用最后一个位置
def inference():
logits = model([t1, t2, t3])
next_token = sample(logits[:, -1, :]) # 浪费了前面的预测!
3.3 Tree Attention的核心机制
问题:如何一次验证多个候选路径?
解决:特殊的attention mask
# 候选树示例:
"""
root
/ \
t1a t1b
/ \ / \
t2a t2b t2c t2d
"""
# Step 1: 展平树
flat_tokens = [root, t1a, t1b, t2a, t2b, t2c, t2d]
# [0, 1, 2, 3, 4, 5, 6]
# Step 2: 构建Tree Attention Mask
# 每个token只能attend到它的祖先路径
"""
Mask矩阵 (1=可见, 0=不可见):
0 1 2 3 4 5 6
0 [1 0 0 0 0 0 0] root
1 [1 1 0 0 0 0 0] t1a (只看root,t1a)
2 [1 0 1 0 0 0 0] t1b (只看root,t1b)
3 [1 1 0 1 0 0 0] t2a (看root→t1a→t2a)
4 [1 1 0 0 1 0 0] t2b (看root→t1a→t2b)
5 [1 0 1 0 0 1 0] t2c (看root→t1b→t2c)
6 [1 0 1 0 0 0 1] t2d (看root→t1b→t2d)
"""
# Step 3: 一次forward验证所有候选
logits = transformer_forward(
input_ids=flat_tokens, # [1, 7]
attention_mask=tree_mask # [1, 7, 7]
) # 输出: [1, 7, vocab_size]
# 每个位置的输出只基于它的祖先路径!
# t2a的预测 = f(root, t1a, t2a)
# t2c的预测 = f(root, t1b, t2c)
# 互不干扰,但共享root的计算!
关键优势: - root只计算一次,所有路径共享 - 一次forward验证整棵树 - 不是简单的batch(batch中每个序列独立计算root)
3.4 为什么Tree Attention高效?
Memory-Bandwidth Bound特性:
LLM推理的瓶颈:
- 计算能力:312 TFLOPS (A100)
- 内存带宽:1935 GB/s (A100)
- 模型大小:70B × 2字节 = 140GB
时间分析:
- 读取参数时间:140GB / 1935GB/s = 72ms
- 计算1个token:0.9ms
- 计算125个token:112ms
关键:参数只读一次,但处理了125个位置!
总时间:max(72ms, 112ms) + overhead ≈ 130ms
对比vanilla 125个token:125 × 72ms = 9000ms
加速:9000/130 ≈ 69x (理论上限)
实际:考虑acceptance rate,约2-3x
四、完整推理流程对比
生成4个tokens的成本
方法 | Transformer Forward | Draft计算 | 总时间(估算) | 加速比 |
---|---|---|---|---|
Vanilla | 4次 × 1 token | - | 288ms | 1.0x |
Traditional SpecDec | 2次 (1 draft + 1 verify) | 完整小模型 | ~200ms | 1.4x |
EAGLE | 2次 (1 draft + 1 verify) | 3次 × 1层Decoder | 164ms | 1.76x |
Medusa | 2次 (1 draft + 1 verify) | 4次 × 1层MLP | 204ms | 1.41x |
注:以上为理想情况(draft全部正确),实际效果取决于acceptance rate
五、关键洞察总结
5.1 为什么EAGLE准确率最高?
Medusa的困境:
预测位置t+2时,不知道位置t+1会采样到什么
→ 只能预测一个"平均"分布
→ 准确率低(0.6)
EAGLE的优势:
预测位置t+2时,已知位置t+1采样了什么
→ 目标明确
→ 准确率高(0.8)
5.2 为什么Medusa更简单?
- 不需要访问中间层特征
- 只在最后一层添加heads
- 部署侵入性最小
5.3 为什么都能加速?
核心原因:把O(N)的sequential步骤压缩到O(1) - Vanilla:N个tokens需要N次forward - Speculative:N个tokens需要约N/k次forward (k=平均接受长度) - 利用了Transformer"一次输出多个位置"的能力
5.4 Tree Attention的本质
不是batch并行,而是计算共享: - Batch: 多个独立序列并行,root计算N次 - Tree: 单个序列,root计算1次,所有路径共享
Memory-bandwidth优势: - 参数只读一次(72ms) - 但处理了100+个token位置 - 总时间 << 100次vanilla forward
六、选择建议
场景 | 推荐方法 | 理由 |
---|---|---|
追求最高加速比 | EAGLE | 准确率高,接受序列长 |
最小部署侵入 | Medusa | 只加MLP heads,易集成 |
训练资源有限 | Medusa-1 | 冻结backbone,低成本 |
需要独立draft | Traditional | 模型解耦,灵活性高 |
附录:术语表
- Draft: 快速生成候选tokens的过程
- Verify: 用target model验证候选的正确性
- Acceptance Rate: 平均每次接受的tokens比例
- Tree Attention: 特殊的attention mask,允许树形结构的并行验证
- Shifted Token: EAGLE的关键机制,使用已知采样结果作为输入
- Memory-Bandwidth Bound: 性能瓶颈在内存读取而非计算