Transformer块在训练和推理阶段的流程对比
模型配置
- 层数 (L): 32
- 隐藏维度 (H): 4096
- 注意力头数 (N_heads): 32
- 每个头的维度 (d_k): 128 (H = N_heads × d_k = 32 × 128 = 4096)
- 数据类型: BF16 (2字节)
- 批次大小 (B): 1
- 输入序列长度 (S): 100 tokens
一、训练阶段(Training)
输入形状
- 输入:
[B, S, H]=[1, 100, 4096] - 所有100个token同时输入模型
单层Transformer块的流程
1. 自注意力机制(Self-Attention)
输入 X: [1, 100, 4096]
↓
LayerNorm: [1, 100, 4096]
↓
线性投影 (Q, K, V)
├─ Q = X @ W_q: [1, 100, 4096]
├─ K = X @ W_k: [1, 100, 4096]
└─ V = X @ W_v: [1, 100, 4096]
↓
重塑为多头形式
├─ Q: [1, 32, 100, 128] # [B, N_heads, S, d_k]
├─ K: [1, 32, 100, 128]
└─ V: [1, 32, 100, 128]
↓
计算注意力分数
Scores = Q @ K^T: [1, 32, 100, 100] # 注意:100×100的注意力矩阵
↓
应用因果掩码(Causal Mask)
# 上三角部分设为-inf,确保token只能看到之前的token
Masked_Scores: [1, 32, 100, 100]
↓
Softmax(Scores / √128): [1, 32, 100, 100]
↓
加权求和
Attention_output = Softmax @ V: [1, 32, 100, 128]
↓
重塑回原形状: [1, 100, 4096]
↓
输出投影 @ W_o: [1, 100, 4096]
↓
残差连接 + 原始输入
2. 前馈网络(FFN)
↓
LayerNorm: [1, 100, 4096]
↓
FFN (通常是 4H 的中间维度)
├─ 线性1: [1, 100, 4096] → [1, 100, 16384]
├─ 激活函数 (如 GeLU)
└─ 线性2: [1, 100, 16384] → [1, 100, 4096]
↓
残差连接
训练阶段的关键特征
- ✅ 并行计算:所有100个token同时处理
- ✅ 完整注意力矩阵:[100, 100] 的注意力分数矩阵
- ✅ 因果掩码:通过掩码确保自回归特性(第i个token只能看到前i个token)
- ✅ KV Cache:不使用(因为是一次性前向传播)
- ✅ 计算效率:高度并行化,充分利用GPU
二、推理阶段(Inference)- 自回归生成
推理阶段分为两个子阶段:Prefill 和 Decode
阶段1:Prefill(预填充)- 处理输入prompt
输入 prompt: [1, 100, 4096] # 与训练类似
这一步与训练阶段几乎相同: - 100个token同时处理 - 计算完整的 [100, 100] 注意力矩阵 - 重要:将所有100个token的 K 和 V 缓存起来
KV Cache 初始化:
├─ K_cache[layer_0]: [1, 32, 100, 128]
└─ V_cache[layer_0]: [1, 32, 100, 128]
... (对所有32层)
输出:第101个token的logits → 采样得到第101个token
阶段2:Decode(解码)- 逐个生成新token
第一步生成(token 101)
输入新token: [1, 1, 4096] # 只有1个token!
↓
LayerNorm: [1, 1, 4096]
↓
线性投影
├─ Q_new = X @ W_q: [1, 1, 4096] → [1, 32, 1, 128]
├─ K_new = X @ W_k: [1, 1, 4096] → [1, 32, 1, 128]
└─ V_new = X @ W_v: [1, 1, 4096] → [1, 32, 1, 128]
↓
将新的 K, V 拼接到缓存
├─ K_cache = concat(K_cache, K_new): [1, 32, 101, 128]
└─ V_cache = concat(V_cache, V_new): [1, 32, 101, 128]
↓
计算注意力(关键!)
Scores = Q_new @ K_cache^T: [1, 32, 1, 101] # 注意形状!
↓
Softmax: [1, 32, 1, 101]
↓
Attention_output = Softmax @ V_cache: [1, 32, 1, 128]
↓
后续处理同训练
第二步生成(token 102)
输入: [1, 1, 4096] # 刚生成的token 101
K_cache: [1, 32, 102, 128] # 又增加了1个
V_cache: [1, 32, 102, 128]
Scores = Q_new @ K_cache^T: [1, 32, 1, 102]
...
以此类推,每次生成一个新token,KV Cache长度+1。
三、核心差异对比表
| 维度 | 训练阶段 | 推理阶段(Prefill) | 推理阶段(Decode) |
|---|---|---|---|
| 输入形状 | [1, 100, 4096] |
[1, 100, 4096] |
[1, 1, 4096] |
| Q矩阵形状 | [1, 32, 100, 128] |
[1, 32, 100, 128] |
[1, 32, 1, 128] |
| K矩阵形状 | [1, 32, 100, 128] |
[1, 32, 100, 128] |
[1, 32, seq_len, 128](累积) |
| V矩阵形状 | [1, 32, 100, 128] |
[1, 32, 100, 128] |
[1, 32, seq_len, 128](累积) |
| 注意力分数形状 | [1, 32, 100, 100] |
[1, 32, 100, 100] |
[1, 32, 1, seq_len] |
| 因果掩码 | ✅ 上三角掩码 | ✅ 上三角掩码 | ❌ 不需要(只看历史) |
| KV Cache | ❌ 不使用 | ✅ 初始化 | ✅ 增量更新 |
| 并行度 | 高(所有token并行) | 高(所有token并行) | 低(逐个生成) |
| 计算模式 | 批量并行计算 | 批量并行计算 | 串行自回归生成 |
| 每步处理token数 | 100个 | 100个 | 1个 |
四、内存占用对比
KV Cache内存计算(单层)
在Decode阶段,每生成到第 t 个token时:
单层 KV Cache 大小 = 2 × B × N_heads × t × d_k × 2字节
(K和V) × 批次 × 头数 × 序列长度 × 头维度 × BF16
= 2 × 1 × 32 × t × 128 × 2
= 16,384 × t 字节
≈ 16 KB × t
32层总计 = 32 × 16 KB × t = 512 KB × t
不同序列长度的内存占用
| 生成到的位置 | 单层KV Cache | 32层总KV Cache |
|---|---|---|
| 第100个token(Prefill后) | ~1.6 MB | ~51 MB |
| 第200个token | ~3.2 MB | ~100 MB |
| 第500个token | ~8 MB | ~256 MB |
| 第1000个token | ~16 MB | ~512 MB |
| 第2048个token | ~32 MB | ~1 GB |
| 第4096个token | ~64 MB | ~2 GB |
训练阶段内存
- 不需要KV Cache,但需要存储中间激活值用于反向传播
- 梯度内存:需要存储所有参数的梯度
- 优化器状态:Adam需要存储momentum和variance(2倍参数量)
- 总体内存占用通常远大于推理阶段
五、计算复杂度对比
注意力计算复杂度
训练阶段 / Prefill阶段
- 复杂度: O(S²·H) = O(100² × 4096)
- 需要计算完整的 100×100 注意力矩阵
- 所有token的Q与所有token的K进行点积
- 高度并行化,但计算量大
Decode阶段(第t步)
- 复杂度: O(t·H)
- 只计算 1×t 的注意力分数
- 单个新token的Q与历史t个token的K进行点积
- 随着生成的进行,t不断增大,计算量线性增长
复杂度对比示例
| 阶段 | 序列长度 | 注意力计算FLOPs(近似) | 相对比例 |
|---|---|---|---|
| Prefill | 100 | 100² × H | 基准 |
| Decode (第101步) | 101 | 101 × H | ~1% |
| Decode (第200步) | 200 | 200 × H | ~2% |
| Decode (第1000步) | 1000 | 1000 × H | ~10% |
六、关键优化技术
推理阶段优化
- KV Cache
- 避免重复计算历史token的K和V
- 时间换空间的经典权衡
-
是自回归生成的核心优化
-
Flash Attention
- 优化注意力计算的内存访问模式
- 减少HBM访问,提高计算效率
-
适用于训练和推理
-
Paged Attention(vLLM)
- 动态管理KV Cache内存
- 类似操作系统的虚拟内存分页
-
提高内存利用率和批处理效率
-
Continuous Batching
- 动态批处理不同进度的请求
- 提高GPU利用率
-
减少平均延迟
-
Speculative Decoding
- 使用小模型预测多个token
- 大模型并行验证
- 加速生成过程
训练阶段优化
- 梯度检查点(Gradient Checkpointing)
- 重计算代替存储中间激活值
- 显著减少内存占用
-
增加约20-30%计算时间
-
混合精度训练
- BF16/FP16计算 + FP32累加
- 减少内存和提高速度
-
保持训练稳定性
-
ZeRO优化器(DeepSpeed)
- 分布式训练内存优化
- 分割优化器状态、梯度、参数
-
可训练更大模型
-
张量并行 / 流水线并行
- 模型并行策略
- 突破单卡内存限制
- 提高训练吞吐量
七、实际应用场景对比
训练场景特点
- 目标: 学习模型参数
- 输入: 固定长度的训练序列
- 输出: 损失函数和梯度
- 关注点: 训练速度、收敛性、模型质量
- 成本: 计算密集型,通常使用A100/H100等高端GPU集群
推理场景特点
Prefill阶段
- 目标: 理解用户输入prompt
- 输入: 用户的完整问题或上下文
- 输出: 首个生成token
- 特点: 类似训练的并行计算,但只需前向传播
- 延迟: 取决于prompt长度,通常几十到几百毫秒
Decode阶段
- 目标: 自回归生成响应
- 输入: 每次一个新生成的token
- 输出: 下一个token
- 特点: 串行生成,受内存带宽限制
- 延迟: 每个token约10-50ms(取决于模型大小和硬件)
八、关键要点总结
最核心的区别
- 输入维度差异
- 训练/Prefill:
[B, S, H]- 完整序列 -
Decode:
[B, 1, H]- 单个token -
注意力计算模式
- 训练/Prefill: 完整的S×S注意力矩阵,需要因果掩码
-
Decode: 1×seq_len注意力向量,利用KV Cache增量计算
-
内存使用策略
- 训练: 存储激活值用于反向传播
-
推理: 存储KV Cache用于加速生成
-
计算瓶颈
- 训练/Prefill: 计算密集型(Compute-bound)
- Decode: 内存带宽密集型(Memory-bound)
为什么需要KV Cache?
如果没有KV Cache,在Decode阶段每生成一个新token,都需要重新计算所有历史token的K和V矩阵,这会导致:
- 计算冗余: 历史token的K和V在每一步都被重复计算
- 时间复杂度: O(t²) 而不是 O(t)
- 实际影响: 生成1000个token可能需要几十秒而不是几秒
KV Cache通过空间换时间,使得自回归生成变得实用。
九、可视化对比
训练阶段注意力矩阵
Token: 1 2 3 4 5 ... 100
┌─────────────────────────────┐
1 │ ✓ ✗ ✗ ✗ ✗ ... ✗ │
2 │ ✓ ✓ ✗ ✗ ✗ ... ✗ │
3 │ ✓ ✓ ✓ ✗ ✗ ... ✗ │
4 │ ✓ ✓ ✓ ✓ ✗ ... ✗ │
... │ ... │
100 │ ✓ ✓ ✓ ✓ ✓ ... ✓ │
└─────────────────────────────┘
✓ = 可见(有效注意力)
✗ = 被掩码(-inf)
Decode阶段注意力向量(第101个token)
新token 101 对历史的注意力:
Token 101: [✓ ✓ ✓ ✓ ✓ ... ✓ ✓]
↑ ↑ ↑ ↑ ↑ ↑ ↑
tok1 tok2 tok3 tok4 tok5 ... tok100
所有历史token都可见,计算1×100的注意力权重
参考资料
- Attention Is All You Need (Vaswani et al., 2017)
- Flash Attention (Dao et al., 2022)
- vLLM: Easy, Fast, and Cheap LLM Serving (Kwon et al., 2023)
- Transformer Inference Optimization (多个来源)
文档版本: v1.0
创建日期: 2025-10-23
适用模型: GPT系列、LLaMA系列、Claude系列等基于Transformer的大语言模型