LLM LM_HEAD解码加速

LLM LM_HEAD解码加速

深入解析LLM中的强化学习算法,包括PPO、DPO、GRPO的工作原理、数学推导、以及工业界应用案例

LLM推理加速方法对比:Speculative Decoding、EAGLE与Medusa

一、核心框架对比

1.1 通用Speculative Decoding范式

所有方法都遵循三步骤:

1. Draft (草稿生成)快速生成候选tokens
2. Verify (验证)Target LLM并行验证所有候选
3. Accept (接受)选择正确的tokens

1.2 三种方法的差异

维度 Traditional SpecDec EAGLE Medusa
Draft方法 独立的小模型 Target LLM的特征 + 1层Decoder Target LLM上的多个MLP头
Draft参数量 完整小模型(1B-7B) 0.24B-1B 0.1-0.3B per head
是否需要额外模型 ✓ 需要 ✗ 不需要 ✗ 不需要
Draft准确率 中等(0.5-0.7) 高(0.8) 较低(0.6)
预测方式 自回归(sequential) 自回归(sequential) 并行(parallel)
特征来源 独立模型 倒数第二层(Layer 79) 最后一层(Layer 80)
核心创新 N/A Shifted token消除不确定性 多头并行 + Tree attention
加速比(7B) 1.5-1.9x 2.8-3.0x 2.2-2.8x

二、各方法详细机制

2.1 Traditional Speculative Decoding

Draft阶段

# 使用独立的小模型
draft_tokens = small_model.generate(context, num_tokens=K)
# 例如: 用1B模型为70B模型生成draft

问题: - 需要维护两个模型(部署复杂) - 小模型需要单独训练/预训练 - 分布可能不匹配


2.2 EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)

核心思想:利用Target LLM的倒数第二层特征 + Shifted Token机制

架构

Target LLM:
Input → Embedding → [Layer 1-79] → features₇₉ ──┐
                                                   ↓
EAGLE Draft Model:                    ┌───────────┴─────────┐
                                      │ Concat(features, token_emb)
                                      │ FC Layer (2d → d)
                                      │ 1× Decoder Layer  ← 唯一可训练
                                      │ LM Head (复用)
                                      └──────────────────────┘

关键机制 - Shifted Token

# 消除采样不确定性
def eagle_draft(features_t, token_t):
    # token_t是已知的采样结果!
    feature_t+1 = draft_model(
        features=features_t,
        shifted_token=token_t  # 提前一步的token
    )
    # 明确知道预测"token_t之后"的内容

为什么准确率高(0.8): - 特征层面更规律(连续vs离散) - Shifted token消除了采样的不确定性 - 例子:知道采样了"love",预测"natural"比不知道采样什么更准确

成本: - Draft: 3次 × (1层Decoder) - Verify: 1次 × (80层Transformer,处理4个tokens)


2.3 Medusa (Multiple Decoding Heads)

核心思想:在Target LLM上添加多个并行预测头

架构

Target LLM:
Input → [Layer 1-80] → last_hidden ──┬──→ Original LM Head (位置t+1)
                                      ├──→ Medusa Head 1 (位置t+2)
                                      ├──→ Medusa Head 2 (位置t+3)
                                      └──→ Medusa Head 3 (位置t+4)

所有heads共享输入,并行预测!

Medusa Head结构

class MedusaHead(nn.Module):
    def forward(self, last_hidden):
        x = self.w1(last_hidden)        # Linear
        x = SiLU(x)                     # 激活
        x = x + last_hidden             # Residual!
        logits = self.w2(x)             # 输出
        return logits

树形候选构建

假设3个Medusa heads,每个head取top-k:
- Head 0: top-4 → [t1a, t1b, t1c, t1d]
- Head 1: top-3 → [t2a, t2b, t2c]
- Head 2: top-3 → [t3a, t3b, t3c]

树结构:
                    root
               /    |    |    \
             t1a   t1b  t1c  t1d      (4个分支)
            / | \   ...
          t2a t2b t2c                 (每个3个分支)
         / | \
       t3a t3b t3c                    (每个3个分支)

总节点数: 1 + 4 + 12 + 36 = 53个

关键:只有4个heads,但树展开成53个节点!

成本: - Draft: 1次Transformer forward + 4次MLP - Verify: 1次Transformer forward,处理53个tokens(tree attention)


三、Transformer一次输出多个token的原理

3.1 基础概念

标准理解(误解)

# 看起来只输出1个token
logits = model([t1, t2, t3])  # [1, 3, vocab_size]
next_token = sample(logits[:, -1, :])  # 只取最后一个位置

真实情况

# Transformer实际输出所有位置的预测
logits = model([t1, t2, t3])  # [1, 3, vocab_size]

# logits[0, 0, :]: 基于t1,预测t2的分布
# logits[0, 1, :]: 基于t1,t2,预测t3的分布
# logits[0, 2, :]: 基于t1,t2,t3,预测t4的分布

# 推理时只用最后一个,训练时用所有位置!

3.2 训练vs推理的区别

# 训练:并行计算所有位置的loss
def training():
    input_ids = [t1, t2, t3, t4]
    labels = [t2, t3, t4, t5]

    logits = model(input_ids)  # [1, 4, vocab_size]

    # 4个位置的loss并行计算
    loss = cross_entropy(logits[:, :-1], labels)

# 推理:只用最后一个位置
def inference():
    logits = model([t1, t2, t3])
    next_token = sample(logits[:, -1, :])  # 浪费了前面的预测!

3.3 Tree Attention的核心机制

问题:如何一次验证多个候选路径?

解决:特殊的attention mask

# 候选树示例:
"""
            root
          /      \
        t1a      t1b
       /  \      /  \
     t2a  t2b  t2c  t2d
"""

# Step 1: 展平树
flat_tokens = [root, t1a, t1b, t2a, t2b, t2c, t2d]
#              [0,    1,    2,   3,   4,   5,   6]

# Step 2: 构建Tree Attention Mask
# 每个token只能attend到它的祖先路径
"""
Mask矩阵 (1=可见, 0=不可见):
       0  1  2  3  4  5  6
    0 [1  0  0  0  0  0  0]  root
    1 [1  1  0  0  0  0  0]  t1a (只看root,t1a)
    2 [1  0  1  0  0  0  0]  t1b (只看root,t1b)
    3 [1  1  0  1  0  0  0]  t2a (看root→t1a→t2a)
    4 [1  1  0  0  1  0  0]  t2b (看root→t1a→t2b)
    5 [1  0  1  0  0  1  0]  t2c (看root→t1b→t2c)
    6 [1  0  1  0  0  0  1]  t2d (看root→t1b→t2d)
"""

# Step 3: 一次forward验证所有候选
logits = transformer_forward(
    input_ids=flat_tokens,        # [1, 7]
    attention_mask=tree_mask      # [1, 7, 7]
)  # 输出: [1, 7, vocab_size]

# 每个位置的输出只基于它的祖先路径!
# t2a的预测 = f(root, t1a, t2a)
# t2c的预测 = f(root, t1b, t2c)
# 互不干扰,但共享root的计算!

关键优势: - root只计算一次,所有路径共享 - 一次forward验证整棵树 - 不是简单的batch(batch中每个序列独立计算root)

3.4 为什么Tree Attention高效?

Memory-Bandwidth Bound特性

LLM推理的瓶颈:
- 计算能力:312 TFLOPS (A100)
- 内存带宽:1935 GB/s (A100)
- 模型大小:70B × 2字节 = 140GB

时间分析:
- 读取参数时间:140GB / 1935GB/s = 72ms
- 计算1个token:0.9ms
- 计算125个token:112ms

关键:参数只读一次,但处理了125个位置!
总时间:max(72ms, 112ms) + overhead ≈ 130ms
对比vanilla 125个token:125 × 72ms = 9000ms

加速:9000/130 ≈ 69x (理论上限)
实际:考虑acceptance rate,约2-3x

四、完整推理流程对比

生成4个tokens的成本

方法 Transformer Forward Draft计算 总时间(估算) 加速比
Vanilla 4次 × 1 token - 288ms 1.0x
Traditional SpecDec 2次 (1 draft + 1 verify) 完整小模型 ~200ms 1.4x
EAGLE 2次 (1 draft + 1 verify) 3次 × 1层Decoder 164ms 1.76x
Medusa 2次 (1 draft + 1 verify) 4次 × 1层MLP 204ms 1.41x

:以上为理想情况(draft全部正确),实际效果取决于acceptance rate

五、关键洞察总结

5.1 为什么EAGLE准确率最高?

Medusa的困境:
预测位置t+2时,不知道位置t+1会采样到什么
→ 只能预测一个"平均"分布
→ 准确率低(0.6)

EAGLE的优势:
预测位置t+2时,已知位置t+1采样了什么
→ 目标明确
→ 准确率高(0.8)

5.2 为什么Medusa更简单?

  • 不需要访问中间层特征
  • 只在最后一层添加heads
  • 部署侵入性最小

5.3 为什么都能加速?

核心原因:把O(N)的sequential步骤压缩到O(1) - Vanilla:N个tokens需要N次forward - Speculative:N个tokens需要约N/k次forward (k=平均接受长度) - 利用了Transformer"一次输出多个位置"的能力

5.4 Tree Attention的本质

不是batch并行,而是计算共享: - Batch: 多个独立序列并行,root计算N次 - Tree: 单个序列,root计算1次,所有路径共享

Memory-bandwidth优势: - 参数只读一次(72ms) - 但处理了100+个token位置 - 总时间 << 100次vanilla forward

六、选择建议

场景 推荐方法 理由
追求最高加速比 EAGLE 准确率高,接受序列长
最小部署侵入 Medusa 只加MLP heads,易集成
训练资源有限 Medusa-1 冻结backbone,低成本
需要独立draft Traditional 模型解耦,灵活性高

附录:术语表

  • Draft: 快速生成候选tokens的过程
  • Verify: 用target model验证候选的正确性
  • Acceptance Rate: 平均每次接受的tokens比例
  • Tree Attention: 特殊的attention mask,允许树形结构的并行验证
  • Shifted Token: EAGLE的关键机制,使用已知采样结果作为输入
  • Memory-Bandwidth Bound: 性能瓶颈在内存读取而非计算

Thanks for Reading

If this article was helpful to you, feel free to connect with me!