涵盖了 Attention 机制原理、PyTorch Profiling 性能分析、显存优化 以及 数值调试技巧 。这四个模块并非孤立存在,它们共同构成了 "LLM 底层系统工程" 的基石。要从"学会知识点"进阶到"构建体系化能力",需要将这四块内容重组为一条 "正确性 → 可观测性 → 效率上限" 的工程闭环。以下是循序渐进认知与实践框架:
认知重构:建立三层递进模型
不要平行地看待四个主题,应按以下优先级建立层级认知:
| 层级 | 核心主题 | 认知目标 | 对应今日内容 |
|---|---|---|---|
| L1: 正确性基座 | Attention + Debugging | 确保算得对、不崩溃。这是所有优化的前提。 | SDPA 手写、NaN 排查、梯度健康检查 |
| L2: 可观测性 | Profiling | 看清黑盒内部。没有 Profiling 的优化都是盲人摸象。 | torch.profiler、TensorBoard、CPU/GPU 时间分布 |
| L3: 效率上限 | Memory Optimization | 在硬件约束下逼近理论极限。 | 显存分析、梯度累积、AMP、Checkpointing |
核心心法 :先 L1 再 L2,有 L2 才做 L3。
很多初学者跳过 Profiling 直接做显存优化,结果花了大量时间优化了一个根本不是瓶颈的模块。正确的顺序永远是:代码跑通 (L1) → Profile 找瓶颈 (L2) → 针对性优化 (L3) → 再次 Profile 验证 (L2)。
高频实践:必须形成肌肉记忆的内容
以下内容不是"了解即可",而是需要在未来每次写训练代码时条件反射般使用的:
- 数值安全三件套
masked_fill(~mask, float('-inf'))而非-1e9- Loss 记录永远用
.item()而非 append tensor optimizer.zero_grad(set_to_none=True)
- 显存监控习惯
- 训练启动后前 3 步必查
torch.cuda.max_memory_allocated() - 遇到 OOM 先看
reserved - allocated判断是碎片还是真不足
- 训练启动后前 3 步必查
- Profiling 标准动作
- 任何性能测试必须先 Warmup ≥3 步
- GPU 计时只用
cuda.Event,禁用time.time() - 报告必看
Self CPU %vsCUDA time判断 CPU/GPU 谁在等谁
- Attention 实现规范
- Scale 必须在 Softmax 之前
- 生产环境优先
F.scaled_dot_product_attention(is_causal=True) - 手写版本仅用于学习和对齐验证
深入研究:从"会用"到"理解本质"
当基础操作熟练后,需要对以下问题进行深度追问,这决定了能走多远:
Attention 的数学直觉
- 为什么是 D*D* 而不是 D*D* ? 推导点积方差与维度的关系,理解 scale 的本质是方差归一化。
- Causal Mask 的替代方案有哪些? 研究 ALiBi、RoPE 如何将位置信息注入 Q/K 而非依赖 mask,理解它们对长度外推的影响。
- Flash Attention 为什么快? 深入理解 Tiling + Online Softmax 算法,明白它如何避免物化 T×TT ×T 矩阵,以及 IO-aware 设计哲学。
显存的物理本质
- PyTorch 缓存分配器的工作原理 :理解
allocatedvsreserved的区别,研究expandable_segments如何解决碎片化。 - Activation Checkpointing 的最优粒度:不是越细越好。研究如何根据每层的 FLOPs/Activation 比值选择性地 checkpoint,平衡重计算与显存节省。
- 混合精度的数值边界:FP16 vs BF16 的动态范围差异,GradScaler 的缩放因子自适应逻辑,哪些算子必须保持 FP32(如 LayerNorm、Softmax)。
Profiling 的信号解读
- Kernel Launch Overhead 的量化:多少 μs/op 算异常?小算子融合的阈值在哪里?
- 异步执行的时序图阅读:能从 Chrome Trace 中识别出 data loading gap、sync barrier、kernel pipeline bubble。
- 内存时间线分析:区分临时 buffer、持久 activation、optimizer state 的生命周期。
体系化实践路线图
| 阶段 | 周期 | 实践任务 | 验收标准 |
|---|---|---|---|
| Week 1: 正确性验证 | 第 1 周 | 手写 SDPA 并与官方实现对齐;故意注入 NaN/Inf 并用调试工具定位 | 单元测试 pass;能在 5 分钟内定位人为注入的数值 bug |
| Week 2: 性能画像 | 第 2 周 | 对一个真实模型做完整 Profiling,输出瓶颈分析报告 | 能明确说出 Top3 瓶颈及原因(CPU bound / GPU bound / Memory bound) |
| Week 3: 显存压榨 | 第 3 周 | 对同一模型依次应用 AMP→Checkpoint→GradAccum,记录显存/速度变化曲线 | 产出优化对比表;显存降低 ≥40% 且吞吐下降 <20% |
| Week 4: 综合实战 | 第 4 周 | 从零搭建一个 mini-LLM 训练循环,集成所有最佳实践 | 代码无 NaN、显存利用率 >80%、Profile 无明显 bubble |
- 建立个人 Debug Checklist:将今天学到的排查流程固化为文档,每次遇到 NaN/OOM 时按清单逐项检查,而非凭感觉试错。
- 保留基准测试代码 :将
benchmark_batch_sizes、measure_peak_memory等工具函数封装为可复用库,后续每个新项目都直接调用。 - 关注 PyTorch 版本演进 :今天的很多"最佳实践"(如手动 AMP、手写 checkpoint)在 PyTorch 2.x+ 中已被
torch.compile和原生 API 取代。理解原理是为了更好地使用高层抽象,而非永远手写底层。 - 从单卡思维走向系统思维:今天的知识主要面向单卡。下一步应自然延伸到多卡场景:ZeRO/FSDP 如何分片显存、DDP 通信如何与计算重叠、跨 rank NaN 如何同步检测。
终极检验标准:当看到一个 LLM 训练报错或性能问题时,能在 30 秒内判断它属于 L1/L2/L3 哪个层级,并准确调用对应的工具和知识储备------这就是体系化能力的体现。
1. generate_position_ids(seq_len)。 目标:掌握列表推导式(List Comprehension),这是处理 Token 序列最基础的操作。
python
def generate_position_ids(seq_len: int) -> list[int]:
"""
生成位置编码 ID,用于 Transformer 的位置嵌入层。
Args:
seq_len: 序列长度
Returns:
从 0 开始的连续整数列表 [0, 1, 2, ..., seq_len-1]
"""
return [i for i in range(seq_len)]
虽然
list(range(seq_len))也能实现。在实际 LLM 工程中,当需要对 position_ids 做条件过滤(如跳过 padding token)时,列表推导的灵活性远高于list(range())。
2. merge_model_config(base_config, overrides)。目标 :理解字典合并与浅拷贝陷阱。这是模型配置热更新的核心操作。
python
def merge_model_config(base_config: dict, overrides: dict) -> dict:
"""
安全合并模型配置,overrides 中的键值对覆盖 base_config。不修改原始 base_config
Args:
base_config: 基础配置字典
overrides: 需要覆盖的配置项
Returns:
合并后的新字典
"""
# 使用浅拷贝避免污染原对象
merged = base_config.copy()
merged.update(overrides)
return merged
- 错误写法 :
base_config.update(overrides); return base_config→ 这会原地修改传入的 base_config,导致后续调用产生难以排查的 bug。- 浅拷贝 vs 深拷贝 :此处用
.copy()即可,因为 config 通常是扁平的 key-value 结构。如果 config 包含嵌套字典/列表且需要独立修改,则应使用copy.deepcopy()。- Python 3.9+ 替代写法 :
return {**base_config, **overrides}同样安全且更简洁,但.copy() + .update()语义更清晰,适合调试打印。
3. count_token_frequency(tokens)。目标:掌握频率统计,这是分词器分析、数据质量检查的基础技能。
python
from collections import Counter
def count_token_frequency(tokens: list[str]) -> dict[str, int]:
"""
统计 token 出现频率。
Args:
tokens: token 字符串列表
Returns:
{token: count} 字典,按出现次数降序排列
"""
freq = Counter(tokens)
# 返回普通 dict,按频率降序,方便直接打印调试
return dict(freq.most_common())
Counter是 Python 标准库中专为频率统计设计的工具,底层 C 实现,比手写for + dict快一个数量级。在分析百万级 token 语料时差异显著。.most_common()返回的结果天然适合日志输出和可视化。
4. ModelSpec.summary()。目标 :学会用类封装模型元信息,并让输出可直接用于调试。
python
class ModelSpec:
"""轻量级模型规格描述类"""
def __init__(self, name: str, vocab_size: int, hidden_dim: int, num_layers: int):
self.name = name
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.num_layers = num_layers
def summary(self) -> str:
"""
返回格式化的模型摘要字符串,可直接 print() 用于调试。
"""
lines = [
f"=== Model Spec: {self.name} ===",
f" Vocab Size : {self.vocab_size:,}",
f" Hidden Dim : {self.hidden_dim:,}",
f" Num Layers : {self.num_layers}",
f" Est. Params: ~{self._estimate_params():,}",
]
return "\n".join(lines)
def _estimate_params(self) -> int:
"""粗略参数量估算(仅用于调试展示)"""
return self.vocab_size * self.hidden_dim + \
self.num_layers * (4 * self.hidden_dim ** 2)
def __repr__(self) -> str:
return f"ModelSpec(name={self.name!r}, layers={self.num_layers})"
summary()返回 str 而非直接 print,这样既能print(spec.summary())也能写入日志文件,符合"输出可直接用于调试和打印"的要求。- 数字使用
{:,}千分位格式化,大幅提升可读性。- 额外实现
__repr__,使得在 Jupyter Notebook 或 pdb 调试中直接显示对象时有意义,而不是<__main__.ModelSpec object at 0x...>。_estimate_params()作为私有方法,体现封装思想:对外只暴露summary(),内部计算细节对使用者透明。
| 检查项 | 通过标准 |
|---|---|
| 列表推导 | generate_position_ids 不使用 list(range()) |
| 字典安全 | merge_model_config 不修改 base_config 原对象 |
| 频率统计 | 使用 Counter,结果按频次降序 |
| 类封装 | summary() 返回 str,含千分位格式化 |
| 代码风格 | 有 type hint + docstring,无魔法数字 |
5. build_causal_mask(seq_len)。目标 :掌握布尔掩码生成与广播,这是 Decoder-only 模型防止"未来信息泄露"的关键。einsum 和广播机制是理解 Attention、LayerNorm 等算子的"通用语言"。
python
import numpy as np
def build_causal_mask(seq_len: int) -> np.ndarray:
"""
生成因果掩码 (Causal Mask)。
True/1 表示允许 attend,False/0 表示需要被 mask 掉。
Shape: (seq_len, seq_len)
示例 (seq_len=3):
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
# np.tril 直接生成下三角矩阵,简洁且高效
mask = np.tril(np.ones((seq_len, seq_len), dtype=np.bool_))
return mask
- 为什么用
bool_而非float? 布尔掩码在后续与 attention scores 结合时,可通过np.where(mask, scores, -np.inf)安全地将未来位置设为负无穷。若直接用 float 0/1 做乘法,softmax 后 0 位置仍会有非零概率,导致信息泄露。- 广播预备 :实际 Attention 中 mask 需要广播到
(batch, num_heads, seq_len, seq_len),此处返回 2D 形状正是为了后续自动广播对齐。
6. matmul_einsum(a, b)。目标:用 einsum 表达最基础的矩阵乘法,建立"下标即语义"的思维。
python
def matmul_einsum(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""
用 einsum 实现标准矩阵乘法 C = A @ B
Args:
a: shape (M, K)
b: shape (K, N)
Returns:
shape (M, N)
"""
# i: M维度, k: 收缩维度(求和), j: N维度
return np.einsum('ik,kj->ij', a, b)
Einsum 阅读口诀:
- 逗号左边:输入张量的轴标签
- 箭头右边:输出张量保留的轴标签
- 未出现在右边的标签 (如
k):沿该轴求和(收缩)'ik,kj->ij'等价于 Cij=∑kAik⋅BkjC_{ij}=∑kA{ik}⋅B_{kj}Cij=∑kAik⋅Bkj ,这正是矩阵乘法的数学定义。- 调试技巧:写 einsum 前先在纸上画出每个张量的 shape 并标注轴名,确认收缩维度大小一致后再写代码。
7. batch_attention_scores(q, k)。目标:用 einsum 表达带 batch 维度的多维张量运算,这是 Multi-Head Attention 的核心步骤。
python
def batch_attention_scores(q: np.ndarray, k: np.ndarray) -> np.ndarray:
"""
计算 batch 下的 Attention 分数 (未缩放、未 mask)。
Args:
q: shape (batch, seq_len_q, d_model)
k: shape (batch, seq_len_k, d_model)
Returns:
scores: shape (batch, seq_len_q, seq_len_k)
"""
# b: batch, i: seq_q, j: seq_k, d: d_model (收缩)
scores = np.einsum('bid,bjd->bij', q, k)
return scores
b同时出现在 q 和 k 中:表示 batch 维度是对齐的,einsum 会逐 batch 独立计算,不会跨 batch 混合。d只出现在左边:表示沿 feature 维度做点积求和,这正是 attention score 的语义。- 为什么不先 reshape 再 matmul? 虽然
q @ k.transpose(0,2,1)也能实现,但 einsum 的优势在于公式即文档 ------看到'bid,bjd->bij'就能立刻理解运算逻辑,无需脑内模拟 transpose 和维度对齐。这在推导更复杂的 GQA、RoPE 时优势巨大。- 实际工程中 :还需除以
sqrt(d_model)并应用 causal mask,本题仅聚焦 einsum 表达。
8. rms_normalize(x, eps)。目标 :用 NumPy 实现 RMSNorm,掌握轴向归约 + 广播回扩这一 LLM 中最频繁的张量操作模式。
python
def rms_normalize(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
"""
Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)
Args:
x: shape (..., d_model),支持任意前导维度
eps: 数值稳定性小量
Returns:
归一化后的张量,shape 与 x 相同
"""
# 1. 沿最后一个轴计算均方根 keepdims=True 是关键!保持维度以便广播
rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
# 2. 广播除法:(..., 1) 自动广播到 (..., d_model)
return x / rms
为什么
keepdims=True至关重要?
写法 mean 输出 shape 能否广播回 x axis=-1(无 keepdims)(...)维度丢失,需手动 reshape axis=-1, keepdims=True(..., 1)自动广播,代码简洁
- RMSNorm vs LayerNorm:RMSNorm 不减均值、不学 affine 参数(或单独乘 weight),计算更快,是 LLaMA/Qwen 等主流模型的标配。
eps的位置 :放在sqrt内部而非外部,避免除零的同时保证梯度稳定。...通配符:函数签名支持任意前导维度(batch、seq_len、num_heads 等),这正是 LLM 工程中归一化函数的通用设计模式。
| 检查项 | 通过标准 |
|---|---|
| Causal Mask | 使用 bool 类型,下三角为 True |
| Einsum 基础 | 'ik,kj->ij' 结果与 a @ b 完全一致 |
| Batch Attention | 'bid,bjd->bij',batch 维度不参与收缩 |
| RMSNorm | keepdims=True,支持任意前导维度 |
| Shape 意识 | 每个函数都有明确的 shape 注释 |
完成实现后,强烈建议用以下代码做数值正确性校验:
python
# 验证 einsum matmul 与原生 matmul 一致性
a = np.random.randn(4, 8)
b = np.random.randn(8, 6)
assert np.allclose(matmul_einsum(a, b), a @ b)
# 验证 RMSNorm 输出尺度
x = np.random.randn(2, 10, 64) * 10
y = rms_normalize(x)
# RMSNorm 后每个样本的 RMS 应接近 1
rms_after = np.sqrt(np.mean(y**2, axis=-1))
assert np.allclose(rms_after, 1.0, atol=1e-4)
掌握这四个函数后,将具备直接阅读和手写 Multi-Head Attention、SwiGLU、RoPE 等 Chapter 2+ 核心算子的数学表达能力。记住:先标 shape,再写下标,最后验数值。
Tensor vs NumPy Array
| 特性 | NumPy Array | PyTorch Tensor |
|---|---|---|
| 设备 | 仅 CPU | CPU 或 GPU |
| 自动求导 | ❌ | ✅ |
| 深度学习优化 | ❌ | ✅ |
| 互操作性 | - | 可与 NumPy 互转 |
创建 Causal Mask
-
题目提示中写的是
torch.triu()(上三角),但根据文档示例和自回归生成的语义(位置 i 只能看到 ≤i 的位置),我们需要的是下三角矩阵 。使用triu会导致模型只看未来不看过去,完全违背因果律。
python
import torch
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
创建 Causal Mask (下三角布尔矩阵)
Args:
seq_len: 序列长度
Returns:
mask: shape (seq_len, seq_len), dtype=torch.bool
mask[i, j] = True 表示位置 i 可以 attend 到位置 j
"""
# 使用 tril (下三角) 而非 triu
# dtype=bool 节省显存,且与后续 masked_fill / attention_mask 兼容
mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
return mask
# 验证
mask = create_causal_mask(4)
print(mask)
# tensor([[ True, False, False, False],
# [ True, True, False, False],
# [ True, True, True, False],
# [ True, True, True, True]])
-
工程实践要点:
- 为什么用
bool而非float? 布尔 mask 在 GPU 上占用显存仅为 float32 的 1/4。在后续 Attention 计算中,通过scores.masked_fill(~mask, float('-inf'))即可安全屏蔽未来位置。若用 float 0/1 做乘法,softmax 后被 mask 的位置仍会有非零概率。 - 广播友好性 :返回 2D shape
(seq_len, seq_len)而非 4D,PyTorch 会自动广播到(batch, num_heads, seq_len, seq_len),避免不必要的内存分配。
- 为什么用
多头注意力的维度变换
这是 Multi-Head Attention 中最容易出错的步骤,核心是 reshape + permute 的顺序不能反。
python
def split_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
将 (batch, seq, hidden) 拆分为 (batch, num_heads, seq, head_dim)
Args:
x: shape (batch, seq, hidden)
num_heads: 注意力头数,hidden 必须能被其整除
Returns:
shape (batch, num_heads, seq, head_dim)
"""
batch, seq, hidden = x.shape
assert hidden % num_heads == 0, f"hidden({hidden}) must be divisible by num_heads({num_heads})"
head_dim = hidden // num_heads
# Step 1: reshape 拆分 hidden 维度 → (batch, seq, num_heads, head_dim)
# Step 2: permute 将 num_heads 移到 seq 前面 → (batch, num_heads, seq, head_dim)
x = x.reshape(batch, seq, num_heads, head_dim)
x = x.permute(0, 2, 1, 3)
return x
# 验证
x = torch.randn(2, 10, 512)
x_split = split_heads(x, num_heads=8)
print(x_split.shape) # torch.Size([2, 8, 10, 64])
- 为什么不能先 permute 再 reshape? 如果先
permute(0,2,1)得到(batch, hidden, seq),再 reshape 会把 hidden 和 seq 的元素混在一起,导致语义错乱(不同 token 的特征被拼接到了同一个 head 里)。必须先在原始维度上拆分 hidden,再移动维度顺序。reshapevsview:此处推荐用reshape。虽然输入通常是连续的,但在某些 pipeline 中 x 可能经过 transpose 变得不连续,view会直接报错,而reshape会自动处理(必要时拷贝),代码更健壮。- 逆操作 :Attention 输出后需要合并 heads,逆操作为
x.permute(0, 2, 1, 3).reshape(batch, seq, hidden),顺序同样不能反。
批量矩阵乘法
用于计算 Attention Scores ( QKTQK^TQKT) 的核心算子。
python
def batch_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
批量矩阵乘法: C[b,i,j] = Σ_k A[b,i,k] * B[b,k,j]
Args:
a: shape (batch, n, m)
b: shape (batch, m, p)
Returns:
shape (batch, n, p)
"""
# 推荐使用 @ 运算符(等价于 torch.bmm,但更简洁且支持广播)
return a @ b
# 验证
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 5)
c = batch_matmul(a, b)
print(c.shape) # torch.Size([2, 3, 5])
@vstorch.bmmvstorch.matmul选型指南:
操作符 严格 3D 支持广播 可读性 推荐场景 torch.bmm✅ 仅 3D ❌ 低 旧代码兼容 @/torch.matmul✅ ✅ 高 现代 PyTorch 首选 torch.einsum✅ ✅ 中 复杂收缩/非标准乘法
- 为什么优先用
@? 它底层调用 cuBLAS 的cublasSgemmBatched,性能与bmm完全一致,但支持广播(例如a: (batch, n, m)@b: (1, m, p)自动广播),在 GQA、KV Cache 等场景中无需手动 expand。- 数值精度注意 :在 FP16/BF16 混合精度训练中,
@默认使用 Tensor Core,可能有微小数值差异。如需严格复现 FP32 结果,可用torch.matmul(a.float(), b.float()).half()。
| 检查项 | 通过标准 | 常见错误 |
|---|---|---|
| Causal Mask | tril + bool dtype |
误用 triu;用 float 类型 |
| split_heads | 先 reshape 再 permute | 顺序颠倒导致语义错乱 |
| batch_matmul | 使用 @,shape 正确 |
手动 for 循环逐 batch 乘 |
| 内存连续性 | 理解 view 失败原因 | 对非连续 tensor 强行 view |
| 设备安全 | .to(device) 而非硬编码 cuda |
在无 GPU 环境崩溃 |
每次写完张量操作,养成三行验证法的习惯:
python
# 1. Shape 断言
assert x_split.shape == (2, 8, 10, 64), f"Expected (2,8,10,64), got {x_split.shape}"
# 2. 数值 sanity check(如 mask 对角线应全为 True)
assert mask.diagonal().all(), "Causal mask diagonal must be all True"
# 3. 梯度连通性(确保没有意外 detach)
x_test = torch.randn(2, 10, 512, requires_grad=True)
y = split_heads(x_test, 8)
loss = y.sum()
loss.backward()
assert x_test.grad is not None, "Gradient disconnected!"
- 掌握这三个练习后,具备了手写完整 Multi-Head Attention 模块的张量操作基础。建议接下来结合 Chapter 01 的 einsum 知识,尝试用纯 PyTorch 实现一个完整的 Decoder-only Attention 层。
PyTorch 的 Autograd 机制是 LLM 训练的引擎。理解它不仅能写出正确的训练循环,还能在遇到梯度爆炸、显存溢出或自定义算子时快速定位问题。
实现 Softmax 的自定义 autograd
Softmax 是 LLM 输出层和 Attention 的核心。数值稳定性 和反向传播的高效实现是本题的两个关键考点。
python
import torch
class MySoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
"""
数值稳定的 Softmax 前向传播
Args:
input: shape (batch, num_classes)
Returns:
output: shape (batch, num_classes), 每行和为 1
"""
# 减去最大值防止 exp() 溢出(数值稳定性核心)
x_max = input.max(dim=1, keepdim=True).values
exp_x = torch.exp(input - x_max)
output = exp_x / exp_x.sum(dim=1, keepdim=True)
# 保存 softmax 输出而非原始输入,反向传播可直接复用
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""
Softmax 反向传播的高效向量形式
Args:
grad_output: dL/dy, shape (batch, num_classes)
Returns:
grad_input: dL/dx, shape (batch, num_classes)
"""
output, = ctx.saved_tensors
# 向量化公式: dx_i = y_i * (dy_i - Σ_j(dy_j * y_j))
# 避免 O(n²) 的 Jacobian 矩阵计算
sum_dy_y = (grad_output * output).sum(dim=1, keepdim=True)
grad_input = output * (grad_output - sum_dy_y)
return grad_input
# 验证
softmax = MySoftmax.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = softmax(x)
# 1. 前向正确性:每行和为 1
print("Row sums:", y.sum(dim=1)) # tensor([1., 1.], dtype=torch.float64)
# 2. 反向正确性:使用 gradcheck 自动验证
from torch.autograd import gradcheck
assert gradcheck(MySoftmax.apply, x, eps=1e-6), "Gradient check failed!"
print("Gradient check passed!")
- 为什么 save
output而非input? Softmax 的梯度公式 ∂L∂xi=yi(∂L∂yi−∑j∂L∂yjyj)\frac{∂L}{∂xi}=yi(\frac{∂L}{∂yi}−∑_j\frac{∂L}{∂yj}yj)∂xi∂L=yi(∂yi∂L−∑j∂yj∂Lyj) 只依赖 y (softmax 输出),不依赖原始 x 。保存 y 避免了反向传播时重新计算 softmax,节省算力。- 为什么不用 Jacobian 矩阵? Softmax 的完整 Jacobian 是 n×n 矩阵,直接计算复杂度为 O(n2)。上述向量化公式将复杂度降至 O(n) ,这在 vocab_size=128K 的 LLM 中差距巨大。
dtype=torch.double:gradcheck要求双精度浮点,因为数值梯度对精度极其敏感。实际训练中用 float32/bfloat16 即可。
梯度累积模拟大 Batch
LLM 训练中显存往往不够放下目标 batch size,梯度累积是在不增加显存的前提下等效增大 batch size 的标准技术。
python
def train_with_gradient_accumulation(model, data_loader, optimizer,
accumulation_steps: int = 4):
"""
梯度累积训练循环
Args:
model: PyTorch 模型
data_loader: 数据加载器
optimizer: 优化器
accumulation_steps: 累积步数,等效 batch = micro_batch × accumulation_steps
"""
model.train()
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(data_loader):
# 1. 前向传播
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
# 2. 损失缩放:保证累积后的梯度与真实大 batch 等价
scaled_loss = loss / accumulation_steps
# 3. 反向传播(梯度自动累加到 .grad)
scaled_loss.backward()
# 4. 每 accumulation_steps 步更新一次参数
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# 处理末尾不足 accumulation_steps 的残余梯度
if len(data_loader) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()
- 为什么要除以
accumulation_steps? 真实大 batch 的 loss 是所有样本 loss 的平均值。如果不除,累积后的梯度会是真实值的 N 倍,等效于学习率放大了 N 倍,导致训练发散。zero_grad()的位置 :必须在optimizer.step()之后清零,而非之前。如果在 step 之前清零,刚累积的梯度就被丢弃了。- 残余批次处理 :当 dataloader 长度不能被
accumulation_steps整除时,最后几步的梯度不会被自动更新。生产代码中必须处理这个边界情况,否则最后一个不完整 batch 的训练信号丢失。- 与 AMP 配合 :若使用混合精度,
scaler.scale(scaled_loss).backward()中的 scale 因子会与/accumulation_steps叠加,需注意顺序。
实现梯度裁剪
LLM 训练(尤其是 RNN/Transformer 早期阶段)极易梯度爆炸。梯度裁剪是保障训练稳定性的安全阀。
python
def clip_gradients(model: torch.nn.Module, max_norm: float = 1.0) -> float:
"""
按全局范数裁剪梯度,返回裁剪前的总梯度范数(用于监控)
Args:
model: 模型
max_norm: 最大允许梯度范数
Returns:
total_norm: 裁剪前的梯度总范数
"""
# 直接使用 PyTorch 官方实现,底层 C++ 优化,支持分布式
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
return total_norm.item()
# --- 手动实现版本(仅供理解原理)---
def clip_gradients_manual(model: torch.nn.Module, max_norm: float = 1.0) -> float:
"""手动实现全局梯度裁剪"""
# 1. 收集所有有梯度的参数
grads = [p.grad for p in model.parameters() if p.grad is not None]
if not grads:
return 0.0
# 2. 计算全局 L2 范数(不是逐参数裁剪!)
total_norm = torch.sqrt(sum(g.data.norm(2).item() ** 2 for g in grads))
# 3. 仅当超限时按比例缩放
if total_norm > max_norm:
clip_coef = max_norm / (total_norm + 1e-6)
for g in grads:
g.data.mul_(clip_coef)
return total_norm
关键区分:全局裁剪 vs 逐参数裁剪
方法 API 语义 LLM 推荐 全局范数裁剪 clip_grad_norm_将所有参数梯度视为一个向量,统一缩放 ✅ 标准做法 逐值裁剪 clip_grad_value_每个梯度分量独立 clamp 到 -v, v ❌ 破坏梯度方向 逐参数范数裁剪 手动循环 norm 每个参数独立缩放 ❌ 失去全局比例关系
- 为什么返回
total_norm? 在 LLM 训练监控面板(如 WandB/TensorBoard)中记录梯度范数是诊断训练健康度的核心指标。范数突然飙升通常预示 loss spike 或数据异常。clip_grad_norm_的下划线 :表示原地操作(in-place),直接修改.grad,无额外内存分配。- 分布式训练注意 :在 FSDP/DDP 下,
clip_grad_norm_会自动跨 GPU 聚合梯度范数再裁剪,手动实现需要额外调用dist.all_reduce。
| 检查项 | 通过标准 | 常见错误 |
|---|---|---|
| 自定义 Function | gradcheck 通过;forward 数值稳定 |
忘记 save_for_backward;反向公式用 Jacobian |
| 梯度累积 | loss 除以 N;step 后 zero_grad;处理残余 | 不除 N 导致梯度放大 N 倍 |
| 梯度裁剪 | 用 clip_grad_norm_;记录 total_norm |
误用 clip_grad_value_ |
| 推理模式 | model.eval() + torch.no_grad() |
推理时仍追踪梯度导致 OOM |
| detach 使用 | 明确知道何时切断计算图 | 滥用 detach 导致梯度断裂 |
当梯度行为异常时,使用以下诊断手段:
python
# 1. 检测 NaN/Inf 梯度
for name, param in model.named_parameters():
if param.grad is not None and not torch.isfinite(param.grad).all():
print(f"Non-finite gradient in {name}")
# 2. 监控各层梯度范数分布
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_norm={param.grad.norm().item():.4f}")
# 3. 验证自定义 Function 的梯度(开发阶段必做)
# 始终用 double 类型测试
x = torch.randn(4, 5, dtype=torch.double, requires_grad=True)
assert gradcheck(MySoftmax.apply, x, eps=1e-6, atol=1e-4)
掌握这三个练习后,具备调试 LLM 训练过程中梯度相关问题的能力。建议下一步结合 Chapter 03 的 Tensor 操作,尝试手写一个带梯度检查的完整 Transformer Block 前向+反向传播。
nn.Module 是 PyTorch 构建所有神经网络的基石。在 LLM 工程中,从简单的 MLP 到复杂的 Transformer Block,都依赖于正确的模块封装、参数注册与组合。
SimpleLinear: 最小线性层封装 。目标 :理解 nn.Parameter 的注册机制以及 forward() 的纯粹性。
python
import torch
import torch.nn as nn
class SimpleLinear(nn.Module):
"""
手动实现的最小线性层 y = xW^T + b
用于理解 nn.Linear 底层原理及参数注册机制
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
# 必须用 nn.Parameter 包装,否则不会被 state_dict 收集
# 使用 Kaiming 初始化模拟真实线性层行为
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# forward 只做纯张量运算,不包含优化器、loss 等训练逻辑
return x @ self.weight.t() + self.bias
- 为什么必须用
nn.Parameter? 直接赋值self.weight = torch.randn(...)只是一个普通属性,PyTorch 无法感知它。只有被nn.Parameter包装(或作为子nn.Module的属性),它才会被自动加入parameters()迭代器和state_dict()。forward()的纯粹原则 :forward只定义"数据如何流动"。绝不要在forward里写optimizer.step()、loss.backward()或设备转移.to(device)。这保证了模块的可复用性和可组合性。- 权重形状约定 :PyTorch 官方
nn.Linear的 weight shape 是(out, in),计算时用x @ W^T。遵循此约定便于后续与预训练权重对齐。
TwoLayerMLP : 模块组合与激活函数**。目标:掌握通过子模块注册构建复合网络,理解 LLM 中 FFN 的基础原型。
python
class TwoLayerMLP(nn.Module):
"""
两层 MLP: Linear -> ReLU -> Linear
LLM 中 Feed-Forward Network (FFN) 的最简原型
"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
# 子模块赋值给 self.xxx 即自动注册
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.act = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
- 子模块自动注册 :只要将
nn.Module实例赋值给self的属性(如self.fc1),PyTorch 就会递归地将其参数纳入管理。不需要 手动调用add_module()(除非动态构建模块列表)。- 激活函数也应是模块 :推荐
self.act = nn.ReLU()而非在 forward 中写F.relu()。前者使网络结构在print(model)时完整可见,且方便后续替换为 GELU/SwiGLU(LLM 标配)而无需修改 forward。- LLM 扩展提示 :现代 LLM 的 FFN 通常是 SwiGLU 变体,需要三个线性层(gate, up, down)。掌握了
TwoLayerMLP的组合模式后,扩展到三层只需增加一个self.gate_proj并在 forward 中加入逐元素乘法。
count_parameters(module) : 参数统计与过滤**。目标:区分可训练/不可训练参数,这是 LLM 微调(LoRA/冻结主干)时的必备监控工具。
python
def count_parameters(module: nn.Module, only_trainable: bool = True) -> int:
"""
统计模型参数数量
Args:
module: PyTorch 模块
only_trainable: True=仅统计 requires_grad=True 的参数
False=统计全部参数
Returns:
参数总数(标量整数)
"""
if only_trainable:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in module.parameters())
# 使用示例
model = TwoLayerMLP(768, 2048, 768)
total = count_parameters(model, only_trainable=False)
trainable = count_parameters(model, only_trainable=True)
print(f"Total: {total:,} | Trainable: {trainable:,}")
parameters()vsnamed_parameters():parameters()返回参数张量迭代器,适合计数和优化器传入;named_parameters()返回(name, tensor)元组,适合按名称过滤(如 LoRA 只解冻lora_前缀参数)。numel()vssize():numel()返回标量元素总数,size()返回形状元组。计数必须用numel()。- LLM 微调场景 :在 Full Fine-tuning 中
only_trainable=True等于总参数;但在 LoRA/QLoRA 中,可训练参数可能仅占总参数的 0.1%~1%。打印两者对比是验证冻结策略是否生效的第一步调试手段。- 格式化输出 :使用
{:,}千分位分隔符。LLM 参数量动辄数十亿,1,234,567,890比1234567890可读性高一个数量级。
| 检查项 | 通过标准 | 常见错误 |
|---|---|---|
| 参数注册 | 所有权重用 nn.Parameter 或子模块 |
直接用 torch.Tensor 赋值 |
| forward 纯粹性 | 仅含张量运算,无训练逻辑 | 在 forward 中调用 backward/step |
| 模块组合 | 子模块通过 self.xxx 注册 |
在 forward 中临时创建模块 |
| state_dict | model.state_dict() 包含所有参数 |
参数未注册导致保存缺失 |
| 参数统计 | 区分 trainable/total,用 numel() |
用 len(parameters()) 误算为层数 |
确保 Module 是正确的,完成实现后,务必执行以下三项验证:
python
model = TwoLayerMLP(768, 2048, 768)
# 1. 结构可视化:确认子模块和激活函数都正确注册
print(model)
# 2. state_dict 完整性:确认所有参数都可序列化
sd = model.state_dict()
assert 'fc1.weight' in sd and 'fc2.bias' in sd, "Missing keys in state_dict!"
# 3. 梯度连通性:确认前向→反向链路无断裂
x = torch.randn(2, 10, 768)
loss = model(x).sum()
loss.backward()
for name, p in model.named_parameters():
assert p.grad is not None, f"Gradient disconnected at {name}"
掌握这三个练习后,具备了构建任意复杂度 LLM 组件(Attention、FFN、Embedding、RMSNorm)的模块化能力。下一步建议将这些基础模块组合成一个完整的 Transformer Decoder Layer,并验证其 state_dict 能与 HuggingFace Transformers 的键名对齐。
优化器与损失函数是连接"模型输出"与"参数更新"的桥梁。在 LLM 训练中,CrossEntropy 是绝对核心,而 AdamW 则是事实标准。
mse_loss(pred, target) : 回归损失基础**。目标:理解均方误差的数学形式及其梯度特性,这是所有损失函数的起点。
python
import torch
import torch.nn.functional as F
def mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
手写均方误差 (Mean Squared Error)
Args:
pred: 预测值, shape (*)
target: 真实值, shape (*),必须与 pred 形状一致
Returns:
标量 loss
"""
# 等价于 F.mse_loss(pred, target, reduction='mean')
# 先求差 → 平方 → 取均值
return ((pred - target) ** 2).mean()
- 为什么用
.mean()而非.sum()? MSE 的梯度为 2N(pred−target)\frac{2}{N}(pred−target)N2(pred−target) 。使用 mean 使梯度大小与 batch size 无关,切换 batch size 时无需调整学习率。若用 sum,梯度会随 N 线性增长,导致训练不稳定。- LLM 中的 MSE 场景:虽然语言建模主用 CrossEntropy,但 MSE 仍用于:① Value Head(RLHF/PPO);② Knowledge Distillation 中的 logits 对齐;③ Embedding 模型的对比学习辅助损失。
- 数值安全 :
(pred - target) ** 2在 FP16/BF16 下通常安全,但若差值极大可能溢出。生产代码可考虑F.mse_loss,其内部有额外的数值保护。
cross_entropy_loss(logits, target) : LLM 的核心损失**。目标:掌握分类交叉熵的正确调用方式,理解 logits vs probabilities 的关键区别。
python
def cross_entropy_loss(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
分类交叉熵损失
Args:
logits: 原始模型输出 (未过 softmax!), shape (batch, num_classes)
target: 类别索引, shape (batch,), dtype=torch.long
或 soft labels, shape (batch, num_classes), dtype=torch.float
Returns:
标量 loss
"""
# F.cross_entropy 内部自动执行 log_softmax,数值更稳定
# 切勿先手动 softmax 再传入!会导致双重 softmax + 数值不稳定
return F.cross_entropy(logits, target)
# 验证示例
logits = torch.randn(4, 10) # batch=4, vocab=10
target = torch.randint(0, 10, (4,)) # 硬标签
loss = cross_entropy_loss(logits, target)
print(f"CE Loss: {loss.item():.4f}")
做法 正确性 原因 F.cross_entropy(logits, target)✅ 内部 fused log_softmax,数值稳定且快 F.nll_loss(F.log_softmax(logits), target)✅ 等价写法,略慢 F.cross_entropy(F.softmax(logits), target)❌ 双重 softmax,梯度错误 -torch.log(softmax(logits)[..., target])❌ 手动实现易溢出,不支持 label smoothing
- Logits ≠ Probabilities :LLM 输出的最后一层线性层结果叫 logits,范围 (−∞,+∞)(−∞,+∞) 。永远不要对 logits 做 softmax 后再传给
F.cross_entropy。- Label Smoothing :LLM 训练几乎必用 label smoothing(通常 0.1)。
F.cross_entropy(logits, target, label_smoothing=0.1)一行搞定,防止模型过度自信,提升泛化。- Ignore Index :处理 padding token 时用
ignore_index=-100,被忽略位置的 loss 和梯度均为 0,避免无效 token 污染训练信号。
train_one_step(...): 最小训练步模板 。目标:将前向、反向、更新串联为原子操作,建立正确的梯度管理肌肉记忆。
python
def train_one_step(
model: torch.nn.Module,
x: torch.Tensor,
target: torch.Tensor,
optimizer: torch.optim.Optimizer
) -> float:
"""
最小完整训练步
Args:
model: 模型(已处于 train() 模式)
x: 输入数据
target: 标签
optimizer: 优化器
Returns:
本步 loss 值(Python float,脱离计算图)
"""
# Step 0: 清零梯度(必须在 backward 之前!)
optimizer.zero_grad()
# Step 1: 前向传播
logits = model(x)
loss = cross_entropy_loss(logits, target)
# Step 2: 反向传播(梯度自动累积到 .grad)
loss.backward()
# Step 3: 参数更新
optimizer.step()
# 返回 .item() 脱离计算图,防止内存泄漏
return loss.item()
zero_grad → forward → backward → step 正确 forward → backward → zero_grad → step 梯度被清空,step 无效 forward → backward → step → zero_grad 也正确(PyTorch 官方推荐)PyTorch 官方文档现在推荐 step → zero_grad 的顺序,因为某些优化器(如 AdamW)在 step 时会利用当前梯度做 weight decay,先 zero_grad 不影响 step,但能让下一次 forward 开始时梯度缓冲区干净。两种顺序都正确,但绝不能把 zero_grad 放在 backward 和 step 之间。
loss.item()的重要性 :直接返回loss张量会保留整个计算图的引用,导致显存持续增长直至 OOM。记录 loss 曲线时必须 用.item()。- LLM 训练增强版 :实际 LLM 训练循环还需加入:①
scaler.scale(loss).backward()(AMP);②clip_grad_norm_(梯度裁剪);③scheduler.step()(学习率调度)。本题是最小骨架,这些组件在此基础上叠加。
| 检查项 | 通过标准 | 常见错误 |
|---|---|---|
| MSE | 用 .mean(),梯度与 BS 无关 |
用 .sum() 导致梯度爆炸 |
| CE | 传入 raw logits,不手动 softmax | 先 softmax 再传 CE |
| 训练步 | zero_grad 不在 backward/step 之间 | 忘记 .item() 导致 OOM |
| 设备一致性 | model/x/target/optimizer 同设备 | CPU/GPU 混用报错 |
| Loss 下降 | 小样本上 10 步内 loss 明显降低 | 学习率过大/过小或 bug |
**必做验证:Loss 是否真的在下降?**完成三个函数后,必须用以下代码验证端到端训练有效性:
python
# 构造一个可被完美拟合的小数据集
torch.manual_seed(42)
model = torch.nn.Linear(8, 5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)
x = torch.randn(16, 8)
target = torch.randint(0, 5, (16,))
losses = []
for step in range(50):
loss_val = train_one_step(model, x, target, optimizer)
losses.append(loss_val)
# 验证:最终 loss 应显著低于初始 loss
assert losses[-1] < losses[0] * 0.5, f"Loss not decreasing! {losses[0]:.4f} → {losses[-1]:.4f}"
print(f"Training verified: {losses[0]:.4f} → {losses[-1]:.4f}")
如果 loss 不下降,按此顺序排查:① 学习率是否合理(AdamW 通常 1e-4 ~ 3e-4);② target dtype 是否为 long;③ 模型是否有 requires_grad=True 的参数;④ 是否意外在 no_grad() 上下文中训练。
掌握这三个练习后,已具备 LLM 训练循环的最小可行基础。下一步建议在此基础上叠加 AMP + 梯度累积 + LR Scheduler,构建完整的预训练/微调 pipeline。
一个生产级的训练循环不仅仅是 for 循环的嵌套,它是状态管理、资源调度和可观测性的综合体。在 LLM 训练中,一次中断可能意味着数百万美元的算力浪费,因此"断点续训"、"指标监控"和"防过拟合"是工程底线。
完整训练流程:三合一工程模板
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import os
# ==================== 1. 早停机制 ====================
class EarlyStopping:
"""验证指标停滞时自动终止训练"""
def __init__(self, patience=5, min_delta=0.0, mode='min'):
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, score):
if self.best_score is None:
self.best_score = score
return True
improved = (score < self.best_score - self.min_delta) if self.mode == 'min' \
else (score > self.best_score + self.min_delta)
if improved:
self.best_score = score
self.counter = 0
return True
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False
# ==================== 2. 核心训练/验证函数 ====================
def train_one_epoch(model, loader, criterion, optimizer, device, writer=None, epoch=None):
model.train()
total_loss, correct, total = 0.0, 0, 0
pbar = tqdm(loader, desc=f"Epoch {epoch} [Train]", leave=False)
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
avg_loss = total_loss / len(loader)
accuracy = 100.0 * correct / total
if writer and epoch is not None:
writer.add_scalar('Loss/train', avg_loss, epoch)
writer.add_scalar('Acc/train', accuracy, epoch)
return avg_loss, accuracy
@torch.no_grad()
def validate(model, loader, criterion, device, writer=None, epoch=None):
model.eval()
total_loss, correct, total = 0.0, 0, 0
for inputs, targets in tqdm(loader, desc="Validating", leave=False):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
avg_loss = total_loss / len(loader)
accuracy = 100.0 * correct / total
if writer and epoch is not None:
writer.add_scalar('Loss/val', avg_loss, epoch)
writer.add_scalar('Acc/val', accuracy, epoch)
return avg_loss, accuracy
# ==================== 3. 主训练入口 ====================
def main():
# --- 配置 ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_EPOCHS = 100
SAVE_DIR = './checkpoints'
os.makedirs(SAVE_DIR, exist_ok=True)
# --- 数据 & 模型 ---
train_loader, val_loader = create_dummy_dataset() # 使用题目提供的函数
model = SimpleClassifier().to(DEVICE) # 使用题目提供的模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
# --- 调度器 & 早停 & 日志 ---
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
early_stopping = EarlyStopping(patience=7, mode='min')
writer = SummaryWriter('runs/exp_001')
best_val_loss = float('inf')
for epoch in range(NUM_EPOCHS):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE, writer, epoch)
val_loss, val_acc = validate(model, val_loader, criterion, DEVICE, writer, epoch)
current_lr = optimizer.param_groups[0]['lr']
writer.add_scalar('LR', current_lr, epoch)
print(f"Epoch {epoch}: Train Loss={train_loss:.4f} Acc={train_acc:.2f}% | "
f"Val Loss={val_loss:.4f} Acc={val_acc:.2f}% | LR={current_lr:.2e}")
# 学习率调度(基于验证损失)
scheduler.step(val_loss)
# 保存最佳模型 + 完整 Checkpoint
is_best = early_stopping(val_loss)
if is_best:
best_val_loss = val_loss
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss': best_val_loss,
}
torch.save(checkpoint, os.path.join(SAVE_DIR, 'best_checkpoint.pt'))
print(f" 💾 Best model saved (val_loss={val_loss:.4f})")
# 早停检查
if early_stopping.early_stop:
print(f"\n Early stopping triggered at epoch {epoch}")
break
writer.close()
print(f"\n Training complete. Best Val Loss: {best_val_loss:.4f}")
核心组件深度解析与避坑指南
model.train() vs model.eval():不是装饰器,是行为开关**
| 层类型 | train()模式 |
eval()模式 |
忘记切换的后果 |
|---|---|---|---|
| Dropout | 随机置零,输出乘以 1/(1−p)1/(1−p) | 恒等映射 | 推理时输出幅度偏移,精度暴跌 |
| BatchNorm | 用当前 batch 统计量,更新 running stats | 用冻结的 running_mean/var | 推理结果随 batch size 变化,不可复现 |
| Linear/Conv | 无区别 | 无区别 | 无影响 |
致命陷阱 :
model.eval()不会 禁用梯度计算!必须配合with torch.no_grad():使用。仅用eval()做验证,显存仍会被计算图占满,导致 OOM。
Checkpoint 保存策略:为什么不能只存 state_dict**?**在生产环境中,永远保存完整 checkpoint 而非仅模型权重:
python
# 仅存权重:无法恢复训练,优化器动量丢失
torch.save(model.state_dict(), 'model.pt')
# 完整 checkpoint:支持断点续训,实验完全可复现
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), # Adam 的 m/v 状态
'scheduler_state_dict': scheduler.state_dict(), # LR 调度进度
'best_val_loss': best_val_loss,
'config': {'lr': 1e-3, 'batch_size': 32}, # 超参数快照
}
torch.save(checkpoint, 'checkpoint.pt')
- Optimizer State 的重要性 :Adam/AdamW 维护每个参数的一阶矩(m)和二阶矩(v)。如果只恢复模型权重而重置优化器,相当于从第 0 步重新开始自适应学习率调整,前几十步的训练效果会严重退化。
- 加载时的设备安全 :
torch.load(path, map_location=device)防止在无 GPU 机器上加载 CUDA checkpoint 时报错。
学习率调度器选择指南
| 调度器 | 适用场景 | LLM 推荐度 | 注意事项 |
|---|---|---|---|
StepLR |
传统 CV,固定衰减 | ⭐ | 过于粗糙,LLM 几乎不用 |
ReduceLROnPlateau |
通用微调,验证驱动 | ⭐⭐⭐ | 本题推荐;patience 通常 3-5 |
CosineAnnealingLR |
LLM 预训练标配 | ⭐⭐⭐⭐⭐ | 配合 warmup 使用 |
OneCycleLR |
快速收敛,超参搜索 | ⭐⭐ | 对峰值 LR 敏感 |
LLM 预训练补充 :实际 LLM 预训练使用 Warmup + Cosine Decay ,PyTorch 原生不支持 warmup,需用
transformers.get_cosine_schedule_with_warmup或自定义 LambdaLR。本题的微调场景用ReduceLROnPlateau是最稳妥的选择。
日志与可观测性
- tqdm :提供实时反馈,
set_postfix显示关键指标。注意在分布式训练中只在 rank 0 启用 tqdm,否则终端输出混乱。 - TensorBoard :记录标量(loss/acc/lr)、直方图(梯度/权重分布)、文本(生成样本)。务必记录学习率曲线,它是诊断训练异常的第一线索。
- WandB(进阶):LLM 项目事实标准,支持超参搜索、团队协作、artifact 版本管理。API 与 TensorBoard 类似,迁移成本低。
| 检查项 | 通过标准 | 常见错误 |
|---|---|---|
| 模式切换 | train/eval 成对出现,eval 配 no_grad | 验证时未切 eval,BN/Dropout 污染结果 |
| 梯度清零 | zero_grad() 在 backward() 前 |
放在 step 后也可,但绝不能夹在中间 |
| Checkpoint | 包含 optimizer + scheduler state | 只存 model weights,断点续训失效 |
| 早停 | 基于验证指标,有 patience 缓冲 | 基于训练 loss 早停(永远不触发或过早触发) |
| 设备安全 | 所有 tensor/model/scheduler 同设备 | CPU/GPU 混用;load 未指定 map_location |
| Loss 记录 | 使用 .item() 脱离计算图 |
直接累加 tensor 导致显存泄漏 |
从练习到生产的下一步
掌握此模板后,向 LLM 训练进阶需叠加以下组件:
- 混合精度 (AMP) :
torch.cuda.amp.autocast+GradScaler,FP16/BF16 训练提速 2-3x - 梯度累积:等效大 batch,见 Autograd 章节练习 2
- 分布式训练:DDP/FSDP,多卡/多机扩展
- Gradient Checkpointing:以计算换显存,支撑更长序列
- Flash Attention:IO-aware 的注意力实现,LLM 训练必备
此模板是所有这些高级技术的骨架。确保骨架正确,再逐步叠加血肉,是避免 LLM 训练中"玄学 bug"的最佳实践。
激活函数是神经网络中引入非线性的核心组件。在 LLM 时代,ReLU 已逐渐退出主干网络,GELU 和 SiLU(Swish)成为 Transformer 架构的标准配置。理解它们的数学本质与工程实现,是阅读和修改现代 LLM 代码的前提。
relu: 经典基线
python
import torch
import torch.nn.functional as F
import math
def relu(x: torch.Tensor) -> torch.Tensor:
"""
Rectified Linear Unit
公式: max(0, x)
"""
# 推荐:调用官方算子,底层 C++/CUDA 优化,支持 autograd
return F.relu(x)
# 不推荐手写:torch.maximum(x, torch.zeros_like(x))
# 原因:产生额外张量分配,autograd 图更复杂,性能差
- Shape 不变性:ReLU 是逐元素运算,输入输出 shape 完全一致。这是所有激活函数的基本契约。
- LLM 中的地位 :ReLU 在现代 LLM 主干中已被淘汰,原因是其在 x=0x =0 处不可导且负半轴梯度为零(dying ReLU 问题)。但在以下场景仍可见到:① 早期模型(GPT-1);② 某些 MLP 的 gate 分支变体;③ 作为教学基线。生产代码中优先使用 GELU/SiLU。
- 为什么不用手写?
F.relu在 CUDA 上是单个 fused kernel,而手写maximum会触发多个 kernel launch + 临时显存分配。在 LLM 训练中,激活函数被调用数十亿次,微小开销会被放大。
gelu_exact: LLM 的事实标准
python
def gelu_exact(x: torch.Tensor) -> torch.Tensor:
"""
Gaussian Error Linear Unit (精确版本)
公式: x * Φ(x),其中 Φ(x) 是标准正态分布的累积分布函数
注意:PyTorch F.gelu(x) 默认即为精确版本(approximate='none')
"""
# 精确 GELU:使用 erf 函数计算 CDF
# Φ(x) = 0.5 * (1 + erf(x / sqrt(2)))
cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
return x * cdf
# 等价官方调用(验证用):
# return F.gelu(x, approximate='none')
精确 vs 近似:必须区分的两个版本
版本 公式 PyTorch API 使用场景 Exact x⋅Φ(x) F.gelu(x, approximate='none')BERT, GPT-2, RoBERTa Tanh Approx 0.5x(1+tanh(2/π(x+0.044715x3)))0.5x(1+tanh(\sqrt{2/π}(x+0.044715x^3)))0.5x(1+tanh(2/π (x+0.044715x3))) F.gelu(x, approximate='tanh')GPT-J, LLaMA-1, StarCoder
- 为什么有近似版? 精确版需要计算
erf,在早期 GPU 上较慢。Tanh 近似仅用基础算术运算,速度快约 5-10%,且最大绝对误差 < 10−410−4 。- LLM 选型指南 :加载预训练权重时,必须匹配原始论文使用的版本 。用错版本会导致 logits 分布偏移,下游任务性能下降 1-3%。HuggingFace Transformers 的配置文件中
hidden_act字段明确指定了版本。- 数值安全 :
erf在 FP16 下对大值饱和为 ±1,不会溢出。BF16 下精度更好,推荐 LLM 训练使用 BF16。
silu: Swish 的现代命名
python
def silu(x: torch.Tensor) -> torch.Tensor:
"""
Sigmoid Linear Unit (aka Swish)
公式: x * σ(x) = x / (1 + exp(-x))
PyTorch 中 F.silu 即为此函数
"""
# 官方实现,内部 fused sigmoid + multiply
return F.silu(x)
# 手动等价实现(仅供理解,勿用于生产):
# return x * torch.sigmoid(x)
LLM 中的关键角色:
- SwiGLU 的核心 :现代 LLM(LLaMA-2/3, Qwen, Mistral, Gemma)的 FFN 几乎全部采用 SwiGLU 结构: SwiGLU(x)=SiLU(xWgate)⊗(xWup)SwiGLU(x)=SiLU(xWgate)⊗(xWup)SwiGLU(x)=SiLU(xWgate)⊗(xWup)。SiLU 作为 gate 激活函数,其平滑的非线性特性使梯度流更稳定,训练 loss 比 GELU 低 0.1-0.3。
- vs GELU 对比:SiLU 在负半轴有非零输出(类似 leaky),信息保留更多;GELU 负半轴衰减更快。实证表明 SiLU 在大模型规模下略优,但差距随 scale 增大而缩小。
- Fused Kernel :在实际 LLM 框架(如 Megatron-LM, vLLM)中,SiLU + element-wise multiply 通常被融合为单个 CUDA kernel(
silu_and_mul),避免中间张量的显存读写。写自定义算子时务必考虑这种融合机会。
activation_summary: 可复用的激活工厂
python
from typing import Callable, Dict
def activation_summary() -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
"""
返回可用激活函数的注册表
用于模型配置驱动的动态模块构建
"""
return {
"relu": relu,
"gelu": gelu_exact, # 精确版
"gelu_tanh": lambda x: F.gelu(x, approximate="tanh"), # 近似版
"silu": silu,
"swish": silu, # swish 是 silu 的别名
}
# 使用示例:配置驱动的 MLP 构建
def get_activation(name: str) -> Callable:
registry = activation_summary()
if name not in registry:
raise ValueError(f"Unknown activation '{name}'. Available: {list(registry.keys())}")
return registry[name]
# 在 nn.Module 中使用
class ConfigurableMLP(torch.nn.Module):
def __init__(self, d_in, d_hidden, d_out, act_name="silu"):
super().__init__()
self.fc1 = torch.nn.Linear(d_in, d_hidden)
self.act = get_activation(act_name) # 动态选择
self.fc2 = torch.nn.Linear(d_hidden, d_out)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
- 配置驱动 :LLM 实验中频繁切换激活函数做 ablation study。硬编码
F.gelu需要改源码;通过 registry + config YAML,无需触碰模型代码即可切换。- HuggingFace 对齐 :Transformers 库的
ACT2FN字典就是此模式的工业实现。自建 registry 时保持相同键名("gelu","silu","gelu_pytorch_tanh"),便于权重迁移。- 可扩展性:新增自定义激活(如 QuickGELU、ReGLU)只需在 registry 中添加一行,符合开闭原则。
每个自定义激活函数必须通过以下验证:
python
def test_activations():
"""确保手写实现与 PyTorch 官方数值一致"""
x = torch.randn(4, 8, dtype=torch.float64) # double 精度测试
# ReLU
assert torch.allclose(relu(x), F.relu(x), atol=1e-10)
# GELU exact
assert torch.allclose(gelu_exact(x), F.gelu(x, approximate='none'), atol=1e-10)
# SiLU
assert torch.allclose(silu(x), F.silu(x), atol=1e-10)
# Shape 守恒
for fn in [relu, gelu_exact, silu]:
assert fn(x).shape == x.shape, f"{fn.__name__} changed shape!"
# Autograd 连通性
for fn in [relu, gelu_exact, silu]:
x_grad = torch.randn(4, 8, requires_grad=True)
fn(x_grad).sum().backward()
assert x_grad.grad is not None, f"{fn.__name__} gradient disconnected!"
print("All activation tests passed!")
test_activations()
激活函数选型速查表
| 激活函数 | 公式 | LLM 代表模型 | 推荐度 | 备注 |
|---|---|---|---|---|
| SiLU/Swish | xσ(x) | LLaMA, Qwen, Mistral | ⭐⭐⭐⭐⭐ | 当前主流,SwiGLU 标配 |
| GELU (tanh) | tanh 近似 | GPT-J, LLaMA-1 | ⭐⭐⭐⭐ | 兼容性好,速度略快 |
| GELU (exact) | xΦ(x) | BERT, GPT-2 | ⭐⭐⭐ | 老模型权重加载必需 |
| ReLU | max(0,x) | GPT-1 | ⭐ | 仅作基线,新模型避免 |
| QuickGELU | xσ(1.702x) | OpenCLIP | ⭐⭐ | GELU 的快速近似 |
掌握这四个练习后,已具备阅读和修改任何 LLM FFN/Attention 代码中激活函数部分的能力。下一步建议结合 SwiGLU 结构,实现一个完整的 LLM-style FeedForward 模块,并用 activation_summary 使其支持配置化切换。
归一化技术是深度学习中稳定训练、加速收敛的基石。在 LLM 时代,BatchNorm 因序列长度动态变化和因果掩码兼容性问题已被淘汰,LayerNorm(及其变体 RMSNorm)成为 Transformer 的唯一标准。但理解 BatchNorm 的训练/推理双态机制,是掌握所有归一化技术的必经之路。
batch_norm_train: 训练态手写实现
python
import torch
import math
def batch_norm_train(
x: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-5
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
训练态 BatchNorm(沿 batch 维统计)
Args:
x: shape (N, C, *) --- N=batch, C=channels/features
gamma: scale 参数, shape (C,)
beta: shift 参数, shape (C,)
eps: 数值稳定性常数
Returns:
y: 归一化输出, shape 同 x
mean: 当前 batch 均值, shape (C,) --- 用于更新 running stats
var: 当前 batch 方差, shape (C,) --- 用于更新 running stats
"""
# 沿 batch 维度(dim=0)计算统计量,保留 channel 维度
# keepdim=True 保证广播正确性
reduce_dims = [0] + list(range(2, x.dim())) # (0, 2, 3, ...) for 4D input
mean = x.mean(dim=reduce_dims, keepdim=True)
var = x.var(dim=reduce_dims, unbiased=False, keepdim=True)
# 归一化 + 仿射变换
x_hat = (x - mean) / torch.sqrt(var + eps)
y = gamma * x_hat + beta
return y, mean.squeeze(), var.squeeze()
unbiased=False:BatchNorm 使用总体方差 (除以 N),而非样本方差(除以 N-1)。这与 PyTorchnn.BatchNorm默认行为一致。用错会导致推理时数值偏移。- 返回值设计 :训练态必须返回当前 batch 的 mean/var,供
update_running_stats使用。这是训练/推理分离的核心接口。- Shape 约定 :BatchNorm 对
(N, C, H, W)和(N, C)都适用,reduce_dims的动态计算使其通用。
batch_norm_eval: 推理态手写实现
python
def batch_norm_eval(
x: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
running_mean: torch.Tensor,
running_var: torch.Tensor,
eps: float = 1e-5
) -> torch.Tensor:
"""
推理态 BatchNorm(使用累积统计量)
Args:
x: shape (N, C, *)
running_mean: 训练期间累积的均值, shape (C,)
running_var: 训练期间累积的方差, shape (C,)
"""
# 推理时完全不依赖当前 batch 的统计量
# reshape 为 (1, C, 1, ...) 以支持广播
shape = [1, -1] + [1] * (x.dim() - 2)
rm = running_mean.view(shape)
rv = running_var.view(shape)
x_hat = (x - rm) / torch.sqrt(rv + eps)
return gamma * x_hat + beta
为什么推理不能用当前 batch 统计?
场景 使用当前 batch stats 使用 running stats 单样本推理 方差=0,除零错误 正常 小 batch 推理 统计量噪声大,结果不稳定 稳定 批量推理 结果随 batch 内容变化 确定性输出
- 确定性要求:生产环境中,相同输入必须产生相同输出。若用当前 batch stats,同一张图片在不同 batch 中会得到不同结果,这在部署时是不可接受的 bug。
model.eval()的本质 :就是切换到此函数路径。忘记调用eval()是 LLM 推理性能下降的最常见原因之一。
layer_norm_last_dim: Transformer 的标准归一化
python
def layer_norm_last_dim(
x: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-5
) -> torch.Tensor:
"""
LayerNorm 沿最后一个维度归一化
等价于 nn.LayerNorm(normalized_shape=x.shape[-1])
Args:
x: shape (*, D) --- D 是特征维度
gamma, beta: shape (D,)
"""
# 仅沿最后一维统计,每个样本独立归一化
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, unbiased=False, keepdim=True)
x_hat = (x - mean) / torch.sqrt(var + eps)
return gamma * x_hat + beta
为什么 LLM 只用 LayerNorm?
序列长度无关:LayerNorm 对每个 token 独立归一化,不依赖 batch 内其他样本。变长序列、padding、因果掩码都不影响统计量。
训练/推理一致:无 running stats,无需区分 train/eval 模式。这简化了分布式训练和推理部署。
Pre-Norm vs Post-Norm :现代 LLM 几乎全部采用 Pre-Norm(归一化在 Attention/FFN 之前)。Post-Norm(原始 Transformer)梯度流不稳定,深层网络难以训练。
Pre-Norm: x → LN → Sublayer → Add(x, ·) LLaMA/GPT-3/Qwen
Post-Norm: x → Sublayer → Add(x, ·) → LN 仅原始 Transformer
update_running_stats: 指数移动平均
python
def update_running_stats(
running_mean: torch.Tensor,
running_var: torch.Tensor,
batch_mean: torch.Tensor,
batch_var: torch.Tensor,
momentum: float = 0.1
) -> tuple[torch.Tensor, torch.Tensor]:
"""
指数移动平均更新累积统计量
Args:
momentum: EMA 系数,PyTorch 默认 0.1
new = (1-momentum)*old + momentum*batch
"""
# 原地更新避免额外内存分配(生产代码推荐 inplace)
new_mean = (1.0 - momentum) * running_mean + momentum * batch_mean
new_var = (1.0 - momentum) * running_var + momentum * batch_var
return new_mean, new_var
- momentum 语义反转 :PyTorch 的
momentum=0.1表示新值权重为 0.1(即旧值权重 0.9)。这与优化器中 momentum=0.9 的含义相反。混用会导致 running stats 要么震荡不收敛,要么更新过慢。- 初始化 :
running_mean初始化为 0,running_var初始化为 1。错误的初始值(如 var=0)会导致前几个 batch 的推理结果异常。- LLM 中的替代方案:由于 LLM 不使用 BatchNorm,此函数主要用于理解原理。实际 LLM 训练中,RMSNorm 完全不需要 running stats,这也是其被广泛采用的原因之一------消除了状态同步开销,对分布式训练更友好。
python
def test_normalizations():
torch.manual_seed(42)
N, C, H, W = 8, 16, 4, 4
# === BatchNorm 测试 ===
x_bn = torch.randn(N, C, H, W, dtype=torch.float64)
gamma_bn = torch.ones(C, dtype=torch.float64)
beta_bn = torch.zeros(C, dtype=torch.float64)
# 训练态对齐
y_custom, bm, bv = batch_norm_train(x_bn, gamma_bn, beta_bn)
bn_layer = torch.nn.BatchNorm2d(C).double()
bn_layer.train()
y_official = bn_layer(x_bn)
assert torch.allclose(y_custom, y_official, atol=1e-8), "BN train mismatch!"
# 推理态对齐
bn_layer.eval()
y_eval_official = bn_layer(x_bn)
y_eval_custom = batch_norm_eval(x_bn, gamma_bn, beta_bn,
bn_layer.running_mean, bn_layer.running_var)
assert torch.allclose(y_eval_custom, y_eval_official, atol=1e-8), "BN eval mismatch!"
# === LayerNorm 测试 ===
D = 768
x_ln = torch.randn(4, 10, D, dtype=torch.float64)
gamma_ln = torch.ones(D, dtype=torch.float64)
beta_ln = torch.zeros(D, dtype=torch.float64)
y_ln_custom = layer_norm_last_dim(x_ln, gamma_ln, beta_ln)
ln_layer = torch.nn.LayerNorm(D).double()
y_ln_official = ln_layer(x_ln)
assert torch.allclose(y_ln_custom, y_ln_official, atol=1e-8), "LN mismatch!"
print("All normalization tests passed!")
test_normalizations()
归一化技术选型速查表
| 技术 | 统计维度 | Running Stats | LLM 适用性 | 代表模型 |
|---|---|---|---|---|
| BatchNorm | Batch (dim=0) | 需要 | 不适用 | ResNet, CNN |
| LayerNorm | Feature (last dim) | 不需要 | 标准 | GPT, BERT |
| RMSNorm | Feature (last dim) | 不需要 | 主流 | LLaMA, Qwen, Gemma |
| GroupNorm | Group of channels | 不需要 | 特殊场景 | Diffusion, ConvNeXt |
LLM 进阶:从 LayerNorm 到 RMSNorm 。掌握本题后,下一步应实现 RMSNorm (Root Mean Square Normalization):RMSNorm(x)=xmean(x2)+ϵ⋅γRMSNorm(x)=\frac{x}{mean(x2)+ϵ}⋅γRMSNorm(x)=mean(x2)+ϵx⋅γ。相比 LayerNorm,RMSNorm 去掉了均值中心化(subtract mean),仅做缩放。实证表明这对 LLM 性能无损,但减少了约 50% 的计算量和一次全局 reduction 操作,在大模型训练中节省可观算力。LLaMA-2/3、Qwen-2、Gemma 等主流开源模型均采用 RMSNorm + Pre-Norm 组合。
Attention 机制是 Transformer 和所有现代 LLM 的基石。手写 Scaled Dot-Product Attention 不仅是理解论文公式的最佳途径,更是后续实现 Multi-Head Attention、Flash Attention 以及 KV Cache 等高级优化的前置条件。
build_causal_mask: 构建因果掩码
python
import torch
import torch.nn.functional as F
import math
def build_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
"""
构建下三角因果掩码(Causal Mask)
Args:
seq_len: 序列长度 T
device: 目标设备
Returns:
mask: shape (T, T), dtype=torch.bool
True 表示该位置【允许】被 attend,False 表示被屏蔽
"""
# 使用 tril 生成下三角矩阵
# diagonal=0 包含对角线(当前 token 可以 attend 自己)
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
return mask
- 为什么需要 Causal Mask? 自回归语言模型在训练时使用 Teacher Forcing(并行处理整个序列),但推理时是逐 token 生成的。Mask 确保第 tt 个 token 只能看到 0,t 的信息,防止"偷看未来",保证训练与推理的行为一致性。
- Bool vs Float Mask :推荐使用
torch.bool类型。PyTorch 的masked_fill和 Flash Attention 原生支持 bool mask,比 float mask(0/-inf)更省显存且语义清晰。仅在需要与旧版代码兼容时才用 float mask。- Shape 扩展 :实际 MHA 中 mask 需广播到
(B, num_heads, T, T)。返回(T, T)作为基础模板,由上层函数按需 unsqueeze,保持接口纯净。
masked_softmax: 数值稳定的带掩码 Softmax
python
def masked_softmax(
scores: torch.Tensor,
mask: torch.Tensor = None,
dim: int = -1
) -> torch.Tensor:
"""
带掩码的数值稳定 softmax
Args:
scores: attention scores, shape (*, T)
mask: bool mask, True=保留, False=屏蔽; 或 None
dim: softmax 维度
"""
if mask is not None:
# ✅ 将被屏蔽位置设为 -inf,softmax 后精确为 0
# 使用 float('-inf') 而非 -1e9,避免大值 logits 下泄漏
scores = scores.masked_fill(~mask, float('-inf'))
# ✅ PyTorch softmax 内部已做 max 减法,数值稳定
# ❌ 不要手写 exp(x)/sum(exp(x)),极易溢出
return F.softmax(scores, dim=dim)
做法 正确性 原因 masked_fill(~mask, float('-inf'))✅ softmax(-inf) = 0,精确屏蔽 masked_fill(~mask, -1e9)⚠️ FP16/BF16 下 -1e9 可能被截断;大 logits 时泄漏非零概率 scores * mask.float()❌ 乘法后未屏蔽位置仍参与 softmax 分母计算,梯度错误 手动 exp(x-max)/sum(...)❌ 冗余实现,无性能收益且易出错
- NaN 陷阱 :当某行全部被 mask (如 padding 行全为 False)时,
softmax([-inf, -inf, ...])会产生 NaN。生产代码需在 softmax 后将全 mask 行的输出置零:attn_weights = attn_weights.masked_fill(~mask.any(dim=-1, keepdim=True), 0.0)。- FP16/BF16 安全 :
float('-inf')在 FP16/BF16 下均可正确表示。但-1e9在 FP16 中会被 clamp 到 -65504,导致屏蔽失效。LLM 训练中永远使用-inf。
attention_weights: 计算注意力权重矩阵
python
def attention_weights(
q: torch.Tensor,
k: torch.Tensor,
scale: float = None,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
计算归一化注意力权重
Args:
q: Query, shape (B, T_q, D)
k: Key, shape (B, T_k, D)
scale: 缩放因子,默认 1/sqrt(D)
mask: causal mask or padding mask, shape broadcastable to (B, T_q, T_k)
Returns:
weights: shape (B, T_q, T_k),每行和为 1
"""
D = q.size(-1)
if scale is None:
scale = 1.0 / math.sqrt(D)
# QK^T 点积 + 缩放
# matmul 自动处理 batch 维广播
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# 带掩码的数值稳定 softmax
weights = masked_softmax(scores, mask=mask, dim=-1)
return weights
为什么要 Scale(除以 D)?
- 假设 Q,K的元素独立同分布、均值 0、方差 1,则 QKT 每个元素的方差为 D。
- 当 D 很大(LLM 通常 4096+)时,点积值的绝对值极大,softmax 进入梯度饱和区(输出接近 one-hot),梯度几乎为零,训练停滞。
- 除以 D 将方差重新归一化为 1,使 softmax 工作在梯度敏感区间。
- 注意 :scale 应在 softmax 之前应用。先 softmax 再 scale 等价于改变 temperature,语义完全不同。
scaled_dot_product_attention: 完整 SDPA
python
def scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None,
dropout_p: float = 0.0,
training: bool = True
) -> torch.Tensor:
"""
完整的 Scaled Dot-Product Attention
Args:
q: (B, T_q, D)
k: (B, T_k, D)
v: (B, T_k, D_v)
mask: broadcastable to (B, T_q, T_k)
dropout_p: attention dropout 概率
training: 是否处于训练模式
Returns:
output: (B, T_q, D_v)
"""
# Step 1: 计算注意力权重
attn_weights = attention_weights(q, k, mask=mask)
# Step 2: Attention Dropout(仅在训练时生效)
if dropout_p > 0.0 and training:
attn_weights = F.dropout(attn_weights, p=dropout_p, training=True)
# Step 3: 加权求和 V
output = torch.matmul(attn_weights, v)
return output
Attention Dropout ≠ Regular Dropout:Attention dropout 作用于权重矩阵(哪些连接被断开),而非特征值。它正则化的是"注意力模式"本身,防止模型过度依赖特定 token pair。LLM 中常用 0.0~0.1。
生产环境请用
F.scaled_dot_product_attention:PyTorch 2.0+ 的原生 SDPA 会自动选择最优后端(Flash Attention / Memory-Efficient / Math),比手写快 2-5x 且省显存。手写版本仅用于学习和调试。
python# 生产代码一行搞定 out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=True)is_causal=True 优化 :当传入
is_causal=True时,Flash Attention 后端会使用专门的 causal kernel,跳过上三角区域的计算,进一步提速 ~30%。永远优先使用此参数而非手动构建 mask tensor。
python
def test_sdpa():
torch.manual_seed(42)
B, T, D = 2, 8, 64
q = torch.randn(B, T, D, dtype=torch.float64)
k = torch.randn(B, T, D, dtype=torch.float64)
v = torch.randn(B, T, D, dtype=torch.float64)
# Causal mask
mask = build_causal_mask(T)
# 自定义实现
out_custom = scaled_dot_product_attention(q, k, v, mask=mask, dropout_p=0.0)
# PyTorch 官方 SDPA(math backend 确保精确对比)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
out_official = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)
assert torch.allclose(out_custom, out_official, atol=1e-8), \
f"SDPA mismatch! Max diff: {(out_custom - out_official).abs().max()}"
# 验证因果性:上三角应为 0(通过梯度间接验证)
# 直接验证:检查 weights 的上三角
weights = attention_weights(q, k, mask=mask)
upper_tri = torch.triu(weights, diagonal=1)
assert torch.allclose(upper_tri, torch.zeros_like(upper_tri), atol=1e-10), \
"Causal mask leaked future tokens!"
print("All SDPA tests passed!")
test_sdpa()
Attention 变体速查表
| 变体 | 公式差异 | LLM 代表 | 备注 |
|---|---|---|---|
| Standard SDPA | softmax(QKTD)Vsoftmax(\frac{QK^T}{\sqrt D})Vsoftmax(D QKT)V | GPT-2, BERT | 本题实现 |
| Multi-Head | 拆分 D→H×d,并行 SDPA,concat | 所有 Transformer | 下一章内容 |
| GQA/MQA | K/V head 数 < Q head 数 | LLaMA-3, Qwen-2 | 减少 KV Cache 显存 |
| RoPE | 旋转位置编码注入 Q/K | LLaMA, Qwen, Mistral | 替代绝对位置编码 |
| ALiBi | 线性偏置加到 scores 上 | BLOOM, MPT | 无需位置嵌入 |
从本题到 LLM 的进阶路径,掌握 SDPA 后,按以下顺序进阶:
- Multi-Head Attention :理解 head 拆分/合并、
viewvsreshape的内存连续性陷阱 - KV Cache :推理时缓存历史 K/V,将 O(T2)O (T 2) 降为 O(T)O (T) ,是自回归解码的核心
- Flash Attention :IO-aware 的分块算法,理解为什么它不存储完整 T×TT ×T 矩阵
- RMSNorm + RoPE + GQA:组装现代 LLM 的标准 Attention 模块
这四个练习构成了上述所有高级技术的原子操作。确保每一步的 shape、dtype、数值稳定性都经过严格验证,后续叠加复杂度时才能快速定位问题。
PyTorch Profiling 是从"代码能跑"到"代码跑得快"的关键跨越。在 LLM 训练中,性能优化直接等同于节省数十万美元的算力成本。Profiler 不仅是测量工具,更是理解 CPU-GPU 异步执行模型 、内存分配模式 和算子融合机会的诊断器。
分析模型性能瓶颈(完整解答)
python
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
def analyze_model():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel().to(device)
# 使用真实 batch size,小 batch 无法触发 GPU 并行优势
inputs = torch.randn(32, 3, 32, 32, device=device)
# Warmup:前几步 GPU kernel 编译/缓存未命中,数据不具代表性
for _ in range(3):
_ = model(inputs)
if device.type == 'cuda':
torch.cuda.synchronize()
activities = [ProfilerActivity.CPU]
if device.type == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(
activities=activities,
record_shapes=True, # 按 shape 分组,区分不同尺寸的同类算子
profile_memory=True, # 追踪显存峰值与分配
with_stack=True, # 定位源码行号
on_trace_ready=tensorboard_trace_handler('./log/model_profile')
) as prof:
for _ in range(5): # 多步采样取平均
output = model(inputs)
prof.step()
# 最耗时操作
sort_key = "cuda_time_total" if device.type == 'cuda' else "cpu_time_total"
print("=" * 80)
print("TOP 10 HOTTEST OPERATORS")
print(prof.key_averages(group_by_stack_n=5).table(
sort_by=sort_key, row_limit=10
))
# CPU vs GPU 时间比
cpu_total = sum(e.cpu_time_total for e in prof.key_averages())
gpu_total = sum(e.cuda_time_total for e in prof.key_averages()) if device.type == 'cuda' else 0
ratio = gpu_total / max(cpu_total, 1)
print(f"\n⏱️ CPU total: {cpu_total/1e3:.2f}ms | GPU total: {gpu_total/1e3:.2f}ms | Ratio: {ratio:.2f}")
if ratio < 0.5:
print(" GPU underutilized! Possible CPU bottleneck or excessive sync.")
# 内存分析
print("\n MEMORY SUMMARY")
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=5))
print(f"\n TensorBoard: tensorboard --logdir=./log/model_profile")
analyze_model()
Self CPU %vsCPU total %:Self是算子自身开销(如 kernel launch),total包含子调用。若aten::linear的 Self 很高但aten::addmm的 CUDA 时间很低,说明 kernel launch overhead 主导,应考虑算子融合或增大 batch。group_by_stack_n=5:将同名算子按调用栈分组。例如两个aten::relu分别来自 conv1 和 fc1,不加此参数会合并显示,无法定位具体层。- GPU 利用率判断 :Ratio > 2 通常健康;< 0.5 意味着 GPU 大量空闲等待 CPU。常见原因:DataLoader 慢、频繁
.item()同步、Python 循环过重。
Batch Size 性能对比(含吞吐量计算)
python
import time
def benchmark_batch_sizes(model_class, input_shape_no_batch, batch_sizes, device='cuda'):
"""系统化 batch size 扫描"""
results = []
for bs in batch_sizes:
model = model_class().to(device)
x = torch.randn(bs, *input_shape_no_batch, device=device)
# Warmup
for _ in range(5):
_ = model(x)
if device == 'cuda':
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True) if device == 'cuda' else None
end_event = torch.cuda.Event(enable_timing=True) if device == 'cuda' else None
if device == 'cuda':
start_event.record()
else:
t0 = time.perf_counter()
num_runs = 50
for _ in range(num_runs):
_ = model(x)
if device == 'cuda':
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event) / num_runs
else:
elapsed_ms = (time.perf_counter() - t0) / num_runs * 1000
throughput = bs / (elapsed_ms / 1000) # samples/sec
# 显存占用
mem_mb = torch.cuda.max_memory_allocated(device) / 1024**2 if device == 'cuda' else 0
results.append({
'batch_size': bs,
'latency_ms': round(elapsed_ms, 3),
'throughput_sps': round(throughput, 1),
'gpu_mem_mb': round(mem_mb, 1)
})
print(f"BS={bs:>4d} | Latency={elapsed_ms:>7.3f}ms | "
f"Throughput={throughput:>8.1f} s/s | Mem={mem_mb:>7.1f}MB")
torch.cuda.empty_cache()
return results
# 使用
results = benchmark_batch_sizes(
MyModel,
input_shape_no_batch=(3, 32, 32),
batch_sizes=[1, 8, 32, 64, 128, 256]
)
错误做法 正确做法 原因 time.time()测 GPUcuda.Event或synchronizeGPU 异步执行, time.time()只测 launch 时间跳过 warmup 至少 3-5 步预热 CUDA kernel JIT、cuDNN autotune、cache 填充 单次测量 ≥50 次取均值 消除 OS 调度、热节流噪声 忽略 OOM 恢复 try/except+empty_cache一个 BS 失败不应中断整个扫描
- 吞吐量 vs 延迟的权衡 :BS=1 延迟最低但吞吐最差;BS 增大吞吐提升但边际递减。最优 BS 通常在 GPU 显存占满 80-90% 时达到。超过后可能因 swap 或碎片化反而下降。
- LLM 特殊性 :LLM 的 batch size 受序列长度影响极大。应固定
tokens_per_batch = bs × seq_len做等 token 量对比,而非固定 bs。
DataLoader num_workers 优化
python
from torch.utils.data import DataLoader, TensorDataset
import time
def benchmark_dataloader(dataset, batch_size=64, worker_counts=[0, 2, 4, 8, 16]):
"""找到数据加载的最优 worker 数"""
for nw in worker_counts:
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=nw,
pin_memory=True, # 加速 CPU→GPU 传输
persistent_workers=True if nw > 0 else False # 避免每 epoch 重建进程
)
# Warmup:首个 epoch worker 初始化开销大
for i, (x, y) in enumerate(loader):
if i >= 5: break
_ = x.to('cuda', non_blocking=True)
torch.cuda.synchronize()
# Measure full epoch
start = time.perf_counter()
count = 0
for x, y in loader:
_ = x.to('cuda', non_blocking=True)
count += x.size(0)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
throughput = count / elapsed
print(f"num_workers={nw:>2d} | Epoch={elapsed:.2f}s | "
f"Throughput={throughput:.0f} samples/s")
# 模拟大数据集(实际替换为真实数据集)
dataset = TensorDataset(torch.randn(10000, 3, 224, 224), torch.randint(0, 10, (10000,)))
benchmark_dataloader(dataset)
num_workers 调优经验公式:
- 起点 :
num_workers = min(os.cpu_count(), 4 * num_gpus)- 上限信号:当增加 worker 后 epoch 时间不再缩短甚至变长,说明 CPU 争抢或 IPC 序列化成为新瓶颈。
persistent_workers=True:必开选项。默认情况下每个 epoch 结束都会销毁并重建所有 worker 进程,对于大数据集这会产生数秒的固定开销。开启后 worker 跨 epoch 复用。pin_memory=True:GPU 训练必开 。锁页内存允许 DMA 异步传输,配合non_blocking=True实现数据传输与计算重叠。不开此项,.to('cuda')是同步阻塞的。- LLM 特殊注意 :Tokenization 通常是 DataLoader 瓶颈。若 profiler 显示
collate_fn或 tokenizer 占用大量 CPU 时间,应将 tokenization 预处理到磁盘(memory-mapped dataset),而非运行时动态 tokenize。
掌握基础 Profiler 后,以下是 LLM 工程中真正决定效率的认知升级:
理解异步执行的本质
CPU Timeline: [launch A][launch B][launch C][...idle...][sync]
GPU Timeline: [........][====A====][====B====][====C====]
- Profiler 中看到的 CPU 时间大多是 kernel launch overhead(~5-20μs/op),不是真正的计算。
- 当 GPU 有连续的大 kernel 时,CPU 的空闲是正常的。只有当 GPU 出现空隙(gap)而 CPU 忙碌时,才是真正的 CPU bound。
- 使用
torch.compile可将多个小 op 融合为单个 kernel,大幅减少 launch 次数。
显存分析比时间分析更重要,LLM 训练中,OOM 是第一优先级问题。Profiler 的内存视图关注:
- Activation Checkpointing 效果:对比开启前后的 activation memory peak
- KV Cache 增长:推理时 KV cache 应线性增长,若超线性说明有泄漏
- 碎片化 :
reserved - allocated差值过大说明碎片严重,需调整分配策略或使用PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
从 Profile 到优化的决策树
GPU 利用率低?
├── GPU timeline 有间隙 → CPU bound
│ ├── DataLoader 慢 → 增加 workers / 预处理 / FFCV/WebDataset
│ ├── Python 循环重 → torch.compile / Triton / C++ extension
│ └── 频繁同步 → 移除 .item() / 异步指标聚合
├── GPU timeline 连续但 kernel 小 → Launch overhead
│ └── torch.compile / 手动 fuse / Flash Attention
└── GPU kernel 本身慢 → Compute bound
├── 精度过高 → FP16/BF16 / INT8 quantization
├── 算法低效 → Flash Attention / PagedAttention
└── 硬件未打满 → 增大 batch / tensor parallelism
生产级 Profiling Checklist
| 检查项 | 标准 | 工具 |
|---|---|---|
| Warmup | ≥3 steps before profiling | 手动 / schedule(wait=N) |
| 多步采样 | active≥3, repeat≥2 | profiler.schedule |
| Shape 分组 | record_shapes=True | key_averages(group_by_input_shape=True) |
| 调用栈 | with_stack=True | group_by_stack_n=5 |
| 内存追踪 | profile_memory=True | Memory View in TensorBoard |
| 非侵入式 | 不改业务代码 | context manager / decorator |
| 可复现 | 固定 seed + deterministic | torch.use_deterministic_algorithms(True) |
掌握这些内容后,已具备独立诊断和优化 LLM 训练/推理性能的能力。下一步建议结合 torch.compile 和 Triton 自定义算子,将 Profiler 发现的瓶颈转化为实际的性能收益。
显存是 LLM 训练和推理中最稀缺的资源。理解显存的构成、掌握分析工具、熟练运用优化技术,是从"跑通代码"到"规模化训练"的分水岭。以下是对三个实战练习的工程级解答,以及超越基础文档的 LLM 显存管理心法。
分析模型显存占用(完整解答)
python
import torch
import torch.nn as nn
class LargeModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(2048, 2048) for _ in range(20)
])
def forward(self, x):
for layer in self.layers:
x = torch.relu(layer(x))
return x
def analyze_memory():
device = 'cuda'
model = LargeModel().to(device)
batch_size, seq_len = 32, 512
x = torch.randn(batch_size, seq_len, 2048, device=device)
# 模型参数显存
param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
grad_bytes = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
print(f" Parameters: {param_bytes / 1024**2:.1f} MB")
print(f" Gradients: {grad_bytes / 1024**2:.1f} MB")
# 前向传播峰值显存
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
with torch.no_grad(): # 隔离前向,不含梯度
y = model(x)
fwd_peak = torch.cuda.max_memory_allocated()
print(f"\n Forward Peak: {fwd_peak / 1024**2:.1f} MB")
# 反向传播峰值显存(含激活值 + 梯度)
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
y = model(x)
y.sum().backward()
bwd_peak = torch.cuda.max_memory_allocated()
print(f" Backward Peak: {bwd_peak / 1024**2:.1f} MB")
# 激活值显存估算
activation_bytes = bwd_peak - param_bytes - grad_bytes
print(f"\n Activations (est): {activation_bytes / 1024**2:.1f} MB")
print(f" 占总峰值比例: {activation_bytes / bwd_peak * 100:.1f}%")
# 理论验证:每层 Linear(2048,2048) 保存输入用于反向
# 激活 ≈ 20 layers × (32 × 512 × 2048 × 4 bytes) ≈ 2621 MB
theoretical_act = 20 * batch_size * seq_len * 2048 * 4
print(f" 理论激活值: {theoretical_act / 1024**2:.1f} MB")
analyze_memory()
显存构成公式:Total VRAM=P⏟params+G⏟gradients+O⏟optimizer states+A⏟activations+T⏟temp buffers
组件 FP32 FP16/BF16 Adam Optimizer 备注 Parameters 4B/param 2B/param --- 固定开销 Gradients 4B/param 2B/param --- 与 params 同精度 Optimizer States --- --- 8B/param (FP32) Adam: m + v 各 4B Activations 4B/elem 2B/elem --- 随 batch×seq 线性增长
- 关键洞察 :在 LLM 训练中,激活值通常占显存的 60-80%,远超参数本身。这就是为什么 Gradient Checkpointing 和序列并行比量化参数更重要。
- Adam 的隐藏成本:7B 模型的 Adam 优化器状态需要 ~56GB 显存(7B × 8B),比模型参数本身还大。这是 ZeRO / FSDP 分片的核心动机。
显存优化对比(完整实现 + 基准测试)
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
import time
# === 四种训练策略实现 ===
def train_baseline(model, x, target, optimizer, criterion):
"""基线:FP32 全量训练"""
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()
return loss.item()
def train_grad_accum(model, x, target, optimizer, criterion, accum_steps=4):
"""梯度累积:模拟 4x batch,显存不变"""
optimizer.zero_grad()
# 将 input 切分为 accum_steps 份
chunks_x = x.chunk(accum_steps)
chunks_t = target.chunk(accum_steps)
total_loss = 0
for cx, ct in zip(chunks_x, chunks_t):
output = model(cx)
loss = criterion(output, ct) / accum_steps
loss.backward()
total_loss += loss.item()
optimizer.step()
return total_loss
def train_amp(model, x, target, optimizer, criterion, scaler):
"""混合精度:FP16 前向+反向,FP32 权重更新"""
optimizer.zero_grad()
with autocast(dtype=torch.float16):
output = model(x)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
return loss.item()
def train_checkpoint(model, x, target, optimizer, criterion):
"""梯度检查点:用 ~30% 额外计算换 50-80% 激活显存"""
optimizer.zero_grad()
output = model(x, use_checkpoint=True)
loss = criterion(output, target)
loss.backward()
optimizer.step()
return loss.item()
# === 修改模型支持 checkpoint ===
class CheckpointedModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(2048, 2048) for _ in range(20)])
def forward(self, x, use_checkpoint=False):
for layer in self.layers:
if use_checkpoint:
x = checkpoint(lambda inp, l=layer: torch.relu(l(inp)),
x, use_reentrant=False)
else:
x = torch.relu(layer(x))
return x
# === 统一基准测试 ===
def benchmark_strategies():
device = 'cuda'
bs, seq, dim = 32, 512, 2048
strategies = {}
# Baseline
model_base = LargeModel().to(device)
opt_base = torch.optim.Adam(model_base.parameters(), lr=1e-3)
crit = nn.CrossEntropyLoss()
strategies["Baseline"] = lambda: train_baseline(
model_base, torch.randn(bs, seq, dim, device=device),
torch.randint(0, 10, (bs,), device=device), opt_base, crit)
# Gradient Accumulation
model_ga = LargeModel().to(device)
opt_ga = torch.optim.Adam(model_ga.parameters(), lr=1e-3)
strategies["Grad Accum (4x)"] = lambda: train_grad_accum(
model_ga, torch.randn(bs, seq, dim, device=device),
torch.randint(0, 10, (bs,), device=device), opt_ga, crit)
# AMP
model_amp = LargeModel().to(device)
opt_amp = torch.optim.Adam(model_amp.parameters(), lr=1e-3)
scaler = GradScaler()
strategies["AMP (FP16)"] = lambda: train_amp(
model_amp, torch.randn(bs, seq, dim, device=device),
torch.randint(0, 10, (bs,), device=device), opt_amp, crit, scaler)
# Gradient Checkpointing
model_cp = CheckpointedModel().to(device)
opt_cp = torch.optim.Adam(model_cp.parameters(), lr=1e-3)
strategies["Grad Checkpoint"] = lambda: train_checkpoint(
model_cp, torch.randn(bs, seq, dim, device=device),
torch.randint(0, 10, (bs,), device=device), opt_cp, crit)
# 运行基准
print(f"{'Strategy':<22} {'Peak Mem (MB)':<16} {'Time (ms)':<12} {'Mem Saving':<12}")
print("-" * 65)
baseline_mem = None
for name, func in strategies.items():
# Warmup
for _ in range(3):
func()
torch.cuda.synchronize()
# Measure
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
start = time.perf_counter()
for _ in range(10):
func()
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) / 10 * 1000
peak_mb = torch.cuda.max_memory_allocated() / 1024**2
if baseline_mem is None:
baseline_mem = peak_mb
saving = "---"
else:
saving = f"{(1 - peak_mb/baseline_mem)*100:.1f}%"
print(f"{name:<22} {peak_mb:<16.1f} {elapsed_ms:<12.1f} {saving:<12}")
benchmark_strategies()
优化策略选择决策树:
OOM? ├─ 仅差一点 → 减小 batch size / gradient accumulation ├─ 差很多但可接受减速 → Gradient Checkpointing ├─ 需要速度+显存双收益 → AMP (BF16 > FP16) └─ 单卡无论如何放不下 → ZeRO/FSDP 多卡分片
- AMP 优先用 BF16 :BF16 动态范围与 FP32 相同,无需 GradScaler,训练更稳定。FP16 在 LLM 尺度下容易溢出。
autocast(dtype=torch.bfloat16)- Checkpoint 粒度 :不要对每一层都做 checkpoint。只对激活值最大的模块(如 Attention、FFN)做,Embedding/LayerNorm 等轻量层保留缓存。过度 checkpoint 会导致重计算开销超过显存收益。
- 梯度累积的正确性 :loss 必须除以
accum_steps,否则等效学习率被放大 N 倍。BatchNorm 在梯度累积下统计量会偏移,LLM 使用 LayerNorm/RMSNorm 无此问题。
排查显存泄漏(修复 + 诊断工具)
原始代码的问题分析
python
def train_with_leak(model, train_loader, optimizer, criterion):
model.train()
losses = []
for epoch in range(10):
for inputs, targets in train_loader:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = model(inputs)
loss = criterion(outputs, targets)
losses.append(loss) # BUG: 保留了完整计算图
loss.backward()
optimizer.step()
# BUG: 缺少 optimizer.zero_grad(),梯度无限累积
return losses
两个致命问题:
losses.append(loss)保留了整个计算图的引用。每个 loss tensor 持有从输入到输出的所有中间激活值的引用,导致激活值永远无法释放。10 个 epoch 后显存必然 OOM。- 缺少
optimizer.zero_grad(),梯度在参数上无限累加,不仅浪费显存(梯度 tensor 不会释放),还会导致训练完全错误。
修复后的代码
python
def train_fixed(model, train_loader, optimizer, criterion, device='cuda'):
model.train()
losses = []
for epoch in range(10):
epoch_losses = []
for inputs, targets in train_loader:
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
outputs = model(inputs)
loss = criterion(outputs, targets)
# 只保留标量值,断开计算图引用
epoch_losses.append(loss.item())
optimizer.zero_grad(set_to_none=True) # set_to_none 比 zero_ 更省显存
loss.backward()
optimizer.step()
# 每个 epoch 结束后记录平均 loss,而非保留所有 step 的 tensor
losses.append(sum(epoch_losses) / len(epoch_losses))
return losses
显存泄漏诊断工具箱
python
import gc
import torch
def diagnose_memory_leak():
"""系统化排查显存泄漏"""
# Step 1: 强制 GC + 清空缓存
gc.collect()
torch.cuda.empty_cache()
# Step 2: 快照当前所有 GPU tensor
before_tensors = set()
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) and obj.is_cuda:
before_tensors.add(id(obj))
except Exception:
pass
# === 执行可疑代码 ===
# ... your training loop ...
# Step 3: 再次 GC
gc.collect()
torch.cuda.empty_cache()
# Step 4: 找出新增的 tensor
leaked = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) and obj.is_cuda and id(obj) not in before_tensors:
leaked.append(obj)
except Exception:
pass
if leaked:
print(f" Found {len(leaked)} leaked tensors:")
for t in leaked[:20]: # 只显示前 20 个
print(f" shape={tuple(t.shape)}, dtype={t.dtype}, "
f"size={t.numel()*t.element_size()/1024**2:.2f}MB")
else:
print(" No memory leak detected.")
# Step 5: 监控 reserved vs allocated 差距(碎片化指标)
alloc = torch.cuda.memory_allocated() / 1024**2
resv = torch.cuda.memory_reserved() / 1024**2
frag = (resv - alloc) / resv * 100 if resv > 0 else 0
print(f"\n Allocated: {alloc:.1f}MB | Reserved: {resv:.1f}MB | Fragmentation: {frag:.1f}%")
if frag > 30:
print("High fragmentation! Consider PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")
高频泄漏模式速查:
模式 症状 修复 list.append(tensor)显存随 epoch 线性增长 .append(tensor.item())或.detach()缺少 zero_grad显存缓慢增长 + loss 异常 optimizer.zero_grad(set_to_none=True)闭包捕获 tensor 显存不释放但找不到引用 检查 lambda/functools.partial 全局变量缓存 显存只增不减 用 weakref或显式清理Dataloader worker 泄漏 CPU 内存增长 persistent_workers=True+ 检查 collate_fn
set_to_none=True:zero_grad(set_to_none=True)将梯度设为None而非零张量,节省一份梯度显存。PyTorch 2.x 默认行为,1.x 需显式指定。detach()vsitem():detach()返回新 tensor 但仍占显存;item()返回 Python float,完全释放。记录指标永远用item()。- 碎片化是隐形杀手 :即使 allocated 未达上限,reserved 中的碎片也可能导致 OOM。PyTorch 2.1+ 的
expandable_segments大幅缓解此问题,生产环境务必开启。
显存优化的优先级排序
1. 算法级:Flash Attention(消除 T² 激活)、RoPE(无需位置嵌入缓存)
2. 系统级:Gradient Checkpointing、Activation Recomputation
3. 精度级:BF16/FP16 混合精度、INT8/INT4 量化
4. 工程级:梯度累积、batch size 调优、inplace 操作
5. 分布式:ZeRO Stage 1/2/3、FSDP、Tensor Parallelism
永远从算法级开始优化。Flash Attention 一项就能节省 80%+ 的 Attention 激活显存,效果远超其他所有工程技巧的叠加。
LLM 特有的显存陷阱
- KV Cache:推理时 KV cache 随序列长度线性增长。7B 模型在 4K context 下 KV cache ≈ 2GB,128K context 下 ≈ 64GB。必须使用 PagedAttention (vLLM) 或 MQA/GQA 架构。
- Tokenizer Padding:变长序列 padding 到 max_length 会浪费大量激活显存。使用 packing(多条短序列拼接)或 varlen attention。
- Logits Tensor :vocab_size=128K 时,logits tensor
(B, T, 128K)在 FP32 下占用巨大。使用 fused cross-entropy 避免物化完整 logits。
生产级显存监控 Checklist
| 检查项 | 标准 | 工具 |
|---|---|---|
| 峰值显存 | < GPU 总显存 90% | max_memory_allocated() |
| 碎片率 | < 20% | reserved - allocated |
| 激活占比 | 已知且合理 | Profiler Memory View |
| 泄漏检测 | 多 epoch 显存稳定 | gc + tensor snapshot |
| OOM 恢复 | graceful fallback | try/except + reduce batch |
| 分布式一致性 | 各卡显存均衡 | nvidia-smi / DDP profiler |
掌握这些内容后,已具备独立诊断和解决 LLM 训练中各类显存问题的能力。下一步建议深入研究 Flash Attention 的实现原理和 ZeRO/FSDP 的分片策略,将显存优化从单卡扩展到千卡规模。
在深度学习(尤其是 LLM 训练)中,NaN/Inf 是最令人头疼的问题。它们具有传染性 (一个 NaN 会通过反向传播污染所有梯度)和隐蔽性(可能在 loss 爆炸前数十步就已产生)。掌握系统化的数值健康检查方法,是从"盲目试错"到"精准定位"的关键跨越。
调试工具函数实现
python
import torch
import math
from typing import Dict, List, Tuple, Optional
def tensor_health(x: torch.Tensor, name: str = "tensor") -> Dict:
"""
全面诊断张量的数值健康状态
Returns:
包含 min/max/mean/std/nan_count/inf_count 的字典
"""
with torch.no_grad():
# 转为 float32 统计,避免 FP16/BF16 下 inf/nan 判断失真
xf = x.detach().float()
is_nan = torch.isnan(xf)
is_inf = torch.isinf(xf)
finite_mask = torch.isfinite(xf)
stats = {
"name": name,
"shape": tuple(x.shape),
"dtype": str(x.dtype),
"nan_count": is_nan.sum().item(),
"inf_count": is_inf.sum().item(),
"finite_ratio": finite_mask.float().mean().item(),
}
# 仅在存在有限值时计算统计量,否则返回 NaN 会误导
if stats["finite_ratio"] > 0:
xf_finite = xf[finite_mask]
stats.update({
"min": xf_finite.min().item(),
"max": xf_finite.max().item(),
"mean": xf_finite.mean().item(),
"std": xf_finite.std().item(),
"abs_max": xf_finite.abs().max().item(),
})
else:
stats.update({"min": None, "max": None, "mean": None, "std": None, "abs_max": None})
return stats
def has_nonfinite(x: torch.Tensor) -> bool:
"""快速检查张量是否包含 NaN 或 Inf"""
# isfinite 比 isnan + isinf 更快(单次 kernel launch)
return not torch.isfinite(x).all().item()
def gradient_l2_norm(model: torch.nn.Module) -> float:
"""
计算模型所有梯度的全局 L2 范数
等价于 torch.nn.utils.clip_grad_norm_ 的内部计算(但不裁剪)
"""
total_norm_sq = 0.0
for p in model.parameters():
if p.grad is not None:
# 用 float32 累加,避免 FP16 梯度平方溢出
g = p.grad.detach().float()
total_norm_sq += g.norm(2).item() ** 2
return math.sqrt(total_norm_sq)
def find_nonfinite_grad_names(model: torch.nn.Module) -> List[str]:
"""
找出所有梯度中包含 NaN/Inf 的参数名
用于精确定位问题源头
"""
bad_names = []
for name, p in model.named_parameters():
if p.grad is not None and has_nonfinite(p.grad):
nan_c = torch.isnan(p.grad).sum().item()
inf_c = torch.isinf(p.grad).sum().item()
bad_names.append(f"{name} (nan={nan_c}, inf={inf_c})")
return bad_names
关键设计决策解析
| 设计点 | 错误做法 | 正确做法 | 原因 |
|---|---|---|---|
| 统计精度 | 直接在 FP16 上算 mean/std | .detach().float() 后统计 |
FP16 的 inf 阈值仅 65504,正常大值会被误判;BF16 虽好但 std 精度仍不足 |
| 非有限值处理 | x.mean() 含 NaN → 结果 NaN |
先 mask 再统计 | 一个 NaN 污染整个统计量,完全失去诊断价值 |
| L2 Norm 累加 | FP16 下直接累加平方和 | float32 累加 | FP16 平方极易溢出为 inf,导致 norm 永远为 inf |
| Non-finite 检测 | `isnan(x) | isinf(x)` | ~isfinite(x) |
| 梯度访问 | p.grad.data |
p.grad.detach() |
.data 绕过 autograd 追踪,可能导致隐性 bug;.detach() 是安全标准 |
LLM 训练中 NaN/Inf 的系统化排查流程
当训练中出现 NaN 时,按以下决策树逐层缩小范围:
Loss = NaN?
├── Step 0 就 NaN → 初始化/数据问题
│ ├── 检查输入数据:has_nonfinite(batch) / tensor_health(batch)
│ ├── 检查权重初始化:tensor_health(param) for param in model.parameters()
│ └── 检查 LR:过大导致第一步就发散
│
├── 训练中途突然 NaN → 数值不稳定
│ ├── 1. 定位首个 NaN step(二分法 / anomaly detection)
│ ├── 2. 该 step 前向检查:每层输出 tensor_health()
│ │ └── 找到首个 NaN 层 → 检查该层的输入/权重/操作
│ ├── 3. 该 step 反向检查:find_nonfinite_grad_names(model)
│ │ └── 梯度 NaN 的层 ≠ 前向 NaN 的层 → 反向传播中的数值问题
│ └── 4. 检查 loss 本身:log(0)、除零、softmax 全 mask
│
└── 逐渐增大后 NaN → 梯度爆炸
├── gradient_l2_norm(model) 监控趋势
├── 加 grad clipping(max_norm=1.0)
├── 降低 LR / 增加 warmup
└── 切换 FP16 → BF16(动态范围大 8 个数量级)
LLM 特有的高频 NaN 陷阱
Softmax 全 Mask 行
python
# 因果掩码 + padding 导致某行全为 -inf
scores = scores.masked_fill(~mask, float('-inf'))
attn = F.softmax(scores, dim=-1) # softmax([-inf,...,-inf]) = [NaN,...,NaN]
# 修复:softmax 后将全 mask 行置零
attn = F.softmax(scores, dim=-1)
attn = attn.masked_fill(~mask.any(dim=-1, keepdim=True), 0.0)
这是 LLM 训练中最常见的 NaN 来源,尤其在变长序列 packing 场景下。
Log Probabilities 中的 log(0)
python
# CrossEntropyLoss 内部 log(softmax),若 softmax 输出精确 0 → log(0) = -inf
# PyTorch CE 已做数值稳定,但自定义 loss 需注意
# 使用 log_softmax + nll_loss 组合,或直接 F.cross_entropy(fused 且稳定)
FP16 下的梯度平方溢出
FP16 最大值为 65504。当梯度 > 255 时,grad² > 65504 → overflow to inf → Adam 的 v 变为 inf → 更新量变为 NaN。BF16 最大值约 3.4×10³⁸,彻底消除这一问题。这也是为什么现代 LLM 训练几乎全部使用 BF16。
Gradient Checkpointing 的重计算不一致
如果 checkpoint 包裹的模块在前向和重计算时行为不同(如 dropout 未固定 seed、BN 在 train/eval 间切换),重计算的激活值与原始值不一致,导致梯度错误甚至 NaN。
python
# 确保 checkpoint 内的操作是确定性的
# use_reentrant=False(PyTorch 2.x 推荐)自动处理大部分情况
checkpoint(fn, *args, use_reentrant=False)
集成到训练循环的调试 Hook
python
class NaNDebugger:
"""插入训练循环的自动 NaN 检测器"""
def __init__(self, model: torch.nn.Module, check_every: int = 1):
self.model = model
self.check_every = check_every
self.step = 0
def after_backward(self, loss: torch.Tensor):
self.step += 1
if self.step % self.check_every != 0:
return
# 1. 检查 loss
if has_nonfinite(loss):
print(f" Step {self.step}: Loss is non-finite!")
print(f" Loss value: {loss.item()}")
bad = find_nonfinite_grad_names(self.model)
print(f" Bad grads: {bad[:5]}") # 只显示前 5 个
raise RuntimeError(f"NaN detected at step {self.step}")
# 2. 检查梯度范数趋势
gnorm = gradient_l2_norm(self.model)
if gnorm > 100.0:
print(f" Step {self.step}: Grad norm = {gnorm:.2f} (abnormally high)")
# 3. 可选:抽样检查中间激活(开销较大,仅在排查时开启)
# for name, buf in self.model.named_buffers():
# if has_nonfinite(buf):
# print(f" Buffer '{name}' is non-finite: {tensor_health(buf)}")
调试工具选型速查
| 场景 | 工具 | 开销 | 适用阶段 |
|---|---|---|---|
| 快速检查单个 tensor | has_nonfinite() |
极低 | 实时 / 生产 |
| 详细诊断异常 tensor | tensor_health() |
低 | 排查时 |
| 监控梯度整体健康 | gradient_l2_norm() |
中 | 每步 / 每 N 步 |
| 定位具体坏梯度参数 | find_nonfinite_grad_names() |
中 | NaN 发生时 |
| 自动捕获首次 NaN | torch.autograd.set_detect_anomaly(True) |
极高 (10-100x) | 仅小样本复现 |
| 大规模训练 NaN 定位 | 自定义 Hook + 日志 | 可控 | 生产训练 |
detect_anomaly警告 :此模式会为每个算子插入额外检查,使训练速度下降 10-100 倍。绝不要在全量训练中开启。正确用法:用小数据集 + 小模型复现 NaN → 开 anomaly → 获得精确堆栈 → 关闭 anomaly → 修复代码 → 全量验证。
掌握这四个函数和排查流程后,已具备独立应对 LLM 训练中各类数值异常的能力。下一步建议结合 torch.compile 的数值一致性检查和分布式训练中的跨 rank NaN 同步检测,将调试能力扩展到千卡规模。