LLM train vs infer

LLM train vs infer

A example for LLM train and infer

Transformer块在训练和推理阶段的流程对比

模型配置

  • 层数 (L): 32
  • 隐藏维度 (H): 4096
  • 注意力头数 (N_heads): 32
  • 每个头的维度 (d_k): 128 (H = N_heads × d_k = 32 × 128 = 4096)
  • 数据类型: BF16 (2字节)
  • 批次大小 (B): 1
  • 输入序列长度 (S): 100 tokens

一、训练阶段(Training)

输入形状

  • 输入: [B, S, H] = [1, 100, 4096]
  • 所有100个token同时输入模型

单层Transformer块的流程

1. 自注意力机制(Self-Attention)

输入 X: [1, 100, 4096]
     ↓
LayerNorm: [1, 100, 4096]
     ↓
线性投影 (Q, K, V)
├─ Q = X @ W_q: [1, 100, 4096]
├─ K = X @ W_k: [1, 100, 4096]  
└─ V = X @ W_v: [1, 100, 4096]
     ↓
重塑为多头形式
├─ Q: [1, 32, 100, 128]  # [B, N_heads, S, d_k]
├─ K: [1, 32, 100, 128]
└─ V: [1, 32, 100, 128]
     ↓
计算注意力分数
Scores = Q @ K^T: [1, 32, 100, 100]  # 注意:100×100的注意力矩阵
     ↓
应用因果掩码(Causal Mask)
# 上三角部分设为-inf,确保token只能看到之前的token
Masked_Scores: [1, 32, 100, 100]
     ↓
Softmax(Scores / √128): [1, 32, 100, 100]
     ↓
加权求和
Attention_output = Softmax @ V: [1, 32, 100, 128]
     ↓
重塑回原形状: [1, 100, 4096]
     ↓
输出投影 @ W_o: [1, 100, 4096]
     ↓
残差连接 + 原始输入

2. 前馈网络(FFN)

     ↓
LayerNorm: [1, 100, 4096]
     ↓
FFN (通常是 4H 的中间维度)
├─ 线性1: [1, 100, 4096] → [1, 100, 16384]
├─ 激活函数 (如 GeLU)
└─ 线性2: [1, 100, 16384] → [1, 100, 4096]
     ↓
残差连接

训练阶段的关键特征

  • 并行计算:所有100个token同时处理
  • 完整注意力矩阵:[100, 100] 的注意力分数矩阵
  • 因果掩码:通过掩码确保自回归特性(第i个token只能看到前i个token)
  • KV Cache不使用(因为是一次性前向传播)
  • 计算效率:高度并行化,充分利用GPU

二、推理阶段(Inference)- 自回归生成

推理阶段分为两个子阶段:PrefillDecode

阶段1:Prefill(预填充)- 处理输入prompt

输入 prompt: [1, 100, 4096]  # 与训练类似

这一步与训练阶段几乎相同: - 100个token同时处理 - 计算完整的 [100, 100] 注意力矩阵 - 重要:将所有100个token的 K 和 V 缓存起来

KV Cache 初始化:
├─ K_cache[layer_0]: [1, 32, 100, 128]
└─ V_cache[layer_0]: [1, 32, 100, 128]
    ... (对所有32层)

输出:第101个token的logits → 采样得到第101个token


阶段2:Decode(解码)- 逐个生成新token

第一步生成(token 101)

输入新token: [1, 1, 4096]  # 只有1个token!
     ↓
LayerNorm: [1, 1, 4096]
     ↓
线性投影
├─ Q_new = X @ W_q: [1, 1, 4096] → [1, 32, 1, 128]
├─ K_new = X @ W_k: [1, 1, 4096] → [1, 32, 1, 128]  
└─ V_new = X @ W_v: [1, 1, 4096] → [1, 32, 1, 128]
     ↓
将新的 K, V 拼接到缓存
├─ K_cache = concat(K_cache, K_new): [1, 32, 101, 128]
└─ V_cache = concat(V_cache, V_new): [1, 32, 101, 128]
     ↓
计算注意力(关键!)
Scores = Q_new @ K_cache^T: [1, 32, 1, 101]  # 注意形状!
     ↓
Softmax: [1, 32, 1, 101]
     ↓
Attention_output = Softmax @ V_cache: [1, 32, 1, 128]
     ↓
后续处理同训练

第二步生成(token 102)

输入: [1, 1, 4096]  # 刚生成的token 101
K_cache: [1, 32, 102, 128]  # 又增加了1个
V_cache: [1, 32, 102, 128]

Scores = Q_new @ K_cache^T: [1, 32, 1, 102]
...

以此类推,每次生成一个新token,KV Cache长度+1。


三、核心差异对比表

维度 训练阶段 推理阶段(Prefill) 推理阶段(Decode)
输入形状 [1, 100, 4096] [1, 100, 4096] [1, 1, 4096]
Q矩阵形状 [1, 32, 100, 128] [1, 32, 100, 128] [1, 32, 1, 128]
K矩阵形状 [1, 32, 100, 128] [1, 32, 100, 128] [1, 32, seq_len, 128](累积)
V矩阵形状 [1, 32, 100, 128] [1, 32, 100, 128] [1, 32, seq_len, 128](累积)
注意力分数形状 [1, 32, 100, 100] [1, 32, 100, 100] [1, 32, 1, seq_len]
因果掩码 ✅ 上三角掩码 ✅ 上三角掩码 ❌ 不需要(只看历史)
KV Cache ❌ 不使用 ✅ 初始化 ✅ 增量更新
并行度 高(所有token并行) 高(所有token并行) 低(逐个生成)
计算模式 批量并行计算 批量并行计算 串行自回归生成
每步处理token数 100个 100个 1个

四、内存占用对比

KV Cache内存计算(单层)

在Decode阶段,每生成到第 t 个token时:

单层 KV Cache 大小 = 2 × B × N_heads × t × d_k × 2字节
                    (K和V) × 批次 × 头数 × 序列长度 × 头维度 × BF16

                    = 2 × 1 × 32 × t × 128 × 2
                    = 16,384 × t 字节
                    ≈ 16 KB × t

32层总计 = 32 × 16 KB × t = 512 KB × t

不同序列长度的内存占用

生成到的位置 单层KV Cache 32层总KV Cache
第100个token(Prefill后) ~1.6 MB ~51 MB
第200个token ~3.2 MB ~100 MB
第500个token ~8 MB ~256 MB
第1000个token ~16 MB ~512 MB
第2048个token ~32 MB ~1 GB
第4096个token ~64 MB ~2 GB

训练阶段内存

  • 不需要KV Cache,但需要存储中间激活值用于反向传播
  • 梯度内存:需要存储所有参数的梯度
  • 优化器状态:Adam需要存储momentum和variance(2倍参数量)
  • 总体内存占用通常远大于推理阶段

五、计算复杂度对比

注意力计算复杂度

训练阶段 / Prefill阶段

  • 复杂度: O(S²·H) = O(100² × 4096)
  • 需要计算完整的 100×100 注意力矩阵
  • 所有token的Q与所有token的K进行点积
  • 高度并行化,但计算量大

Decode阶段(第t步)

  • 复杂度: O(t·H)
  • 只计算 1×t 的注意力分数
  • 单个新token的Q与历史t个token的K进行点积
  • 随着生成的进行,t不断增大,计算量线性增长

复杂度对比示例

阶段 序列长度 注意力计算FLOPs(近似) 相对比例
Prefill 100 100² × H 基准
Decode (第101步) 101 101 × H ~1%
Decode (第200步) 200 200 × H ~2%
Decode (第1000步) 1000 1000 × H ~10%

六、关键优化技术

推理阶段优化

  1. KV Cache
  2. 避免重复计算历史token的K和V
  3. 时间换空间的经典权衡
  4. 是自回归生成的核心优化

  5. Flash Attention

  6. 优化注意力计算的内存访问模式
  7. 减少HBM访问,提高计算效率
  8. 适用于训练和推理

  9. Paged Attention(vLLM)

  10. 动态管理KV Cache内存
  11. 类似操作系统的虚拟内存分页
  12. 提高内存利用率和批处理效率

  13. Continuous Batching

  14. 动态批处理不同进度的请求
  15. 提高GPU利用率
  16. 减少平均延迟

  17. Speculative Decoding

  18. 使用小模型预测多个token
  19. 大模型并行验证
  20. 加速生成过程

训练阶段优化

  1. 梯度检查点(Gradient Checkpointing)
  2. 重计算代替存储中间激活值
  3. 显著减少内存占用
  4. 增加约20-30%计算时间

  5. 混合精度训练

  6. BF16/FP16计算 + FP32累加
  7. 减少内存和提高速度
  8. 保持训练稳定性

  9. ZeRO优化器(DeepSpeed)

  10. 分布式训练内存优化
  11. 分割优化器状态、梯度、参数
  12. 可训练更大模型

  13. 张量并行 / 流水线并行

  14. 模型并行策略
  15. 突破单卡内存限制
  16. 提高训练吞吐量

七、实际应用场景对比

训练场景特点

  • 目标: 学习模型参数
  • 输入: 固定长度的训练序列
  • 输出: 损失函数和梯度
  • 关注点: 训练速度、收敛性、模型质量
  • 成本: 计算密集型,通常使用A100/H100等高端GPU集群

推理场景特点

Prefill阶段

  • 目标: 理解用户输入prompt
  • 输入: 用户的完整问题或上下文
  • 输出: 首个生成token
  • 特点: 类似训练的并行计算,但只需前向传播
  • 延迟: 取决于prompt长度,通常几十到几百毫秒

Decode阶段

  • 目标: 自回归生成响应
  • 输入: 每次一个新生成的token
  • 输出: 下一个token
  • 特点: 串行生成,受内存带宽限制
  • 延迟: 每个token约10-50ms(取决于模型大小和硬件)

八、关键要点总结

最核心的区别

  1. 输入维度差异
  2. 训练/Prefill: [B, S, H] - 完整序列
  3. Decode: [B, 1, H] - 单个token

  4. 注意力计算模式

  5. 训练/Prefill: 完整的S×S注意力矩阵,需要因果掩码
  6. Decode: 1×seq_len注意力向量,利用KV Cache增量计算

  7. 内存使用策略

  8. 训练: 存储激活值用于反向传播
  9. 推理: 存储KV Cache用于加速生成

  10. 计算瓶颈

  11. 训练/Prefill: 计算密集型(Compute-bound)
  12. Decode: 内存带宽密集型(Memory-bound)

为什么需要KV Cache?

如果没有KV Cache,在Decode阶段每生成一个新token,都需要重新计算所有历史token的K和V矩阵,这会导致:

  • 计算冗余: 历史token的K和V在每一步都被重复计算
  • 时间复杂度: O(t²) 而不是 O(t)
  • 实际影响: 生成1000个token可能需要几十秒而不是几秒

KV Cache通过空间换时间,使得自回归生成变得实用。


九、可视化对比

训练阶段注意力矩阵

Token:  1   2   3   4   5  ...  100
     ┌─────────────────────────────┐
   1                ...    
   2                ...    
   3                ...    
   4                ...    
 ...               ...            
 100                ...    
     └─────────────────────────────┘

 = 可见(有效注意力)
 = 被掩码(-inf

Decode阶段注意力向量(第101个token)

新token 101 对历史的注意力:

Token 101: [✓   ✓   ✓   ✓   ✓  ...  ✓  ✓]
            ↑   ↑   ↑   ↑   ↑       ↑  ↑
          tok1 tok2 tok3 tok4 tok5 ... tok100

所有历史token都可见,计算1×100的注意力权重

参考资料

  • Attention Is All You Need (Vaswani et al., 2017)
  • Flash Attention (Dao et al., 2022)
  • vLLM: Easy, Fast, and Cheap LLM Serving (Kwon et al., 2023)
  • Transformer Inference Optimization (多个来源)

文档版本: v1.0
创建日期: 2025-10-23
适用模型: GPT系列、LLaMA系列、Claude系列等基于Transformer的大语言模型

Thanks for Reading

If this article was helpful to you, feel free to connect with me!