llm-algo-1

涵盖了 Attention 机制原理、PyTorch Profiling 性能分析、显存优化 以及 数值调试技巧 。这四个模块并非孤立存在,它们共同构成了 "LLM 底层系统工程" 的基石。要从"学会知识点"进阶到"构建体系化能力",需要将这四块内容重组为一条 "正确性 → 可观测性 → 效率上限" 的工程闭环。以下是循序渐进认知与实践框架:

认知重构:建立三层递进模型

不要平行地看待四个主题,应按以下优先级建立层级认知:

层级 核心主题 认知目标 对应今日内容
L1: 正确性基座 Attention + Debugging 确保算得对、不崩溃。这是所有优化的前提。 SDPA 手写、NaN 排查、梯度健康检查
L2: 可观测性 Profiling 看清黑盒内部。没有 Profiling 的优化都是盲人摸象。 torch.profiler、TensorBoard、CPU/GPU 时间分布
L3: 效率上限 Memory Optimization 在硬件约束下逼近理论极限。 显存分析、梯度累积、AMP、Checkpointing

核心心法先 L1 再 L2,有 L2 才做 L3。

很多初学者跳过 Profiling 直接做显存优化,结果花了大量时间优化了一个根本不是瓶颈的模块。正确的顺序永远是:代码跑通 (L1) → Profile 找瓶颈 (L2) → 针对性优化 (L3) → 再次 Profile 验证 (L2)。

高频实践:必须形成肌肉记忆的内容

以下内容不是"了解即可",而是需要在未来每次写训练代码时条件反射般使用的:

  1. 数值安全三件套
    • masked_fill(~mask, float('-inf')) 而非 -1e9
    • Loss 记录永远用 .item() 而非 append tensor
    • optimizer.zero_grad(set_to_none=True)
  2. 显存监控习惯
    • 训练启动后前 3 步必查 torch.cuda.max_memory_allocated()
    • 遇到 OOM 先看 reserved - allocated 判断是碎片还是真不足
  3. Profiling 标准动作
    • 任何性能测试必须先 Warmup ≥3 步
    • GPU 计时只用 cuda.Event,禁用 time.time()
    • 报告必看 Self CPU % vs CUDA time 判断 CPU/GPU 谁在等谁
  4. Attention 实现规范
    • Scale 必须在 Softmax 之前
    • 生产环境优先 F.scaled_dot_product_attention(is_causal=True)
    • 手写版本仅用于学习和对齐验证

深入研究:从"会用"到"理解本质"

当基础操作熟练后,需要对以下问题进行深度追问,这决定了能走多远:

Attention 的数学直觉

  • 为什么是 D*D* 而不是 D*D* ? 推导点积方差与维度的关系,理解 scale 的本质是方差归一化。
  • Causal Mask 的替代方案有哪些? 研究 ALiBi、RoPE 如何将位置信息注入 Q/K 而非依赖 mask,理解它们对长度外推的影响。
  • Flash Attention 为什么快? 深入理解 Tiling + Online Softmax 算法,明白它如何避免物化 T×TT ×T 矩阵,以及 IO-aware 设计哲学。

显存的物理本质

  • PyTorch 缓存分配器的工作原理 :理解 allocated vs reserved 的区别,研究 expandable_segments 如何解决碎片化。
  • Activation Checkpointing 的最优粒度:不是越细越好。研究如何根据每层的 FLOPs/Activation 比值选择性地 checkpoint,平衡重计算与显存节省。
  • 混合精度的数值边界:FP16 vs BF16 的动态范围差异,GradScaler 的缩放因子自适应逻辑,哪些算子必须保持 FP32(如 LayerNorm、Softmax)。

Profiling 的信号解读

  • Kernel Launch Overhead 的量化:多少 μs/op 算异常?小算子融合的阈值在哪里?
  • 异步执行的时序图阅读:能从 Chrome Trace 中识别出 data loading gap、sync barrier、kernel pipeline bubble。
  • 内存时间线分析:区分临时 buffer、持久 activation、optimizer state 的生命周期。

体系化实践路线图

阶段 周期 实践任务 验收标准
Week 1: 正确性验证 第 1 周 手写 SDPA 并与官方实现对齐;故意注入 NaN/Inf 并用调试工具定位 单元测试 pass;能在 5 分钟内定位人为注入的数值 bug
Week 2: 性能画像 第 2 周 对一个真实模型做完整 Profiling,输出瓶颈分析报告 能明确说出 Top3 瓶颈及原因(CPU bound / GPU bound / Memory bound)
Week 3: 显存压榨 第 3 周 对同一模型依次应用 AMP→Checkpoint→GradAccum,记录显存/速度变化曲线 产出优化对比表;显存降低 ≥40% 且吞吐下降 <20%
Week 4: 综合实战 第 4 周 从零搭建一个 mini-LLM 训练循环,集成所有最佳实践 代码无 NaN、显存利用率 >80%、Profile 无明显 bubble
  1. 建立个人 Debug Checklist:将今天学到的排查流程固化为文档,每次遇到 NaN/OOM 时按清单逐项检查,而非凭感觉试错。
  2. 保留基准测试代码 :将 benchmark_batch_sizesmeasure_peak_memory 等工具函数封装为可复用库,后续每个新项目都直接调用。
  3. 关注 PyTorch 版本演进 :今天的很多"最佳实践"(如手动 AMP、手写 checkpoint)在 PyTorch 2.x+ 中已被 torch.compile 和原生 API 取代。理解原理是为了更好地使用高层抽象,而非永远手写底层。
  4. 从单卡思维走向系统思维:今天的知识主要面向单卡。下一步应自然延伸到多卡场景:ZeRO/FSDP 如何分片显存、DDP 通信如何与计算重叠、跨 rank NaN 如何同步检测。

终极检验标准:当看到一个 LLM 训练报错或性能问题时,能在 30 秒内判断它属于 L1/L2/L3 哪个层级,并准确调用对应的工具和知识储备------这就是体系化能力的体现。

1. generate_position_ids(seq_len)目标:掌握列表推导式(List Comprehension),这是处理 Token 序列最基础的操作。

python 复制代码
def generate_position_ids(seq_len: int) -> list[int]:
    """
    生成位置编码 ID,用于 Transformer 的位置嵌入层。
    Args:
        seq_len: 序列长度
    Returns:
        从 0 开始的连续整数列表 [0, 1, 2, ..., seq_len-1]
    """
    return [i for i in range(seq_len)]

虽然 list(range(seq_len)) 也能实现。在实际 LLM 工程中,当需要对 position_ids 做条件过滤(如跳过 padding token)时,列表推导的灵活性远高于 list(range())


2. merge_model_config(base_config, overrides)目标 :理解字典合并与浅拷贝陷阱。这是模型配置热更新的核心操作。

python 复制代码
def merge_model_config(base_config: dict, overrides: dict) -> dict:
    """
    安全合并模型配置,overrides 中的键值对覆盖 base_config。不修改原始 base_config
    Args:
        base_config: 基础配置字典
        overrides: 需要覆盖的配置项
    Returns:
        合并后的新字典
    """
    # 使用浅拷贝避免污染原对象
    merged = base_config.copy()
    merged.update(overrides)
    return merged
  • 错误写法base_config.update(overrides); return base_config → 这会原地修改传入的 base_config,导致后续调用产生难以排查的 bug。
  • 浅拷贝 vs 深拷贝 :此处用 .copy() 即可,因为 config 通常是扁平的 key-value 结构。如果 config 包含嵌套字典/列表且需要独立修改,则应使用 copy.deepcopy()
  • Python 3.9+ 替代写法return {**base_config, **overrides} 同样安全且更简洁,但 .copy() + .update() 语义更清晰,适合调试打印。

3. count_token_frequency(tokens)目标:掌握频率统计,这是分词器分析、数据质量检查的基础技能。

python 复制代码
from collections import Counter
def count_token_frequency(tokens: list[str]) -> dict[str, int]:
    """
    统计 token 出现频率。
    
    Args:
        tokens: token 字符串列表
    Returns:
        {token: count} 字典,按出现次数降序排列
    """
    freq = Counter(tokens)
    # 返回普通 dict,按频率降序,方便直接打印调试
    return dict(freq.most_common())

Counter 是 Python 标准库中专为频率统计设计的工具,底层 C 实现,比手写 for + dict 快一个数量级。在分析百万级 token 语料时差异显著。.most_common() 返回的结果天然适合日志输出和可视化。


4. ModelSpec.summary()目标 :学会用类封装模型元信息,并让输出可直接用于调试

python 复制代码
class ModelSpec:
    """轻量级模型规格描述类"""
    def __init__(self, name: str, vocab_size: int, hidden_dim: int, num_layers: int):
        self.name = name
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
    
    def summary(self) -> str:
        """
        返回格式化的模型摘要字符串,可直接 print() 用于调试。
        """
        lines = [
            f"=== Model Spec: {self.name} ===",
            f"  Vocab Size : {self.vocab_size:,}",
            f"  Hidden Dim : {self.hidden_dim:,}",
            f"  Num Layers : {self.num_layers}",
            f"  Est. Params: ~{self._estimate_params():,}",
        ]
        return "\n".join(lines)
    
    def _estimate_params(self) -> int:
        """粗略参数量估算(仅用于调试展示)"""
        return self.vocab_size * self.hidden_dim + \
               self.num_layers * (4 * self.hidden_dim ** 2)
    
    def __repr__(self) -> str:
        return f"ModelSpec(name={self.name!r}, layers={self.num_layers})"
  • summary() 返回 str 而非直接 print,这样既能 print(spec.summary()) 也能写入日志文件,符合"输出可直接用于调试和打印"的要求。
  • 数字使用 {:,} 千分位格式化,大幅提升可读性。
  • 额外实现 __repr__,使得在 Jupyter Notebook 或 pdb 调试中直接显示对象时有意义,而不是 <__main__.ModelSpec object at 0x...>
  • _estimate_params() 作为私有方法,体现封装思想:对外只暴露 summary(),内部计算细节对使用者透明。

检查项 通过标准
列表推导 generate_position_ids 不使用 list(range())
字典安全 merge_model_config 不修改 base_config 原对象
频率统计 使用 Counter,结果按频次降序
类封装 summary() 返回 str,含千分位格式化
代码风格 有 type hint + docstring,无魔法数字

5. build_causal_mask(seq_len)目标 :掌握布尔掩码生成与广播,这是 Decoder-only 模型防止"未来信息泄露"的关键。einsum 和广播机制是理解 Attention、LayerNorm 等算子的"通用语言"。

python 复制代码
import numpy as np

def build_causal_mask(seq_len: int) -> np.ndarray:
    """
    生成因果掩码 (Causal Mask)。
    True/1 表示允许 attend,False/0 表示需要被 mask 掉。
    
    Shape: (seq_len, seq_len)
    示例 (seq_len=3):
        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]]
    """
    # np.tril 直接生成下三角矩阵,简洁且高效
    mask = np.tril(np.ones((seq_len, seq_len), dtype=np.bool_))
    return mask
  • 为什么用 bool_ 而非 float 布尔掩码在后续与 attention scores 结合时,可通过 np.where(mask, scores, -np.inf) 安全地将未来位置设为负无穷。若直接用 float 0/1 做乘法,softmax 后 0 位置仍会有非零概率,导致信息泄露。
  • 广播预备 :实际 Attention 中 mask 需要广播到 (batch, num_heads, seq_len, seq_len),此处返回 2D 形状正是为了后续自动广播对齐。

6. matmul_einsum(a, b)目标:用 einsum 表达最基础的矩阵乘法,建立"下标即语义"的思维。

python 复制代码
def matmul_einsum(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    """
    用 einsum 实现标准矩阵乘法 C = A @ B
    Args:
        a: shape (M, K)
        b: shape (K, N)
    Returns:
        shape (M, N)
    """
    # i: M维度, k: 收缩维度(求和), j: N维度
    return np.einsum('ik,kj->ij', a, b)

Einsum 阅读口诀

  • 逗号左边:输入张量的轴标签
  • 箭头右边:输出张量保留的轴标签
  • 未出现在右边的标签 (如 k):沿该轴求和(收缩)
  • 'ik,kj->ij' 等价于 Cij=∑kAik⋅BkjC_{ij}=∑kA{ik}⋅B_{kj}Cij=∑kAik⋅Bkj ,这正是矩阵乘法的数学定义。
  • 调试技巧:写 einsum 前先在纸上画出每个张量的 shape 并标注轴名,确认收缩维度大小一致后再写代码。

7. batch_attention_scores(q, k)目标:用 einsum 表达带 batch 维度的多维张量运算,这是 Multi-Head Attention 的核心步骤。

python 复制代码
def batch_attention_scores(q: np.ndarray, k: np.ndarray) -> np.ndarray:
    """
    计算 batch 下的 Attention 分数 (未缩放、未 mask)。
    
    Args:
        q: shape (batch, seq_len_q, d_model)
        k: shape (batch, seq_len_k, d_model)
    Returns:
        scores: shape (batch, seq_len_q, seq_len_k)
    """
    # b: batch, i: seq_q, j: seq_k, d: d_model (收缩)
    scores = np.einsum('bid,bjd->bij', q, k)
    return scores
  • b 同时出现在 q 和 k 中:表示 batch 维度是对齐的,einsum 会逐 batch 独立计算,不会跨 batch 混合。
  • d 只出现在左边:表示沿 feature 维度做点积求和,这正是 attention score 的语义。
  • 为什么不先 reshape 再 matmul? 虽然 q @ k.transpose(0,2,1) 也能实现,但 einsum 的优势在于公式即文档 ------看到 'bid,bjd->bij' 就能立刻理解运算逻辑,无需脑内模拟 transpose 和维度对齐。这在推导更复杂的 GQA、RoPE 时优势巨大。
  • 实际工程中 :还需除以 sqrt(d_model) 并应用 causal mask,本题仅聚焦 einsum 表达。

8. rms_normalize(x, eps)目标 :用 NumPy 实现 RMSNorm,掌握轴向归约 + 广播回扩这一 LLM 中最频繁的张量操作模式。

python 复制代码
def rms_normalize(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """
    Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)

    Args:
        x: shape (..., d_model),支持任意前导维度
        eps: 数值稳定性小量
    Returns:
        归一化后的张量,shape 与 x 相同
    """
    # 1. 沿最后一个轴计算均方根   keepdims=True 是关键!保持维度以便广播
    rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
    # 2. 广播除法:(..., 1) 自动广播到 (..., d_model)
    return x / rms

为什么 keepdims=True 至关重要?

写法 mean 输出 shape 能否广播回 x
axis=-1 (无 keepdims) (...) 维度丢失,需手动 reshape
axis=-1, keepdims=True (..., 1) 自动广播,代码简洁
  • RMSNorm vs LayerNorm:RMSNorm 不减均值、不学 affine 参数(或单独乘 weight),计算更快,是 LLaMA/Qwen 等主流模型的标配。
  • eps 的位置 :放在 sqrt 内部而非外部,避免除零的同时保证梯度稳定。
  • ... 通配符:函数签名支持任意前导维度(batch、seq_len、num_heads 等),这正是 LLM 工程中归一化函数的通用设计模式。

检查项 通过标准
Causal Mask 使用 bool 类型,下三角为 True
Einsum 基础 'ik,kj->ij' 结果与 a @ b 完全一致
Batch Attention 'bid,bjd->bij',batch 维度不参与收缩
RMSNorm keepdims=True,支持任意前导维度
Shape 意识 每个函数都有明确的 shape 注释

完成实现后,强烈建议用以下代码做数值正确性校验

python 复制代码
# 验证 einsum matmul 与原生 matmul 一致性
a = np.random.randn(4, 8)
b = np.random.randn(8, 6)
assert np.allclose(matmul_einsum(a, b), a @ b)

# 验证 RMSNorm 输出尺度
x = np.random.randn(2, 10, 64) * 10
y = rms_normalize(x)
# RMSNorm 后每个样本的 RMS 应接近 1
rms_after = np.sqrt(np.mean(y**2, axis=-1))
assert np.allclose(rms_after, 1.0, atol=1e-4)

掌握这四个函数后,将具备直接阅读和手写 Multi-Head Attention、SwiGLU、RoPE 等 Chapter 2+ 核心算子的数学表达能力。记住:先标 shape,再写下标,最后验数值。

Tensor vs NumPy Array

特性 NumPy Array PyTorch Tensor
设备 仅 CPU CPU 或 GPU
自动求导
深度学习优化
互操作性 - 可与 NumPy 互转

创建 Causal Mask

  • 题目提示中写的是 torch.triu()(上三角),但根据文档示例和自回归生成的语义(位置 i 只能看到 ≤i 的位置),我们需要的是下三角矩阵 。使用 triu 会导致模型只看未来不看过去,完全违背因果律。

python 复制代码
import torch

def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    创建 Causal Mask (下三角布尔矩阵)
    
    Args:
        seq_len: 序列长度
    Returns:
        mask: shape (seq_len, seq_len), dtype=torch.bool
              mask[i, j] = True 表示位置 i 可以 attend 到位置 j
    """
    # 使用 tril (下三角) 而非 triu
    # dtype=bool 节省显存,且与后续 masked_fill / attention_mask 兼容
    mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
    return mask

# 验证
mask = create_causal_mask(4)
print(mask)
# tensor([[ True, False, False, False],
#         [ True,  True, False, False],
#         [ True,  True,  True, False],
#         [ True,  True,  True,  True]])
  • 工程实践要点

    • 为什么用 bool 而非 float 布尔 mask 在 GPU 上占用显存仅为 float32 的 1/4。在后续 Attention 计算中,通过 scores.masked_fill(~mask, float('-inf')) 即可安全屏蔽未来位置。若用 float 0/1 做乘法,softmax 后被 mask 的位置仍会有非零概率。
    • 广播友好性 :返回 2D shape (seq_len, seq_len) 而非 4D,PyTorch 会自动广播到 (batch, num_heads, seq_len, seq_len),避免不必要的内存分配。

多头注意力的维度变换

这是 Multi-Head Attention 中最容易出错的步骤,核心是 reshape + permute 的顺序不能反

python 复制代码
def split_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
    """
    将 (batch, seq, hidden) 拆分为 (batch, num_heads, seq, head_dim)
    Args:
        x: shape (batch, seq, hidden)
        num_heads: 注意力头数,hidden 必须能被其整除
    Returns:
        shape (batch, num_heads, seq, head_dim)
    """
    batch, seq, hidden = x.shape
    assert hidden % num_heads == 0, f"hidden({hidden}) must be divisible by num_heads({num_heads})"
    head_dim = hidden // num_heads
    
    # Step 1: reshape 拆分 hidden 维度 → (batch, seq, num_heads, head_dim)
    # Step 2: permute 将 num_heads 移到 seq 前面 → (batch, num_heads, seq, head_dim)
    x = x.reshape(batch, seq, num_heads, head_dim)
    x = x.permute(0, 2, 1, 3)
    
    return x

# 验证
x = torch.randn(2, 10, 512)
x_split = split_heads(x, num_heads=8)
print(x_split.shape)  # torch.Size([2, 8, 10, 64])
  • 为什么不能先 permute 再 reshape? 如果先 permute(0,2,1) 得到 (batch, hidden, seq),再 reshape 会把 hidden 和 seq 的元素混在一起,导致语义错乱(不同 token 的特征被拼接到了同一个 head 里)。必须先在原始维度上拆分 hidden,再移动维度顺序。
  • reshape vs view :此处推荐用 reshape。虽然输入通常是连续的,但在某些 pipeline 中 x 可能经过 transpose 变得不连续,view 会直接报错,而 reshape 会自动处理(必要时拷贝),代码更健壮。
  • 逆操作 :Attention 输出后需要合并 heads,逆操作为 x.permute(0, 2, 1, 3).reshape(batch, seq, hidden),顺序同样不能反。

批量矩阵乘法

用于计算 Attention Scores ( QKTQK^TQKT) 的核心算子。

python 复制代码
def batch_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    批量矩阵乘法: C[b,i,j] = Σ_k A[b,i,k] * B[b,k,j]
    
    Args:
        a: shape (batch, n, m)
        b: shape (batch, m, p)
    Returns:
        shape (batch, n, p)
    """
    # 推荐使用 @ 运算符(等价于 torch.bmm,但更简洁且支持广播)
    return a @ b

# 验证
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 5)
c = batch_matmul(a, b)
print(c.shape)  # torch.Size([2, 3, 5])

@ vs torch.bmm vs torch.matmul 选型指南

操作符 严格 3D 支持广播 可读性 推荐场景
torch.bmm ✅ 仅 3D 旧代码兼容
@ / torch.matmul 现代 PyTorch 首选
torch.einsum 复杂收缩/非标准乘法
  • 为什么优先用 @ 它底层调用 cuBLAS 的 cublasSgemmBatched,性能与 bmm 完全一致,但支持广播(例如 a: (batch, n, m) @ b: (1, m, p) 自动广播),在 GQA、KV Cache 等场景中无需手动 expand。
  • 数值精度注意 :在 FP16/BF16 混合精度训练中,@ 默认使用 Tensor Core,可能有微小数值差异。如需严格复现 FP32 结果,可用 torch.matmul(a.float(), b.float()).half()

检查项 通过标准 常见错误
Causal Mask tril + bool dtype 误用 triu;用 float 类型
split_heads 先 reshape 再 permute 顺序颠倒导致语义错乱
batch_matmul 使用 @,shape 正确 手动 for 循环逐 batch 乘
内存连续性 理解 view 失败原因 对非连续 tensor 强行 view
设备安全 .to(device) 而非硬编码 cuda 在无 GPU 环境崩溃

每次写完张量操作,养成三行验证法的习惯:

python 复制代码
# 1. Shape 断言
assert x_split.shape == (2, 8, 10, 64), f"Expected (2,8,10,64), got {x_split.shape}"
# 2. 数值 sanity check(如 mask 对角线应全为 True)
assert mask.diagonal().all(), "Causal mask diagonal must be all True"
# 3. 梯度连通性(确保没有意外 detach)
x_test = torch.randn(2, 10, 512, requires_grad=True)
y = split_heads(x_test, 8)
loss = y.sum()
loss.backward()
assert x_test.grad is not None, "Gradient disconnected!"
  • 掌握这三个练习后,具备了手写完整 Multi-Head Attention 模块的张量操作基础。建议接下来结合 Chapter 01 的 einsum 知识,尝试用纯 PyTorch 实现一个完整的 Decoder-only Attention 层。

PyTorch 的 Autograd 机制是 LLM 训练的引擎。理解它不仅能写出正确的训练循环,还能在遇到梯度爆炸、显存溢出或自定义算子时快速定位问题。

实现 Softmax 的自定义 autograd

Softmax 是 LLM 输出层和 Attention 的核心。数值稳定性反向传播的高效实现是本题的两个关键考点。

python 复制代码
import torch

class MySoftmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        """
        数值稳定的 Softmax 前向传播
        Args:
            input: shape (batch, num_classes)
        Returns:
            output: shape (batch, num_classes), 每行和为 1
        """
        # 减去最大值防止 exp() 溢出(数值稳定性核心)
        x_max = input.max(dim=1, keepdim=True).values
        exp_x = torch.exp(input - x_max)
        output = exp_x / exp_x.sum(dim=1, keepdim=True)
        # 保存 softmax 输出而非原始输入,反向传播可直接复用
        ctx.save_for_backward(output)
        return output
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """
        Softmax 反向传播的高效向量形式
        Args:
            grad_output: dL/dy, shape (batch, num_classes)
        Returns:
            grad_input: dL/dx, shape (batch, num_classes)
        """
        output, = ctx.saved_tensors
        
        # 向量化公式: dx_i = y_i * (dy_i - Σ_j(dy_j * y_j))
        # 避免 O(n²) 的 Jacobian 矩阵计算
        sum_dy_y = (grad_output * output).sum(dim=1, keepdim=True)
        grad_input = output * (grad_output - sum_dy_y)
        
        return grad_input

# 验证
softmax = MySoftmax.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = softmax(x)

# 1. 前向正确性:每行和为 1
print("Row sums:", y.sum(dim=1))  # tensor([1., 1.], dtype=torch.float64)

# 2. 反向正确性:使用 gradcheck 自动验证
from torch.autograd import gradcheck
assert gradcheck(MySoftmax.apply, x, eps=1e-6), "Gradient check failed!"
print("Gradient check passed!")
  • 为什么 save output 而非 input Softmax 的梯度公式 ∂L∂xi=yi(∂L∂yi−∑j∂L∂yjyj)\frac{∂L}{∂xi}=yi(\frac{∂L}{∂yi}−∑_j\frac{∂L}{∂yj}yj)∂xi∂L=yi(∂yi∂L−∑j∂yj∂Lyj) 只依赖 y (softmax 输出),不依赖原始 x 。保存 y 避免了反向传播时重新计算 softmax,节省算力。
  • 为什么不用 Jacobian 矩阵? Softmax 的完整 Jacobian 是 n×n 矩阵,直接计算复杂度为 O(n2)。上述向量化公式将复杂度降至 O(n) ,这在 vocab_size=128K 的 LLM 中差距巨大。
  • dtype=torch.doublegradcheck 要求双精度浮点,因为数值梯度对精度极其敏感。实际训练中用 float32/bfloat16 即可。

梯度累积模拟大 Batch

LLM 训练中显存往往不够放下目标 batch size,梯度累积是在不增加显存的前提下等效增大 batch size 的标准技术。

python 复制代码
def train_with_gradient_accumulation(model, data_loader, optimizer, 
                                      accumulation_steps: int = 4):
    """
    梯度累积训练循环
    
    Args:
        model: PyTorch 模型
        data_loader: 数据加载器
        optimizer: 优化器
        accumulation_steps: 累积步数,等效 batch = micro_batch × accumulation_steps
    """
    model.train()
    optimizer.zero_grad()
    
    for i, (inputs, targets) in enumerate(data_loader):
        # 1. 前向传播
        outputs = model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, targets)
        
        # 2. 损失缩放:保证累积后的梯度与真实大 batch 等价
        scaled_loss = loss / accumulation_steps
        
        # 3. 反向传播(梯度自动累加到 .grad)
        scaled_loss.backward()
        
        # 4. 每 accumulation_steps 步更新一次参数
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
    
    # 处理末尾不足 accumulation_steps 的残余梯度
    if len(data_loader) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()
  • 为什么要除以 accumulation_steps 真实大 batch 的 loss 是所有样本 loss 的平均值。如果不除,累积后的梯度会是真实值的 N 倍,等效于学习率放大了 N 倍,导致训练发散。
  • zero_grad() 的位置 :必须在 optimizer.step() 之后清零,而非之前。如果在 step 之前清零,刚累积的梯度就被丢弃了。
  • 残余批次处理 :当 dataloader 长度不能被 accumulation_steps 整除时,最后几步的梯度不会被自动更新。生产代码中必须处理这个边界情况,否则最后一个不完整 batch 的训练信号丢失。
  • 与 AMP 配合 :若使用混合精度,scaler.scale(scaled_loss).backward() 中的 scale 因子会与 /accumulation_steps 叠加,需注意顺序。

实现梯度裁剪

LLM 训练(尤其是 RNN/Transformer 早期阶段)极易梯度爆炸。梯度裁剪是保障训练稳定性的安全阀

python 复制代码
def clip_gradients(model: torch.nn.Module, max_norm: float = 1.0) -> float:
    """
    按全局范数裁剪梯度,返回裁剪前的总梯度范数(用于监控)
    
    Args:
        model: 模型
        max_norm: 最大允许梯度范数
    Returns:
        total_norm: 裁剪前的梯度总范数
    """
    # 直接使用 PyTorch 官方实现,底层 C++ 优化,支持分布式
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
    
    return total_norm.item()

# --- 手动实现版本(仅供理解原理)---
def clip_gradients_manual(model: torch.nn.Module, max_norm: float = 1.0) -> float:
    """手动实现全局梯度裁剪"""
    # 1. 收集所有有梯度的参数
    grads = [p.grad for p in model.parameters() if p.grad is not None]
    if not grads:
        return 0.0
    
    # 2. 计算全局 L2 范数(不是逐参数裁剪!)
    total_norm = torch.sqrt(sum(g.data.norm(2).item() ** 2 for g in grads))
    
    # 3. 仅当超限时按比例缩放
    if total_norm > max_norm:
        clip_coef = max_norm / (total_norm + 1e-6)
        for g in grads:
            g.data.mul_(clip_coef)
    
    return total_norm

关键区分:全局裁剪 vs 逐参数裁剪

方法 API 语义 LLM 推荐
全局范数裁剪 clip_grad_norm_ 将所有参数梯度视为一个向量,统一缩放 标准做法
逐值裁剪 clip_grad_value_ 每个梯度分量独立 clamp 到 -v, v ❌ 破坏梯度方向
逐参数范数裁剪 手动循环 norm 每个参数独立缩放 ❌ 失去全局比例关系
  • 为什么返回 total_norm 在 LLM 训练监控面板(如 WandB/TensorBoard)中记录梯度范数是诊断训练健康度的核心指标。范数突然飙升通常预示 loss spike 或数据异常。
  • clip_grad_norm_ 的下划线 :表示原地操作(in-place),直接修改 .grad,无额外内存分配。
  • 分布式训练注意 :在 FSDP/DDP 下,clip_grad_norm_ 会自动跨 GPU 聚合梯度范数再裁剪,手动实现需要额外调用 dist.all_reduce

检查项 通过标准 常见错误
自定义 Function gradcheck 通过;forward 数值稳定 忘记 save_for_backward;反向公式用 Jacobian
梯度累积 loss 除以 N;step 后 zero_grad;处理残余 不除 N 导致梯度放大 N 倍
梯度裁剪 clip_grad_norm_;记录 total_norm 误用 clip_grad_value_
推理模式 model.eval() + torch.no_grad() 推理时仍追踪梯度导致 OOM
detach 使用 明确知道何时切断计算图 滥用 detach 导致梯度断裂

当梯度行为异常时,使用以下诊断手段:

python 复制代码
# 1. 检测 NaN/Inf 梯度
for name, param in model.named_parameters():
    if param.grad is not None and not torch.isfinite(param.grad).all():
        print(f"Non-finite gradient in {name}")

# 2. 监控各层梯度范数分布
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: grad_norm={param.grad.norm().item():.4f}")

# 3. 验证自定义 Function 的梯度(开发阶段必做)
# 始终用 double 类型测试
x = torch.randn(4, 5, dtype=torch.double, requires_grad=True)
assert gradcheck(MySoftmax.apply, x, eps=1e-6, atol=1e-4)

掌握这三个练习后,具备调试 LLM 训练过程中梯度相关问题的能力。建议下一步结合 Chapter 03 的 Tensor 操作,尝试手写一个带梯度检查的完整 Transformer Block 前向+反向传播。

nn.Module 是 PyTorch 构建所有神经网络的基石。在 LLM 工程中,从简单的 MLP 到复杂的 Transformer Block,都依赖于正确的模块封装、参数注册与组合。

SimpleLinear: 最小线性层封装目标 :理解 nn.Parameter 的注册机制以及 forward() 的纯粹性。

python 复制代码
import torch
import torch.nn as nn

class SimpleLinear(nn.Module):
    """
    手动实现的最小线性层 y = xW^T + b
    用于理解 nn.Linear 底层原理及参数注册机制
    """
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        # 必须用 nn.Parameter 包装,否则不会被 state_dict 收集
        # 使用 Kaiming 初始化模拟真实线性层行为
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        
        nn.init.kaiming_uniform_(self.weight, a=5**0.5)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # forward 只做纯张量运算,不包含优化器、loss 等训练逻辑
        return x @ self.weight.t() + self.bias
  • 为什么必须用 nn.Parameter 直接赋值 self.weight = torch.randn(...) 只是一个普通属性,PyTorch 无法感知它。只有被 nn.Parameter 包装(或作为子 nn.Module 的属性),它才会被自动加入 parameters() 迭代器和 state_dict()
  • forward() 的纯粹原则forward 只定义"数据如何流动"。绝不要在 forward 里写 optimizer.step()loss.backward() 或设备转移 .to(device)。这保证了模块的可复用性和可组合性。
  • 权重形状约定 :PyTorch 官方 nn.Linear 的 weight shape 是 (out, in),计算时用 x @ W^T。遵循此约定便于后续与预训练权重对齐。

TwoLayerMLP : 模块组合与激活函数**。目标:掌握通过子模块注册构建复合网络,理解 LLM 中 FFN 的基础原型。

python 复制代码
class TwoLayerMLP(nn.Module):
    """
    两层 MLP: Linear -> ReLU -> Linear
    LLM 中 Feed-Forward Network (FFN) 的最简原型
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        # 子模块赋值给 self.xxx 即自动注册
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x
  • 子模块自动注册 :只要将 nn.Module 实例赋值给 self 的属性(如 self.fc1),PyTorch 就会递归地将其参数纳入管理。不需要 手动调用 add_module()(除非动态构建模块列表)。
  • 激活函数也应是模块 :推荐 self.act = nn.ReLU() 而非在 forward 中写 F.relu()。前者使网络结构在 print(model) 时完整可见,且方便后续替换为 GELU/SwiGLU(LLM 标配)而无需修改 forward。
  • LLM 扩展提示 :现代 LLM 的 FFN 通常是 SwiGLU 变体,需要三个线性层(gate, up, down)。掌握了 TwoLayerMLP 的组合模式后,扩展到三层只需增加一个 self.gate_proj 并在 forward 中加入逐元素乘法。

count_parameters(module) : 参数统计与过滤**。目标:区分可训练/不可训练参数,这是 LLM 微调(LoRA/冻结主干)时的必备监控工具。

python 复制代码
def count_parameters(module: nn.Module, only_trainable: bool = True) -> int:
    """
    统计模型参数数量
    
    Args:
        module: PyTorch 模块
        only_trainable: True=仅统计 requires_grad=True 的参数
                       False=统计全部参数
    Returns:
        参数总数(标量整数)
    """
    if only_trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# 使用示例
model = TwoLayerMLP(768, 2048, 768)
total = count_parameters(model, only_trainable=False)
trainable = count_parameters(model, only_trainable=True)
print(f"Total: {total:,} | Trainable: {trainable:,}")
  • parameters() vs named_parameters()parameters() 返回参数张量迭代器,适合计数和优化器传入;named_parameters() 返回 (name, tensor) 元组,适合按名称过滤(如 LoRA 只解冻 lora_ 前缀参数)。
  • numel() vs size()numel() 返回标量元素总数,size() 返回形状元组。计数必须用 numel()
  • LLM 微调场景 :在 Full Fine-tuning 中 only_trainable=True 等于总参数;但在 LoRA/QLoRA 中,可训练参数可能仅占总参数的 0.1%~1%。打印两者对比是验证冻结策略是否生效的第一步调试手段
  • 格式化输出 :使用 {:,} 千分位分隔符。LLM 参数量动辄数十亿,1,234,567,8901234567890 可读性高一个数量级。

检查项 通过标准 常见错误
参数注册 所有权重用 nn.Parameter 或子模块 直接用 torch.Tensor 赋值
forward 纯粹性 仅含张量运算,无训练逻辑 在 forward 中调用 backward/step
模块组合 子模块通过 self.xxx 注册 在 forward 中临时创建模块
state_dict model.state_dict() 包含所有参数 参数未注册导致保存缺失
参数统计 区分 trainable/total,用 numel() len(parameters()) 误算为层数

确保 Module 是正确的,完成实现后,务必执行以下三项验证:

python 复制代码
model = TwoLayerMLP(768, 2048, 768)

# 1. 结构可视化:确认子模块和激活函数都正确注册
print(model)

# 2. state_dict 完整性:确认所有参数都可序列化
sd = model.state_dict()
assert 'fc1.weight' in sd and 'fc2.bias' in sd, "Missing keys in state_dict!"

# 3. 梯度连通性:确认前向→反向链路无断裂
x = torch.randn(2, 10, 768)
loss = model(x).sum()
loss.backward()
for name, p in model.named_parameters():
    assert p.grad is not None, f"Gradient disconnected at {name}"

掌握这三个练习后,具备了构建任意复杂度 LLM 组件(Attention、FFN、Embedding、RMSNorm)的模块化能力。下一步建议将这些基础模块组合成一个完整的 Transformer Decoder Layer,并验证其 state_dict 能与 HuggingFace Transformers 的键名对齐。

优化器与损失函数是连接"模型输出"与"参数更新"的桥梁。在 LLM 训练中,CrossEntropy 是绝对核心,而 AdamW 则是事实标准。

mse_loss(pred, target) : 回归损失基础**。目标:理解均方误差的数学形式及其梯度特性,这是所有损失函数的起点。

python 复制代码
import torch
import torch.nn.functional as F

def mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    手写均方误差 (Mean Squared Error)
    Args:
        pred: 预测值, shape (*)
        target: 真实值, shape (*),必须与 pred 形状一致
    Returns:
        标量 loss
    """
    # 等价于 F.mse_loss(pred, target, reduction='mean')
    # 先求差 → 平方 → 取均值
    return ((pred - target) ** 2).mean()
  • 为什么用 .mean() 而非 .sum() MSE 的梯度为 2N(pred−target)\frac{2}{N}(pred−target)N2(pred−target) 。使用 mean 使梯度大小与 batch size 无关,切换 batch size 时无需调整学习率。若用 sum,梯度会随 N 线性增长,导致训练不稳定。
  • LLM 中的 MSE 场景:虽然语言建模主用 CrossEntropy,但 MSE 仍用于:① Value Head(RLHF/PPO);② Knowledge Distillation 中的 logits 对齐;③ Embedding 模型的对比学习辅助损失。
  • 数值安全(pred - target) ** 2 在 FP16/BF16 下通常安全,但若差值极大可能溢出。生产代码可考虑 F.mse_loss,其内部有额外的数值保护。

cross_entropy_loss(logits, target) : LLM 的核心损失**。目标:掌握分类交叉熵的正确调用方式,理解 logits vs probabilities 的关键区别。

python 复制代码
def cross_entropy_loss(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    分类交叉熵损失
    
    Args:
        logits: 原始模型输出 (未过 softmax!), shape (batch, num_classes)
        target: 类别索引, shape (batch,), dtype=torch.long
               或 soft labels, shape (batch, num_classes), dtype=torch.float
    Returns:
        标量 loss
    """
    # F.cross_entropy 内部自动执行 log_softmax,数值更稳定
    # 切勿先手动 softmax 再传入!会导致双重 softmax + 数值不稳定
    return F.cross_entropy(logits, target)

# 验证示例
logits = torch.randn(4, 10)       # batch=4, vocab=10
target = torch.randint(0, 10, (4,)) # 硬标签
loss = cross_entropy_loss(logits, target)
print(f"CE Loss: {loss.item():.4f}")
做法 正确性 原因
F.cross_entropy(logits, target) 内部 fused log_softmax,数值稳定且快
F.nll_loss(F.log_softmax(logits), target) 等价写法,略慢
F.cross_entropy(F.softmax(logits), target) 双重 softmax,梯度错误
-torch.log(softmax(logits)[..., target]) 手动实现易溢出,不支持 label smoothing
  • Logits ≠ Probabilities :LLM 输出的最后一层线性层结果叫 logits,范围 (−∞,+∞)(−∞,+∞) 。永远不要对 logits 做 softmax 后再传给 F.cross_entropy
  • Label Smoothing :LLM 训练几乎必用 label smoothing(通常 0.1)。F.cross_entropy(logits, target, label_smoothing=0.1) 一行搞定,防止模型过度自信,提升泛化。
  • Ignore Index :处理 padding token 时用 ignore_index=-100,被忽略位置的 loss 和梯度均为 0,避免无效 token 污染训练信号。

train_one_step(...): 最小训练步模板目标:将前向、反向、更新串联为原子操作,建立正确的梯度管理肌肉记忆。

python 复制代码
def train_one_step(
    model: torch.nn.Module,
    x: torch.Tensor,
    target: torch.Tensor,
    optimizer: torch.optim.Optimizer
) -> float:
    """
    最小完整训练步
    Args:
        model: 模型(已处于 train() 模式)
        x: 输入数据
        target: 标签
        optimizer: 优化器
    Returns:
        本步 loss 值(Python float,脱离计算图)
    """
    # Step 0: 清零梯度(必须在 backward 之前!)
    optimizer.zero_grad()
    # Step 1: 前向传播
    logits = model(x)
    loss = cross_entropy_loss(logits, target)
    # Step 2: 反向传播(梯度自动累积到 .grad)
    loss.backward()
    # Step 3: 参数更新
    optimizer.step()
    # 返回 .item() 脱离计算图,防止内存泄漏
    return loss.item()
复制代码
zero_grad → forward → backward → step   正确
forward → backward → zero_grad → step   梯度被清空,step 无效
forward → backward → step → zero_grad   也正确(PyTorch 官方推荐)

PyTorch 官方文档现在推荐 step → zero_grad 的顺序,因为某些优化器(如 AdamW)在 step 时会利用当前梯度做 weight decay,先 zero_grad 不影响 step,但能让下一次 forward 开始时梯度缓冲区干净。两种顺序都正确,但绝不能把 zero_grad 放在 backward 和 step 之间

  • loss.item() 的重要性 :直接返回 loss 张量会保留整个计算图的引用,导致显存持续增长直至 OOM。记录 loss 曲线时必须.item()
  • LLM 训练增强版 :实际 LLM 训练循环还需加入:① scaler.scale(loss).backward() (AMP);② clip_grad_norm_ (梯度裁剪);③ scheduler.step() (学习率调度)。本题是最小骨架,这些组件在此基础上叠加。

检查项 通过标准 常见错误
MSE .mean(),梯度与 BS 无关 .sum() 导致梯度爆炸
CE 传入 raw logits,不手动 softmax 先 softmax 再传 CE
训练步 zero_grad 不在 backward/step 之间 忘记 .item() 导致 OOM
设备一致性 model/x/target/optimizer 同设备 CPU/GPU 混用报错
Loss 下降 小样本上 10 步内 loss 明显降低 学习率过大/过小或 bug

**必做验证:Loss 是否真的在下降?**完成三个函数后,必须用以下代码验证端到端训练有效性:

python 复制代码
# 构造一个可被完美拟合的小数据集
torch.manual_seed(42)
model = torch.nn.Linear(8, 5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)

x = torch.randn(16, 8)
target = torch.randint(0, 5, (16,))

losses = []
for step in range(50):
    loss_val = train_one_step(model, x, target, optimizer)
    losses.append(loss_val)

# 验证:最终 loss 应显著低于初始 loss
assert losses[-1] < losses[0] * 0.5, f"Loss not decreasing! {losses[0]:.4f} → {losses[-1]:.4f}"
print(f"Training verified: {losses[0]:.4f} → {losses[-1]:.4f}")

如果 loss 不下降,按此顺序排查:① 学习率是否合理(AdamW 通常 1e-4 ~ 3e-4);② target dtype 是否为 long;③ 模型是否有 requires_grad=True 的参数;④ 是否意外在 no_grad() 上下文中训练。

掌握这三个练习后,已具备 LLM 训练循环的最小可行基础。下一步建议在此基础上叠加 AMP + 梯度累积 + LR Scheduler,构建完整的预训练/微调 pipeline。

一个生产级的训练循环不仅仅是 for 循环的嵌套,它是状态管理、资源调度和可观测性的综合体。在 LLM 训练中,一次中断可能意味着数百万美元的算力浪费,因此"断点续训"、"指标监控"和"防过拟合"是工程底线。

完整训练流程:三合一工程模板

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import os

# ==================== 1. 早停机制 ====================
class EarlyStopping:
    """验证指标停滞时自动终止训练"""
    def __init__(self, patience=5, min_delta=0.0, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return True
        
        improved = (score < self.best_score - self.min_delta) if self.mode == 'min' \
              else (score > self.best_score + self.min_delta)
        
        if improved:
            self.best_score = score
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False

# ==================== 2. 核心训练/验证函数 ====================
def train_one_epoch(model, loader, criterion, optimizer, device, writer=None, epoch=None):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch} [Train]", leave=False)
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    avg_loss = total_loss / len(loader)
    accuracy = 100.0 * correct / total
    
    if writer and epoch is not None:
        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Acc/train', accuracy, epoch)
    
    return avg_loss, accuracy

@torch.no_grad()
def validate(model, loader, criterion, device, writer=None, epoch=None):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    
    for inputs, targets in tqdm(loader, desc="Validating", leave=False):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    avg_loss = total_loss / len(loader)
    accuracy = 100.0 * correct / total
    
    if writer and epoch is not None:
        writer.add_scalar('Loss/val', avg_loss, epoch)
        writer.add_scalar('Acc/val', accuracy, epoch)
    
    return avg_loss, accuracy

# ==================== 3. 主训练入口 ====================
def main():
    # --- 配置 ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_EPOCHS = 100
    SAVE_DIR = './checkpoints'
    os.makedirs(SAVE_DIR, exist_ok=True)
    
    # --- 数据 & 模型 ---
    train_loader, val_loader = create_dummy_dataset()  # 使用题目提供的函数
    model = SimpleClassifier().to(DEVICE)              # 使用题目提供的模型
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
    
    # --- 调度器 & 早停 & 日志 ---
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    early_stopping = EarlyStopping(patience=7, mode='min')
    writer = SummaryWriter('runs/exp_001')
    
    best_val_loss = float('inf')
    
    for epoch in range(NUM_EPOCHS):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE, writer, epoch)
        val_loss, val_acc = validate(model, val_loader, criterion, DEVICE, writer, epoch)
        
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar('LR', current_lr, epoch)
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f} Acc={train_acc:.2f}% | "
              f"Val Loss={val_loss:.4f} Acc={val_acc:.2f}% | LR={current_lr:.2e}")
        
        # 学习率调度(基于验证损失)
        scheduler.step(val_loss)
        
        # 保存最佳模型 + 完整 Checkpoint
        is_best = early_stopping(val_loss)
        if is_best:
            best_val_loss = val_loss
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
            }
            torch.save(checkpoint, os.path.join(SAVE_DIR, 'best_checkpoint.pt'))
            print(f"  💾 Best model saved (val_loss={val_loss:.4f})")
        
        # 早停检查
        if early_stopping.early_stop:
            print(f"\n Early stopping triggered at epoch {epoch}")
            break
    
    writer.close()
    print(f"\n Training complete. Best Val Loss: {best_val_loss:.4f}")

核心组件深度解析与避坑指南

model.train() vs model.eval():不是装饰器,是行为开关**

层类型 train()模式 eval()模式 忘记切换的后果
Dropout 随机置零,输出乘以 1/(1−p)1/(1−p) 恒等映射 推理时输出幅度偏移,精度暴跌
BatchNorm 用当前 batch 统计量,更新 running stats 用冻结的 running_mean/var 推理结果随 batch size 变化,不可复现
Linear/Conv 无区别 无区别 无影响

致命陷阱model.eval() 不会 禁用梯度计算!必须配合 with torch.no_grad(): 使用。仅用 eval() 做验证,显存仍会被计算图占满,导致 OOM。

Checkpoint 保存策略:为什么不能只存 state_dict**?**在生产环境中,永远保存完整 checkpoint 而非仅模型权重:

python 复制代码
# 仅存权重:无法恢复训练,优化器动量丢失
torch.save(model.state_dict(), 'model.pt')

# 完整 checkpoint:支持断点续训,实验完全可复现
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),   # Adam 的 m/v 状态
    'scheduler_state_dict': scheduler.state_dict(),   # LR 调度进度
    'best_val_loss': best_val_loss,
    'config': {'lr': 1e-3, 'batch_size': 32},         # 超参数快照
}
torch.save(checkpoint, 'checkpoint.pt')
  • Optimizer State 的重要性 :Adam/AdamW 维护每个参数的一阶矩(m)和二阶矩(v)。如果只恢复模型权重而重置优化器,相当于从第 0 步重新开始自适应学习率调整,前几十步的训练效果会严重退化
  • 加载时的设备安全torch.load(path, map_location=device) 防止在无 GPU 机器上加载 CUDA checkpoint 时报错。

学习率调度器选择指南

调度器 适用场景 LLM 推荐度 注意事项
StepLR 传统 CV,固定衰减 过于粗糙,LLM 几乎不用
ReduceLROnPlateau 通用微调,验证驱动 ⭐⭐⭐ 本题推荐;patience 通常 3-5
CosineAnnealingLR LLM 预训练标配 ⭐⭐⭐⭐⭐ 配合 warmup 使用
OneCycleLR 快速收敛,超参搜索 ⭐⭐ 对峰值 LR 敏感

LLM 预训练补充 :实际 LLM 预训练使用 Warmup + Cosine Decay ,PyTorch 原生不支持 warmup,需用 transformers.get_cosine_schedule_with_warmup 或自定义 LambdaLR。本题的微调场景用 ReduceLROnPlateau 是最稳妥的选择。

日志与可观测性

  • tqdm :提供实时反馈,set_postfix 显示关键指标。注意在分布式训练中只在 rank 0 启用 tqdm,否则终端输出混乱。
  • TensorBoard :记录标量(loss/acc/lr)、直方图(梯度/权重分布)、文本(生成样本)。务必记录学习率曲线,它是诊断训练异常的第一线索。
  • WandB(进阶):LLM 项目事实标准,支持超参搜索、团队协作、artifact 版本管理。API 与 TensorBoard 类似,迁移成本低。
检查项 通过标准 常见错误
模式切换 train/eval 成对出现,eval 配 no_grad 验证时未切 eval,BN/Dropout 污染结果
梯度清零 zero_grad()backward() 放在 step 后也可,但绝不能夹在中间
Checkpoint 包含 optimizer + scheduler state 只存 model weights,断点续训失效
早停 基于验证指标,有 patience 缓冲 基于训练 loss 早停(永远不触发或过早触发)
设备安全 所有 tensor/model/scheduler 同设备 CPU/GPU 混用;load 未指定 map_location
Loss 记录 使用 .item() 脱离计算图 直接累加 tensor 导致显存泄漏

从练习到生产的下一步

掌握此模板后,向 LLM 训练进阶需叠加以下组件:

  1. 混合精度 (AMP)torch.cuda.amp.autocast + GradScaler,FP16/BF16 训练提速 2-3x
  2. 梯度累积:等效大 batch,见 Autograd 章节练习 2
  3. 分布式训练:DDP/FSDP,多卡/多机扩展
  4. Gradient Checkpointing:以计算换显存,支撑更长序列
  5. Flash Attention:IO-aware 的注意力实现,LLM 训练必备

此模板是所有这些高级技术的骨架。确保骨架正确,再逐步叠加血肉,是避免 LLM 训练中"玄学 bug"的最佳实践。

激活函数是神经网络中引入非线性的核心组件。在 LLM 时代,ReLU 已逐渐退出主干网络,GELU 和 SiLU(Swish)成为 Transformer 架构的标准配置。理解它们的数学本质与工程实现,是阅读和修改现代 LLM 代码的前提。

relu: 经典基线

python 复制代码
import torch
import torch.nn.functional as F
import math

def relu(x: torch.Tensor) -> torch.Tensor:
    """
    Rectified Linear Unit
    公式: max(0, x)
    """
    # 推荐:调用官方算子,底层 C++/CUDA 优化,支持 autograd
    return F.relu(x)

    # 不推荐手写:torch.maximum(x, torch.zeros_like(x))
    # 原因:产生额外张量分配,autograd 图更复杂,性能差
  • Shape 不变性:ReLU 是逐元素运算,输入输出 shape 完全一致。这是所有激活函数的基本契约。
  • LLM 中的地位 :ReLU 在现代 LLM 主干中已被淘汰,原因是其在 x=0x =0 处不可导且负半轴梯度为零(dying ReLU 问题)。但在以下场景仍可见到:① 早期模型(GPT-1);② 某些 MLP 的 gate 分支变体;③ 作为教学基线。生产代码中优先使用 GELU/SiLU
  • 为什么不用手写? F.relu 在 CUDA 上是单个 fused kernel,而手写 maximum 会触发多个 kernel launch + 临时显存分配。在 LLM 训练中,激活函数被调用数十亿次,微小开销会被放大。

gelu_exact: LLM 的事实标准

python 复制代码
def gelu_exact(x: torch.Tensor) -> torch.Tensor:
    """
    Gaussian Error Linear Unit (精确版本)
    公式: x * Φ(x),其中 Φ(x) 是标准正态分布的累积分布函数
    
    注意:PyTorch F.gelu(x) 默认即为精确版本(approximate='none')
    """
    # 精确 GELU:使用 erf 函数计算 CDF
    # Φ(x) = 0.5 * (1 + erf(x / sqrt(2)))
    cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
    return x * cdf

    # 等价官方调用(验证用):
    # return F.gelu(x, approximate='none')

精确 vs 近似:必须区分的两个版本

版本 公式 PyTorch API 使用场景
Exact x⋅Φ(x) F.gelu(x, approximate='none') BERT, GPT-2, RoBERTa
Tanh Approx 0.5x(1+tanh⁡(2/π(x+0.044715x3)))0.5x(1+tanh⁡(\sqrt{2/π}(x+0.044715x^3)))0.5x(1+tanh⁡(2/π (x+0.044715x3))) F.gelu(x, approximate='tanh') GPT-J, LLaMA-1, StarCoder
  • 为什么有近似版? 精确版需要计算 erf,在早期 GPU 上较慢。Tanh 近似仅用基础算术运算,速度快约 5-10%,且最大绝对误差 < 10−410−4 。
  • LLM 选型指南 :加载预训练权重时,必须匹配原始论文使用的版本 。用错版本会导致 logits 分布偏移,下游任务性能下降 1-3%。HuggingFace Transformers 的配置文件中 hidden_act 字段明确指定了版本。
  • 数值安全erf 在 FP16 下对大值饱和为 ±1,不会溢出。BF16 下精度更好,推荐 LLM 训练使用 BF16。

silu: Swish 的现代命名

python 复制代码
def silu(x: torch.Tensor) -> torch.Tensor:
    """
    Sigmoid Linear Unit (aka Swish)
    公式: x * σ(x) = x / (1 + exp(-x))
    
    PyTorch 中 F.silu 即为此函数
    """
    # 官方实现,内部 fused sigmoid + multiply
    return F.silu(x)

    # 手动等价实现(仅供理解,勿用于生产):
    # return x * torch.sigmoid(x)

LLM 中的关键角色

  • SwiGLU 的核心 :现代 LLM(LLaMA-2/3, Qwen, Mistral, Gemma)的 FFN 几乎全部采用 SwiGLU 结构: SwiGLU(x)=SiLU(xWgate)⊗(xWup)SwiGLU(x)=SiLU(xWgate)⊗(xWup)SwiGLU(x)=SiLU(xWgate)⊗(xWup)。SiLU 作为 gate 激活函数,其平滑的非线性特性使梯度流更稳定,训练 loss 比 GELU 低 0.1-0.3。
  • vs GELU 对比:SiLU 在负半轴有非零输出(类似 leaky),信息保留更多;GELU 负半轴衰减更快。实证表明 SiLU 在大模型规模下略优,但差距随 scale 增大而缩小。
  • Fused Kernel :在实际 LLM 框架(如 Megatron-LM, vLLM)中,SiLU + element-wise multiply 通常被融合为单个 CUDA kernel(silu_and_mul),避免中间张量的显存读写。写自定义算子时务必考虑这种融合机会

activation_summary: 可复用的激活工厂

python 复制代码
from typing import Callable, Dict

def activation_summary() -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
    """
    返回可用激活函数的注册表
    用于模型配置驱动的动态模块构建
    """
    return {
        "relu": relu,
        "gelu": gelu_exact,           # 精确版
        "gelu_tanh": lambda x: F.gelu(x, approximate="tanh"),  # 近似版
        "silu": silu,
        "swish": silu,                # swish 是 silu 的别名
    }

# 使用示例:配置驱动的 MLP 构建
def get_activation(name: str) -> Callable:
    registry = activation_summary()
    if name not in registry:
        raise ValueError(f"Unknown activation '{name}'. Available: {list(registry.keys())}")
    return registry[name]

# 在 nn.Module 中使用
class ConfigurableMLP(torch.nn.Module):
    def __init__(self, d_in, d_hidden, d_out, act_name="silu"):
        super().__init__()
        self.fc1 = torch.nn.Linear(d_in, d_hidden)
        self.act = get_activation(act_name)  # 动态选择
        self.fc2 = torch.nn.Linear(d_hidden, d_out)
    
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))
  • 配置驱动 :LLM 实验中频繁切换激活函数做 ablation study。硬编码 F.gelu 需要改源码;通过 registry + config YAML,无需触碰模型代码即可切换。
  • HuggingFace 对齐 :Transformers 库的 ACT2FN 字典就是此模式的工业实现。自建 registry 时保持相同键名("gelu", "silu", "gelu_pytorch_tanh"),便于权重迁移。
  • 可扩展性:新增自定义激活(如 QuickGELU、ReGLU)只需在 registry 中添加一行,符合开闭原则。

每个自定义激活函数必须通过以下验证:

python 复制代码
def test_activations():
    """确保手写实现与 PyTorch 官方数值一致"""
    x = torch.randn(4, 8, dtype=torch.float64)  # double 精度测试
    
    # ReLU
    assert torch.allclose(relu(x), F.relu(x), atol=1e-10)
    
    # GELU exact
    assert torch.allclose(gelu_exact(x), F.gelu(x, approximate='none'), atol=1e-10)
    
    # SiLU
    assert torch.allclose(silu(x), F.silu(x), atol=1e-10)
    
    # Shape 守恒
    for fn in [relu, gelu_exact, silu]:
        assert fn(x).shape == x.shape, f"{fn.__name__} changed shape!"
    
    # Autograd 连通性
    for fn in [relu, gelu_exact, silu]:
        x_grad = torch.randn(4, 8, requires_grad=True)
        fn(x_grad).sum().backward()
        assert x_grad.grad is not None, f"{fn.__name__} gradient disconnected!"
    
    print("All activation tests passed!")

test_activations()

激活函数选型速查表

激活函数 公式 LLM 代表模型 推荐度 备注
SiLU/Swish xσ(x) LLaMA, Qwen, Mistral ⭐⭐⭐⭐⭐ 当前主流,SwiGLU 标配
GELU (tanh) tanh 近似 GPT-J, LLaMA-1 ⭐⭐⭐⭐ 兼容性好,速度略快
GELU (exact) xΦ(x) BERT, GPT-2 ⭐⭐⭐ 老模型权重加载必需
ReLU max⁡(0,x) GPT-1 仅作基线,新模型避免
QuickGELU xσ(1.702x) OpenCLIP ⭐⭐ GELU 的快速近似

掌握这四个练习后,已具备阅读和修改任何 LLM FFN/Attention 代码中激活函数部分的能力。下一步建议结合 SwiGLU 结构,实现一个完整的 LLM-style FeedForward 模块,并用 activation_summary 使其支持配置化切换。

归一化技术是深度学习中稳定训练、加速收敛的基石。在 LLM 时代,BatchNorm 因序列长度动态变化和因果掩码兼容性问题已被淘汰,LayerNorm(及其变体 RMSNorm)成为 Transformer 的唯一标准。但理解 BatchNorm 的训练/推理双态机制,是掌握所有归一化技术的必经之路。

batch_norm_train: 训练态手写实现

python 复制代码
import torch
import math

def batch_norm_train(
    x: torch.Tensor,
    gamma: torch.Tensor,
    beta: torch.Tensor,
    eps: float = 1e-5
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    训练态 BatchNorm(沿 batch 维统计)
    
    Args:
        x: shape (N, C, *) --- N=batch, C=channels/features
        gamma: scale 参数, shape (C,)
        beta: shift 参数, shape (C,)
        eps: 数值稳定性常数
    
    Returns:
        y: 归一化输出, shape 同 x
        mean: 当前 batch 均值, shape (C,) --- 用于更新 running stats
        var: 当前 batch 方差, shape (C,) --- 用于更新 running stats
    """
    # 沿 batch 维度(dim=0)计算统计量,保留 channel 维度
    # keepdim=True 保证广播正确性
    reduce_dims = [0] + list(range(2, x.dim()))  # (0, 2, 3, ...) for 4D input
    mean = x.mean(dim=reduce_dims, keepdim=True)
    var = x.var(dim=reduce_dims, unbiased=False, keepdim=True)
    
    # 归一化 + 仿射变换
    x_hat = (x - mean) / torch.sqrt(var + eps)
    y = gamma * x_hat + beta
    
    return y, mean.squeeze(), var.squeeze()
  • unbiased=False :BatchNorm 使用总体方差 (除以 N),而非样本方差(除以 N-1)。这与 PyTorch nn.BatchNorm 默认行为一致。用错会导致推理时数值偏移。
  • 返回值设计 :训练态必须返回当前 batch 的 mean/var,供 update_running_stats 使用。这是训练/推理分离的核心接口。
  • Shape 约定 :BatchNorm 对 (N, C, H, W)(N, C) 都适用,reduce_dims 的动态计算使其通用。

batch_norm_eval: 推理态手写实现

python 复制代码
def batch_norm_eval(
    x: torch.Tensor,
    gamma: torch.Tensor,
    beta: torch.Tensor,
    running_mean: torch.Tensor,
    running_var: torch.Tensor,
    eps: float = 1e-5
) -> torch.Tensor:
    """
    推理态 BatchNorm(使用累积统计量)
    
    Args:
        x: shape (N, C, *)
        running_mean: 训练期间累积的均值, shape (C,)
        running_var: 训练期间累积的方差, shape (C,)
    """
    # 推理时完全不依赖当前 batch 的统计量
    # reshape 为 (1, C, 1, ...) 以支持广播
    shape = [1, -1] + [1] * (x.dim() - 2)
    rm = running_mean.view(shape)
    rv = running_var.view(shape)
    
    x_hat = (x - rm) / torch.sqrt(rv + eps)
    return gamma * x_hat + beta

为什么推理不能用当前 batch 统计?

场景 使用当前 batch stats 使用 running stats
单样本推理 方差=0,除零错误 正常
小 batch 推理 统计量噪声大,结果不稳定 稳定
批量推理 结果随 batch 内容变化 确定性输出
  • 确定性要求:生产环境中,相同输入必须产生相同输出。若用当前 batch stats,同一张图片在不同 batch 中会得到不同结果,这在部署时是不可接受的 bug。
  • model.eval() 的本质 :就是切换到此函数路径。忘记调用 eval() 是 LLM 推理性能下降的最常见原因之一。

layer_norm_last_dim: Transformer 的标准归一化

python 复制代码
def layer_norm_last_dim(
    x: torch.Tensor,
    gamma: torch.Tensor,
    beta: torch.Tensor,
    eps: float = 1e-5
) -> torch.Tensor:
    """
    LayerNorm 沿最后一个维度归一化
    等价于 nn.LayerNorm(normalized_shape=x.shape[-1])
    
    Args:
        x: shape (*, D) --- D 是特征维度
        gamma, beta: shape (D,)
    """
    # 仅沿最后一维统计,每个样本独立归一化
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    
    x_hat = (x - mean) / torch.sqrt(var + eps)
    return gamma * x_hat + beta

为什么 LLM 只用 LayerNorm?

  • 序列长度无关:LayerNorm 对每个 token 独立归一化,不依赖 batch 内其他样本。变长序列、padding、因果掩码都不影响统计量。

  • 训练/推理一致:无 running stats,无需区分 train/eval 模式。这简化了分布式训练和推理部署。

  • Pre-Norm vs Post-Norm :现代 LLM 几乎全部采用 Pre-Norm(归一化在 Attention/FFN 之前)。Post-Norm(原始 Transformer)梯度流不稳定,深层网络难以训练。

    Pre-Norm: x → LN → Sublayer → Add(x, ·) LLaMA/GPT-3/Qwen
    Post-Norm: x → Sublayer → Add(x, ·) → LN 仅原始 Transformer


update_running_stats: 指数移动平均

python 复制代码
def update_running_stats(
    running_mean: torch.Tensor,
    running_var: torch.Tensor,
    batch_mean: torch.Tensor,
    batch_var: torch.Tensor,
    momentum: float = 0.1
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    指数移动平均更新累积统计量
    
    Args:
        momentum: EMA 系数,PyTorch 默认 0.1
                  new = (1-momentum)*old + momentum*batch
    """
    # 原地更新避免额外内存分配(生产代码推荐 inplace)
    new_mean = (1.0 - momentum) * running_mean + momentum * batch_mean
    new_var = (1.0 - momentum) * running_var + momentum * batch_var
    
    return new_mean, new_var
  • momentum 语义反转 :PyTorch 的 momentum=0.1 表示新值权重为 0.1(即旧值权重 0.9)。这与优化器中 momentum=0.9 的含义相反。混用会导致 running stats 要么震荡不收敛,要么更新过慢。
  • 初始化running_mean 初始化为 0,running_var 初始化为 1。错误的初始值(如 var=0)会导致前几个 batch 的推理结果异常。
  • LLM 中的替代方案:由于 LLM 不使用 BatchNorm,此函数主要用于理解原理。实际 LLM 训练中,RMSNorm 完全不需要 running stats,这也是其被广泛采用的原因之一------消除了状态同步开销,对分布式训练更友好。

python 复制代码
def test_normalizations():
    torch.manual_seed(42)
    N, C, H, W = 8, 16, 4, 4
    
    # === BatchNorm 测试 ===
    x_bn = torch.randn(N, C, H, W, dtype=torch.float64)
    gamma_bn = torch.ones(C, dtype=torch.float64)
    beta_bn = torch.zeros(C, dtype=torch.float64)
    
    # 训练态对齐
    y_custom, bm, bv = batch_norm_train(x_bn, gamma_bn, beta_bn)
    bn_layer = torch.nn.BatchNorm2d(C).double()
    bn_layer.train()
    y_official = bn_layer(x_bn)
    assert torch.allclose(y_custom, y_official, atol=1e-8), "BN train mismatch!"
    
    # 推理态对齐
    bn_layer.eval()
    y_eval_official = bn_layer(x_bn)
    y_eval_custom = batch_norm_eval(x_bn, gamma_bn, beta_bn, 
                                     bn_layer.running_mean, bn_layer.running_var)
    assert torch.allclose(y_eval_custom, y_eval_official, atol=1e-8), "BN eval mismatch!"
    
    # === LayerNorm 测试 ===
    D = 768
    x_ln = torch.randn(4, 10, D, dtype=torch.float64)
    gamma_ln = torch.ones(D, dtype=torch.float64)
    beta_ln = torch.zeros(D, dtype=torch.float64)
    
    y_ln_custom = layer_norm_last_dim(x_ln, gamma_ln, beta_ln)
    ln_layer = torch.nn.LayerNorm(D).double()
    y_ln_official = ln_layer(x_ln)
    assert torch.allclose(y_ln_custom, y_ln_official, atol=1e-8), "LN mismatch!"
    
    print("All normalization tests passed!")

test_normalizations()

归一化技术选型速查表

技术 统计维度 Running Stats LLM 适用性 代表模型
BatchNorm Batch (dim=0) 需要 不适用 ResNet, CNN
LayerNorm Feature (last dim) 不需要 标准 GPT, BERT
RMSNorm Feature (last dim) 不需要 主流 LLaMA, Qwen, Gemma
GroupNorm Group of channels 不需要 特殊场景 Diffusion, ConvNeXt

LLM 进阶:从 LayerNorm 到 RMSNorm 。掌握本题后,下一步应实现 RMSNorm (Root Mean Square Normalization):RMSNorm(x)=xmean(x2)+ϵ⋅γRMSNorm(x)=\frac{x}{mean(x2)+ϵ}⋅γRMSNorm(x)=mean(x2)+ϵx⋅γ。相比 LayerNorm,RMSNorm 去掉了均值中心化(subtract mean),仅做缩放。实证表明这对 LLM 性能无损,但减少了约 50% 的计算量和一次全局 reduction 操作,在大模型训练中节省可观算力。LLaMA-2/3、Qwen-2、Gemma 等主流开源模型均采用 RMSNorm + Pre-Norm 组合。

Attention 机制是 Transformer 和所有现代 LLM 的基石。手写 Scaled Dot-Product Attention 不仅是理解论文公式的最佳途径,更是后续实现 Multi-Head Attention、Flash Attention 以及 KV Cache 等高级优化的前置条件。

build_causal_mask: 构建因果掩码

python 复制代码
import torch
import torch.nn.functional as F
import math

def build_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
    """
    构建下三角因果掩码(Causal Mask)
    
    Args:
        seq_len: 序列长度 T
        device: 目标设备
        
    Returns:
        mask: shape (T, T), dtype=torch.bool
              True 表示该位置【允许】被 attend,False 表示被屏蔽
    """
    # 使用 tril 生成下三角矩阵
    # diagonal=0 包含对角线(当前 token 可以 attend 自己)
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
    return mask
  • 为什么需要 Causal Mask? 自回归语言模型在训练时使用 Teacher Forcing(并行处理整个序列),但推理时是逐 token 生成的。Mask 确保第 tt 个 token 只能看到 0,t 的信息,防止"偷看未来",保证训练与推理的行为一致性。
  • Bool vs Float Mask :推荐使用 torch.bool 类型。PyTorch 的 masked_fill 和 Flash Attention 原生支持 bool mask,比 float mask(0/-inf)更省显存且语义清晰。仅在需要与旧版代码兼容时才用 float mask。
  • Shape 扩展 :实际 MHA 中 mask 需广播到 (B, num_heads, T, T)。返回 (T, T) 作为基础模板,由上层函数按需 unsqueeze,保持接口纯净。

masked_softmax: 数值稳定的带掩码 Softmax

python 复制代码
def masked_softmax(
    scores: torch.Tensor,
    mask: torch.Tensor = None,
    dim: int = -1
) -> torch.Tensor:
    """
    带掩码的数值稳定 softmax
    
    Args:
        scores: attention scores, shape (*, T)
        mask: bool mask, True=保留, False=屏蔽; 或 None
        dim: softmax 维度
    """
    if mask is not None:
        # ✅ 将被屏蔽位置设为 -inf,softmax 后精确为 0
        # 使用 float('-inf') 而非 -1e9,避免大值 logits 下泄漏
        scores = scores.masked_fill(~mask, float('-inf'))
    
    # ✅ PyTorch softmax 内部已做 max 减法,数值稳定
    # ❌ 不要手写 exp(x)/sum(exp(x)),极易溢出
    return F.softmax(scores, dim=dim)
做法 正确性 原因
masked_fill(~mask, float('-inf')) softmax(-inf) = 0,精确屏蔽
masked_fill(~mask, -1e9) ⚠️ FP16/BF16 下 -1e9 可能被截断;大 logits 时泄漏非零概率
scores * mask.float() 乘法后未屏蔽位置仍参与 softmax 分母计算,梯度错误
手动 exp(x-max)/sum(...) 冗余实现,无性能收益且易出错
  • NaN 陷阱 :当某行全部被 mask (如 padding 行全为 False)时,softmax([-inf, -inf, ...]) 会产生 NaN。生产代码需在 softmax 后将全 mask 行的输出置零:attn_weights = attn_weights.masked_fill(~mask.any(dim=-1, keepdim=True), 0.0)
  • FP16/BF16 安全float('-inf') 在 FP16/BF16 下均可正确表示。但 -1e9 在 FP16 中会被 clamp 到 -65504,导致屏蔽失效。LLM 训练中永远使用 -inf

attention_weights: 计算注意力权重矩阵

python 复制代码
def attention_weights(
    q: torch.Tensor,
    k: torch.Tensor,
    scale: float = None,
    mask: torch.Tensor = None
) -> torch.Tensor:
    """
    计算归一化注意力权重
    
    Args:
        q: Query, shape (B, T_q, D)
        k: Key, shape (B, T_k, D)
        scale: 缩放因子,默认 1/sqrt(D)
        mask: causal mask or padding mask, shape broadcastable to (B, T_q, T_k)
    
    Returns:
        weights: shape (B, T_q, T_k),每行和为 1
    """
    D = q.size(-1)
    if scale is None:
        scale = 1.0 / math.sqrt(D)
    
    # QK^T 点积 + 缩放
    # matmul 自动处理 batch 维广播
    scores = torch.matmul(q, k.transpose(-2, -1)) * scale
    
    # 带掩码的数值稳定 softmax
    weights = masked_softmax(scores, mask=mask, dim=-1)
    
    return weights

为什么要 Scale(除以 D)?

  • 假设 Q,K的元素独立同分布、均值 0、方差 1,则 QKT 每个元素的方差为 D。
  • 当 D 很大(LLM 通常 4096+)时,点积值的绝对值极大,softmax 进入梯度饱和区(输出接近 one-hot),梯度几乎为零,训练停滞。
  • 除以 D 将方差重新归一化为 1,使 softmax 工作在梯度敏感区间。
  • 注意 :scale 应在 softmax 之前应用。先 softmax 再 scale 等价于改变 temperature,语义完全不同。

scaled_dot_product_attention: 完整 SDPA

python 复制代码
def scaled_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    mask: torch.Tensor = None,
    dropout_p: float = 0.0,
    training: bool = True
) -> torch.Tensor:
    """
    完整的 Scaled Dot-Product Attention
    
    Args:
        q: (B, T_q, D)
        k: (B, T_k, D)
        v: (B, T_k, D_v)
        mask: broadcastable to (B, T_q, T_k)
        dropout_p: attention dropout 概率
        training: 是否处于训练模式
    
    Returns:
        output: (B, T_q, D_v)
    """
    # Step 1: 计算注意力权重
    attn_weights = attention_weights(q, k, mask=mask)
    
    # Step 2: Attention Dropout(仅在训练时生效)
    if dropout_p > 0.0 and training:
        attn_weights = F.dropout(attn_weights, p=dropout_p, training=True)
    
    # Step 3: 加权求和 V
    output = torch.matmul(attn_weights, v)
    
    return output
  • Attention Dropout ≠ Regular Dropout:Attention dropout 作用于权重矩阵(哪些连接被断开),而非特征值。它正则化的是"注意力模式"本身,防止模型过度依赖特定 token pair。LLM 中常用 0.0~0.1。

  • 生产环境请用 F.scaled_dot_product_attention:PyTorch 2.0+ 的原生 SDPA 会自动选择最优后端(Flash Attention / Memory-Efficient / Math),比手写快 2-5x 且省显存。手写版本仅用于学习和调试。

    python 复制代码
    # 生产代码一行搞定
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, 
                                          dropout_p=0.0, is_causal=True)
  • is_causal=True 优化 :当传入 is_causal=True 时,Flash Attention 后端会使用专门的 causal kernel,跳过上三角区域的计算,进一步提速 ~30%。永远优先使用此参数而非手动构建 mask tensor


python 复制代码
def test_sdpa():
    torch.manual_seed(42)
    B, T, D = 2, 8, 64
    
    q = torch.randn(B, T, D, dtype=torch.float64)
    k = torch.randn(B, T, D, dtype=torch.float64)
    v = torch.randn(B, T, D, dtype=torch.float64)
    
    # Causal mask
    mask = build_causal_mask(T)
    
    # 自定义实现
    out_custom = scaled_dot_product_attention(q, k, v, mask=mask, dropout_p=0.0)
    
    # PyTorch 官方 SDPA(math backend 确保精确对比)
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
        out_official = F.scaled_dot_product_attention(
            q, k, v, attn_mask=mask, dropout_p=0.0
        )
    
    assert torch.allclose(out_custom, out_official, atol=1e-8), \
        f"SDPA mismatch! Max diff: {(out_custom - out_official).abs().max()}"
    
    # 验证因果性:上三角应为 0(通过梯度间接验证)
    # 直接验证:检查 weights 的上三角
    weights = attention_weights(q, k, mask=mask)
    upper_tri = torch.triu(weights, diagonal=1)
    assert torch.allclose(upper_tri, torch.zeros_like(upper_tri), atol=1e-10), \
        "Causal mask leaked future tokens!"
    
    print("All SDPA tests passed!")

test_sdpa()

Attention 变体速查表

变体 公式差异 LLM 代表 备注
Standard SDPA softmax(QKTD)Vsoftmax(\frac{QK^T}{\sqrt D})Vsoftmax(D QKT)V GPT-2, BERT 本题实现
Multi-Head 拆分 D→H×d,并行 SDPA,concat 所有 Transformer 下一章内容
GQA/MQA K/V head 数 < Q head 数 LLaMA-3, Qwen-2 减少 KV Cache 显存
RoPE 旋转位置编码注入 Q/K LLaMA, Qwen, Mistral 替代绝对位置编码
ALiBi 线性偏置加到 scores 上 BLOOM, MPT 无需位置嵌入

从本题到 LLM 的进阶路径,掌握 SDPA 后,按以下顺序进阶:

  1. Multi-Head Attention :理解 head 拆分/合并、view vs reshape 的内存连续性陷阱
  2. KV Cache :推理时缓存历史 K/V,将 O(T2)O (T 2) 降为 O(T)O (T) ,是自回归解码的核心
  3. Flash Attention :IO-aware 的分块算法,理解为什么它不存储完整 T×TT ×T 矩阵
  4. RMSNorm + RoPE + GQA:组装现代 LLM 的标准 Attention 模块

这四个练习构成了上述所有高级技术的原子操作。确保每一步的 shape、dtype、数值稳定性都经过严格验证,后续叠加复杂度时才能快速定位问题。

PyTorch Profiling 是从"代码能跑"到"代码跑得快"的关键跨越。在 LLM 训练中,性能优化直接等同于节省数十万美元的算力成本。Profiler 不仅是测量工具,更是理解 CPU-GPU 异步执行模型内存分配模式算子融合机会的诊断器。

分析模型性能瓶颈(完整解答)

python 复制代码
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def analyze_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MyModel().to(device)
    # 使用真实 batch size,小 batch 无法触发 GPU 并行优势
    inputs = torch.randn(32, 3, 32, 32, device=device)
    
    # Warmup:前几步 GPU kernel 编译/缓存未命中,数据不具代表性
    for _ in range(3):
        _ = model(inputs)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    activities = [ProfilerActivity.CPU]
    if device.type == 'cuda':
        activities.append(ProfilerActivity.CUDA)
    
    with profile(
        activities=activities,
        record_shapes=True,      # 按 shape 分组,区分不同尺寸的同类算子
        profile_memory=True,     # 追踪显存峰值与分配
        with_stack=True,         # 定位源码行号
        on_trace_ready=tensorboard_trace_handler('./log/model_profile')
    ) as prof:
        for _ in range(5):       # 多步采样取平均
            output = model(inputs)
            prof.step()
    
    # 最耗时操作
    sort_key = "cuda_time_total" if device.type == 'cuda' else "cpu_time_total"
    print("=" * 80)
    print("TOP 10 HOTTEST OPERATORS")
    print(prof.key_averages(group_by_stack_n=5).table(
        sort_by=sort_key, row_limit=10
    ))
    
    # CPU vs GPU 时间比
    cpu_total = sum(e.cpu_time_total for e in prof.key_averages())
    gpu_total = sum(e.cuda_time_total for e in prof.key_averages()) if device.type == 'cuda' else 0
    ratio = gpu_total / max(cpu_total, 1)
    print(f"\n⏱️  CPU total: {cpu_total/1e3:.2f}ms | GPU total: {gpu_total/1e3:.2f}ms | Ratio: {ratio:.2f}")
    if ratio < 0.5:
        print(" GPU underutilized! Possible CPU bottleneck or excessive sync.")
    
    # 内存分析
    print("\n MEMORY SUMMARY")
    print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=5))
    
    print(f"\n TensorBoard: tensorboard --logdir=./log/model_profile")

analyze_model()
  • Self CPU % vs CPU total %Self 是算子自身开销(如 kernel launch),total 包含子调用。若 aten::linear 的 Self 很高但 aten::addmm 的 CUDA 时间很低,说明 kernel launch overhead 主导,应考虑算子融合或增大 batch。
  • group_by_stack_n=5 :将同名算子按调用栈分组。例如两个 aten::relu 分别来自 conv1 和 fc1,不加此参数会合并显示,无法定位具体层。
  • GPU 利用率判断 :Ratio > 2 通常健康;< 0.5 意味着 GPU 大量空闲等待 CPU。常见原因:DataLoader 慢、频繁 .item() 同步、Python 循环过重。

Batch Size 性能对比(含吞吐量计算)

python 复制代码
import time

def benchmark_batch_sizes(model_class, input_shape_no_batch, batch_sizes, device='cuda'):
    """系统化 batch size 扫描"""
    results = []
    
    for bs in batch_sizes:
        model = model_class().to(device)
        x = torch.randn(bs, *input_shape_no_batch, device=device)
        
        # Warmup
        for _ in range(5):
            _ = model(x)
        if device == 'cuda':
            torch.cuda.synchronize()
        
        # Benchmark
        start_event = torch.cuda.Event(enable_timing=True) if device == 'cuda' else None
        end_event = torch.cuda.Event(enable_timing=True) if device == 'cuda' else None
        
        if device == 'cuda':
            start_event.record()
        else:
            t0 = time.perf_counter()
            
        num_runs = 50
        for _ in range(num_runs):
            _ = model(x)
            
        if device == 'cuda':
            end_event.record()
            torch.cuda.synchronize()
            elapsed_ms = start_event.elapsed_time(end_event) / num_runs
        else:
            elapsed_ms = (time.perf_counter() - t0) / num_runs * 1000
        
        throughput = bs / (elapsed_ms / 1000)  # samples/sec
        
        # 显存占用
        mem_mb = torch.cuda.max_memory_allocated(device) / 1024**2 if device == 'cuda' else 0
        
        results.append({
            'batch_size': bs,
            'latency_ms': round(elapsed_ms, 3),
            'throughput_sps': round(throughput, 1),
            'gpu_mem_mb': round(mem_mb, 1)
        })
        print(f"BS={bs:>4d} | Latency={elapsed_ms:>7.3f}ms | "
              f"Throughput={throughput:>8.1f} s/s | Mem={mem_mb:>7.1f}MB")
        
        torch.cuda.empty_cache()
    
    return results

# 使用
results = benchmark_batch_sizes(
    MyModel, 
    input_shape_no_batch=(3, 32, 32),
    batch_sizes=[1, 8, 32, 64, 128, 256]
)
错误做法 正确做法 原因
time.time() 测 GPU cuda.Eventsynchronize GPU 异步执行,time.time() 只测 launch 时间
跳过 warmup 至少 3-5 步预热 CUDA kernel JIT、cuDNN autotune、cache 填充
单次测量 ≥50 次取均值 消除 OS 调度、热节流噪声
忽略 OOM 恢复 try/except + empty_cache 一个 BS 失败不应中断整个扫描
  • 吞吐量 vs 延迟的权衡 :BS=1 延迟最低但吞吐最差;BS 增大吞吐提升但边际递减。最优 BS 通常在 GPU 显存占满 80-90% 时达到。超过后可能因 swap 或碎片化反而下降。
  • LLM 特殊性 :LLM 的 batch size 受序列长度影响极大。应固定 tokens_per_batch = bs × seq_len 做等 token 量对比,而非固定 bs。

DataLoader num_workers 优化

python 复制代码
from torch.utils.data import DataLoader, TensorDataset
import time

def benchmark_dataloader(dataset, batch_size=64, worker_counts=[0, 2, 4, 8, 16]):
    """找到数据加载的最优 worker 数"""
    for nw in worker_counts:
        loader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            num_workers=nw,
            pin_memory=True,      # 加速 CPU→GPU 传输
            persistent_workers=True if nw > 0 else False  # 避免每 epoch 重建进程
        )
        
        # Warmup:首个 epoch worker 初始化开销大
        for i, (x, y) in enumerate(loader):
            if i >= 5: break
            _ = x.to('cuda', non_blocking=True)
        torch.cuda.synchronize()
        
        # Measure full epoch
        start = time.perf_counter()
        count = 0
        for x, y in loader:
            _ = x.to('cuda', non_blocking=True)
            count += x.size(0)
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start
        
        throughput = count / elapsed
        print(f"num_workers={nw:>2d} | Epoch={elapsed:.2f}s | "
              f"Throughput={throughput:.0f} samples/s")

# 模拟大数据集(实际替换为真实数据集)
dataset = TensorDataset(torch.randn(10000, 3, 224, 224), torch.randint(0, 10, (10000,)))
benchmark_dataloader(dataset)

num_workers 调优经验公式

  • 起点num_workers = min(os.cpu_count(), 4 * num_gpus)
  • 上限信号:当增加 worker 后 epoch 时间不再缩短甚至变长,说明 CPU 争抢或 IPC 序列化成为新瓶颈。
  • persistent_workers=True必开选项。默认情况下每个 epoch 结束都会销毁并重建所有 worker 进程,对于大数据集这会产生数秒的固定开销。开启后 worker 跨 epoch 复用。
  • pin_memory=TrueGPU 训练必开 。锁页内存允许 DMA 异步传输,配合 non_blocking=True 实现数据传输与计算重叠。不开此项,.to('cuda') 是同步阻塞的。
  • LLM 特殊注意 :Tokenization 通常是 DataLoader 瓶颈。若 profiler 显示 collate_fn 或 tokenizer 占用大量 CPU 时间,应将 tokenization 预处理到磁盘(memory-mapped dataset),而非运行时动态 tokenize。

掌握基础 Profiler 后,以下是 LLM 工程中真正决定效率的认知升级:

理解异步执行的本质

复制代码
CPU Timeline:  [launch A][launch B][launch C][...idle...][sync]
GPU Timeline:  [........][====A====][====B====][====C====]
  • Profiler 中看到的 CPU 时间大多是 kernel launch overhead(~5-20μs/op),不是真正的计算。
  • 当 GPU 有连续的大 kernel 时,CPU 的空闲是正常的。只有当 GPU 出现空隙(gap)而 CPU 忙碌时,才是真正的 CPU bound
  • 使用 torch.compile 可将多个小 op 融合为单个 kernel,大幅减少 launch 次数。

显存分析比时间分析更重要,LLM 训练中,OOM 是第一优先级问题。Profiler 的内存视图关注:

  • Activation Checkpointing 效果:对比开启前后的 activation memory peak
  • KV Cache 增长:推理时 KV cache 应线性增长,若超线性说明有泄漏
  • 碎片化reserved - allocated 差值过大说明碎片严重,需调整分配策略或使用 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

从 Profile 到优化的决策树

复制代码
GPU 利用率低?
├── GPU timeline 有间隙 → CPU bound
│   ├── DataLoader 慢 → 增加 workers / 预处理 / FFCV/WebDataset
│   ├── Python 循环重 → torch.compile / Triton / C++ extension
│   └── 频繁同步 → 移除 .item() / 异步指标聚合
├── GPU timeline 连续但 kernel 小 → Launch overhead
│   └── torch.compile / 手动 fuse / Flash Attention
└── GPU kernel 本身慢 → Compute bound
    ├── 精度过高 → FP16/BF16 / INT8 quantization
    ├── 算法低效 → Flash Attention / PagedAttention
    └── 硬件未打满 → 增大 batch / tensor parallelism

生产级 Profiling Checklist

检查项 标准 工具
Warmup ≥3 steps before profiling 手动 / schedule(wait=N)
多步采样 active≥3, repeat≥2 profiler.schedule
Shape 分组 record_shapes=True key_averages(group_by_input_shape=True)
调用栈 with_stack=True group_by_stack_n=5
内存追踪 profile_memory=True Memory View in TensorBoard
非侵入式 不改业务代码 context manager / decorator
可复现 固定 seed + deterministic torch.use_deterministic_algorithms(True)

掌握这些内容后,已具备独立诊断和优化 LLM 训练/推理性能的能力。下一步建议结合 torch.compile 和 Triton 自定义算子,将 Profiler 发现的瓶颈转化为实际的性能收益。

显存是 LLM 训练和推理中最稀缺的资源。理解显存的构成、掌握分析工具、熟练运用优化技术,是从"跑通代码"到"规模化训练"的分水岭。以下是对三个实战练习的工程级解答,以及超越基础文档的 LLM 显存管理心法。

分析模型显存占用(完整解答)

python 复制代码
import torch
import torch.nn as nn

class LargeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(2048, 2048) for _ in range(20)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = torch.relu(layer(x))
        return x

def analyze_memory():
    device = 'cuda'
    model = LargeModel().to(device)
    batch_size, seq_len = 32, 512
    x = torch.randn(batch_size, seq_len, 2048, device=device)
    
    #  模型参数显存
    param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
    grad_bytes = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
    print(f" Parameters: {param_bytes / 1024**2:.1f} MB")
    print(f" Gradients:  {grad_bytes / 1024**2:.1f} MB")
    
    #  前向传播峰值显存
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    with torch.no_grad():  # 隔离前向,不含梯度
        y = model(x)
    fwd_peak = torch.cuda.max_memory_allocated()
    print(f"\n  Forward Peak:  {fwd_peak / 1024**2:.1f} MB")
    
    #  反向传播峰值显存(含激活值 + 梯度)
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    y = model(x)
    y.sum().backward()
    bwd_peak = torch.cuda.max_memory_allocated()
    print(f"  Backward Peak: {bwd_peak / 1024**2:.1f} MB")
    
    #  激活值显存估算
    activation_bytes = bwd_peak - param_bytes - grad_bytes
    print(f"\n Activations (est): {activation_bytes / 1024**2:.1f} MB")
    print(f"   占总峰值比例: {activation_bytes / bwd_peak * 100:.1f}%")
    
    # 理论验证:每层 Linear(2048,2048) 保存输入用于反向
    # 激活 ≈ 20 layers × (32 × 512 × 2048 × 4 bytes) ≈ 2621 MB
    theoretical_act = 20 * batch_size * seq_len * 2048 * 4
    print(f"   理论激活值:     {theoretical_act / 1024**2:.1f} MB")

analyze_memory()

显存构成公式:Total VRAM=P⏟params+G⏟gradients+O⏟optimizer states+A⏟activations+T⏟temp buffers

组件 FP32 FP16/BF16 Adam Optimizer 备注
Parameters 4B/param 2B/param --- 固定开销
Gradients 4B/param 2B/param --- 与 params 同精度
Optimizer States --- --- 8B/param (FP32) Adam: m + v 各 4B
Activations 4B/elem 2B/elem --- 随 batch×seq 线性增长
  • 关键洞察 :在 LLM 训练中,激活值通常占显存的 60-80%,远超参数本身。这就是为什么 Gradient Checkpointing 和序列并行比量化参数更重要。
  • Adam 的隐藏成本:7B 模型的 Adam 优化器状态需要 ~56GB 显存(7B × 8B),比模型参数本身还大。这是 ZeRO / FSDP 分片的核心动机。

显存优化对比(完整实现 + 基准测试)

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
import time

# === 四种训练策略实现 ===

def train_baseline(model, x, target, optimizer, criterion):
    """基线:FP32 全量训练"""
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    return loss.item()

def train_grad_accum(model, x, target, optimizer, criterion, accum_steps=4):
    """梯度累积:模拟 4x batch,显存不变"""
    optimizer.zero_grad()
    # 将 input 切分为 accum_steps 份
    chunks_x = x.chunk(accum_steps)
    chunks_t = target.chunk(accum_steps)
    total_loss = 0
    for cx, ct in zip(chunks_x, chunks_t):
        output = model(cx)
        loss = criterion(output, ct) / accum_steps
        loss.backward()
        total_loss += loss.item()
    optimizer.step()
    return total_loss

def train_amp(model, x, target, optimizer, criterion, scaler):
    """混合精度:FP16 前向+反向,FP32 权重更新"""
    optimizer.zero_grad()
    with autocast(dtype=torch.float16):
        output = model(x)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    return loss.item()

def train_checkpoint(model, x, target, optimizer, criterion):
    """梯度检查点:用 ~30% 额外计算换 50-80% 激活显存"""
    optimizer.zero_grad()
    output = model(x, use_checkpoint=True)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    return loss.item()

# === 修改模型支持 checkpoint ===
class CheckpointedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(2048, 2048) for _ in range(20)])
    
    def forward(self, x, use_checkpoint=False):
        for layer in self.layers:
            if use_checkpoint:
                x = checkpoint(lambda inp, l=layer: torch.relu(l(inp)), 
                              x, use_reentrant=False)
            else:
                x = torch.relu(layer(x))
        return x

# === 统一基准测试 ===
def benchmark_strategies():
    device = 'cuda'
    bs, seq, dim = 32, 512, 2048
    
    strategies = {}
    
    # Baseline
    model_base = LargeModel().to(device)
    opt_base = torch.optim.Adam(model_base.parameters(), lr=1e-3)
    crit = nn.CrossEntropyLoss()
    strategies["Baseline"] = lambda: train_baseline(
        model_base, torch.randn(bs, seq, dim, device=device),
        torch.randint(0, 10, (bs,), device=device), opt_base, crit)
    
    # Gradient Accumulation
    model_ga = LargeModel().to(device)
    opt_ga = torch.optim.Adam(model_ga.parameters(), lr=1e-3)
    strategies["Grad Accum (4x)"] = lambda: train_grad_accum(
        model_ga, torch.randn(bs, seq, dim, device=device),
        torch.randint(0, 10, (bs,), device=device), opt_ga, crit)
    
    # AMP
    model_amp = LargeModel().to(device)
    opt_amp = torch.optim.Adam(model_amp.parameters(), lr=1e-3)
    scaler = GradScaler()
    strategies["AMP (FP16)"] = lambda: train_amp(
        model_amp, torch.randn(bs, seq, dim, device=device),
        torch.randint(0, 10, (bs,), device=device), opt_amp, crit, scaler)
    
    # Gradient Checkpointing
    model_cp = CheckpointedModel().to(device)
    opt_cp = torch.optim.Adam(model_cp.parameters(), lr=1e-3)
    strategies["Grad Checkpoint"] = lambda: train_checkpoint(
        model_cp, torch.randn(bs, seq, dim, device=device),
        torch.randint(0, 10, (bs,), device=device), opt_cp, crit)
    
    # 运行基准
    print(f"{'Strategy':<22} {'Peak Mem (MB)':<16} {'Time (ms)':<12} {'Mem Saving':<12}")
    print("-" * 65)
    
    baseline_mem = None
    for name, func in strategies.items():
        # Warmup
        for _ in range(3):
            func()
        torch.cuda.synchronize()
        
        # Measure
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        
        start = time.perf_counter()
        for _ in range(10):
            func()
        torch.cuda.synchronize()
        elapsed_ms = (time.perf_counter() - start) / 10 * 1000
        
        peak_mb = torch.cuda.max_memory_allocated() / 1024**2
        if baseline_mem is None:
            baseline_mem = peak_mb
            saving = "---"
        else:
            saving = f"{(1 - peak_mb/baseline_mem)*100:.1f}%"
        
        print(f"{name:<22} {peak_mb:<16.1f} {elapsed_ms:<12.1f} {saving:<12}")

benchmark_strategies()

优化策略选择决策树

复制代码
OOM?
├─ 仅差一点 → 减小 batch size / gradient accumulation
├─ 差很多但可接受减速 → Gradient Checkpointing
├─ 需要速度+显存双收益 → AMP (BF16 > FP16)
└─ 单卡无论如何放不下 → ZeRO/FSDP 多卡分片
  • AMP 优先用 BF16 :BF16 动态范围与 FP32 相同,无需 GradScaler,训练更稳定。FP16 在 LLM 尺度下容易溢出。autocast(dtype=torch.bfloat16)
  • Checkpoint 粒度 :不要对每一层都做 checkpoint。只对激活值最大的模块(如 Attention、FFN)做,Embedding/LayerNorm 等轻量层保留缓存。过度 checkpoint 会导致重计算开销超过显存收益。
  • 梯度累积的正确性 :loss 必须除以 accum_steps,否则等效学习率被放大 N 倍。BatchNorm 在梯度累积下统计量会偏移,LLM 使用 LayerNorm/RMSNorm 无此问题。

排查显存泄漏(修复 + 诊断工具)

原始代码的问题分析

python 复制代码
def train_with_leak(model, train_loader, optimizer, criterion):
    model.train()
    losses = []
    
    for epoch in range(10):
        for inputs, targets in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.append(loss)  # BUG: 保留了完整计算图
            
            loss.backward()
            optimizer.step()
            # BUG: 缺少 optimizer.zero_grad(),梯度无限累积
    
    return losses

两个致命问题

  1. losses.append(loss) 保留了整个计算图的引用。每个 loss tensor 持有从输入到输出的所有中间激活值的引用,导致激活值永远无法释放。10 个 epoch 后显存必然 OOM。
  2. 缺少 optimizer.zero_grad(),梯度在参数上无限累加,不仅浪费显存(梯度 tensor 不会释放),还会导致训练完全错误。

修复后的代码

python 复制代码
def train_fixed(model, train_loader, optimizer, criterion, device='cuda'):
    model.train()
    losses = []
    
    for epoch in range(10):
        epoch_losses = []
        for inputs, targets in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # 只保留标量值,断开计算图引用
            epoch_losses.append(loss.item())
            
            optimizer.zero_grad(set_to_none=True)  # set_to_none 比 zero_ 更省显存
            loss.backward()
            optimizer.step()
        
        # 每个 epoch 结束后记录平均 loss,而非保留所有 step 的 tensor
        losses.append(sum(epoch_losses) / len(epoch_losses))
    
    return losses

显存泄漏诊断工具箱

python 复制代码
import gc
import torch

def diagnose_memory_leak():
    """系统化排查显存泄漏"""
    
    # Step 1: 强制 GC + 清空缓存
    gc.collect()
    torch.cuda.empty_cache()
    
    # Step 2: 快照当前所有 GPU tensor
    before_tensors = set()
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                before_tensors.add(id(obj))
        except Exception:
            pass
    
    # === 执行可疑代码 ===
    # ... your training loop ...
    
    # Step 3: 再次 GC
    gc.collect()
    torch.cuda.empty_cache()
    
    # Step 4: 找出新增的 tensor
    leaked = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda and id(obj) not in before_tensors:
                leaked.append(obj)
        except Exception:
            pass
    
    if leaked:
        print(f"  Found {len(leaked)} leaked tensors:")
        for t in leaked[:20]:  # 只显示前 20 个
            print(f"   shape={tuple(t.shape)}, dtype={t.dtype}, "
                  f"size={t.numel()*t.element_size()/1024**2:.2f}MB")
    else:
        print(" No memory leak detected.")
    
    # Step 5: 监控 reserved vs allocated 差距(碎片化指标)
    alloc = torch.cuda.memory_allocated() / 1024**2
    resv = torch.cuda.memory_reserved() / 1024**2
    frag = (resv - alloc) / resv * 100 if resv > 0 else 0
    print(f"\n Allocated: {alloc:.1f}MB | Reserved: {resv:.1f}MB | Fragmentation: {frag:.1f}%")
    if frag > 30:
        print("High fragmentation! Consider PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")

高频泄漏模式速查

模式 症状 修复
list.append(tensor) 显存随 epoch 线性增长 .append(tensor.item()).detach()
缺少 zero_grad 显存缓慢增长 + loss 异常 optimizer.zero_grad(set_to_none=True)
闭包捕获 tensor 显存不释放但找不到引用 检查 lambda/functools.partial
全局变量缓存 显存只增不减 weakref 或显式清理
Dataloader worker 泄漏 CPU 内存增长 persistent_workers=True + 检查 collate_fn
  • set_to_none=Truezero_grad(set_to_none=True) 将梯度设为 None 而非零张量,节省一份梯度显存。PyTorch 2.x 默认行为,1.x 需显式指定。
  • detach() vs item()detach() 返回新 tensor 但仍占显存;item() 返回 Python float,完全释放。记录指标永远用 item()
  • 碎片化是隐形杀手 :即使 allocated 未达上限,reserved 中的碎片也可能导致 OOM。PyTorch 2.1+ 的 expandable_segments 大幅缓解此问题,生产环境务必开启。

显存优化的优先级排序

复制代码
1. 算法级:Flash Attention(消除 T² 激活)、RoPE(无需位置嵌入缓存)
2. 系统级:Gradient Checkpointing、Activation Recomputation
3. 精度级:BF16/FP16 混合精度、INT8/INT4 量化
4. 工程级:梯度累积、batch size 调优、inplace 操作
5. 分布式:ZeRO Stage 1/2/3、FSDP、Tensor Parallelism

永远从算法级开始优化。Flash Attention 一项就能节省 80%+ 的 Attention 激活显存,效果远超其他所有工程技巧的叠加。

LLM 特有的显存陷阱

  • KV Cache:推理时 KV cache 随序列长度线性增长。7B 模型在 4K context 下 KV cache ≈ 2GB,128K context 下 ≈ 64GB。必须使用 PagedAttention (vLLM) 或 MQA/GQA 架构。
  • Tokenizer Padding:变长序列 padding 到 max_length 会浪费大量激活显存。使用 packing(多条短序列拼接)或 varlen attention。
  • Logits Tensor :vocab_size=128K 时,logits tensor (B, T, 128K) 在 FP32 下占用巨大。使用 fused cross-entropy 避免物化完整 logits。

生产级显存监控 Checklist

检查项 标准 工具
峰值显存 < GPU 总显存 90% max_memory_allocated()
碎片率 < 20% reserved - allocated
激活占比 已知且合理 Profiler Memory View
泄漏检测 多 epoch 显存稳定 gc + tensor snapshot
OOM 恢复 graceful fallback try/except + reduce batch
分布式一致性 各卡显存均衡 nvidia-smi / DDP profiler

掌握这些内容后,已具备独立诊断和解决 LLM 训练中各类显存问题的能力。下一步建议深入研究 Flash Attention 的实现原理和 ZeRO/FSDP 的分片策略,将显存优化从单卡扩展到千卡规模。

在深度学习(尤其是 LLM 训练)中,NaN/Inf 是最令人头疼的问题。它们具有传染性 (一个 NaN 会通过反向传播污染所有梯度)和隐蔽性(可能在 loss 爆炸前数十步就已产生)。掌握系统化的数值健康检查方法,是从"盲目试错"到"精准定位"的关键跨越。

调试工具函数实现

python 复制代码
import torch
import math
from typing import Dict, List, Tuple, Optional

def tensor_health(x: torch.Tensor, name: str = "tensor") -> Dict:
    """
    全面诊断张量的数值健康状态
    
    Returns:
        包含 min/max/mean/std/nan_count/inf_count 的字典
    """
    with torch.no_grad():
        # 转为 float32 统计,避免 FP16/BF16 下 inf/nan 判断失真
        xf = x.detach().float()
        
        is_nan = torch.isnan(xf)
        is_inf = torch.isinf(xf)
        finite_mask = torch.isfinite(xf)
        
        stats = {
            "name": name,
            "shape": tuple(x.shape),
            "dtype": str(x.dtype),
            "nan_count": is_nan.sum().item(),
            "inf_count": is_inf.sum().item(),
            "finite_ratio": finite_mask.float().mean().item(),
        }
        
        # 仅在存在有限值时计算统计量,否则返回 NaN 会误导
        if stats["finite_ratio"] > 0:
            xf_finite = xf[finite_mask]
            stats.update({
                "min": xf_finite.min().item(),
                "max": xf_finite.max().item(),
                "mean": xf_finite.mean().item(),
                "std": xf_finite.std().item(),
                "abs_max": xf_finite.abs().max().item(),
            })
        else:
            stats.update({"min": None, "max": None, "mean": None, "std": None, "abs_max": None})
        
        return stats


def has_nonfinite(x: torch.Tensor) -> bool:
    """快速检查张量是否包含 NaN 或 Inf"""
    # isfinite 比 isnan + isinf 更快(单次 kernel launch)
    return not torch.isfinite(x).all().item()


def gradient_l2_norm(model: torch.nn.Module) -> float:
    """
    计算模型所有梯度的全局 L2 范数
    等价于 torch.nn.utils.clip_grad_norm_ 的内部计算(但不裁剪)
    """
    total_norm_sq = 0.0
    for p in model.parameters():
        if p.grad is not None:
            # 用 float32 累加,避免 FP16 梯度平方溢出
            g = p.grad.detach().float()
            total_norm_sq += g.norm(2).item() ** 2
    
    return math.sqrt(total_norm_sq)


def find_nonfinite_grad_names(model: torch.nn.Module) -> List[str]:
    """
    找出所有梯度中包含 NaN/Inf 的参数名
    用于精确定位问题源头
    """
    bad_names = []
    for name, p in model.named_parameters():
        if p.grad is not None and has_nonfinite(p.grad):
            nan_c = torch.isnan(p.grad).sum().item()
            inf_c = torch.isinf(p.grad).sum().item()
            bad_names.append(f"{name} (nan={nan_c}, inf={inf_c})")
    return bad_names

关键设计决策解析

设计点 错误做法 正确做法 原因
统计精度 直接在 FP16 上算 mean/std .detach().float() 后统计 FP16 的 inf 阈值仅 65504,正常大值会被误判;BF16 虽好但 std 精度仍不足
非有限值处理 x.mean() 含 NaN → 结果 NaN 先 mask 再统计 一个 NaN 污染整个统计量,完全失去诊断价值
L2 Norm 累加 FP16 下直接累加平方和 float32 累加 FP16 平方极易溢出为 inf,导致 norm 永远为 inf
Non-finite 检测 `isnan(x) isinf(x)` ~isfinite(x)
梯度访问 p.grad.data p.grad.detach() .data 绕过 autograd 追踪,可能导致隐性 bug;.detach() 是安全标准

LLM 训练中 NaN/Inf 的系统化排查流程

当训练中出现 NaN 时,按以下决策树逐层缩小范围:

复制代码
Loss = NaN?
├── Step 0 就 NaN → 初始化/数据问题
│   ├── 检查输入数据:has_nonfinite(batch) / tensor_health(batch)
│   ├── 检查权重初始化:tensor_health(param) for param in model.parameters()
│   └── 检查 LR:过大导致第一步就发散
│
├── 训练中途突然 NaN → 数值不稳定
│   ├── 1. 定位首个 NaN step(二分法 / anomaly detection)
│   ├── 2. 该 step 前向检查:每层输出 tensor_health()
│   │   └── 找到首个 NaN 层 → 检查该层的输入/权重/操作
│   ├── 3. 该 step 反向检查:find_nonfinite_grad_names(model)
│   │   └── 梯度 NaN 的层 ≠ 前向 NaN 的层 → 反向传播中的数值问题
│   └── 4. 检查 loss 本身:log(0)、除零、softmax 全 mask
│
└── 逐渐增大后 NaN → 梯度爆炸
    ├── gradient_l2_norm(model) 监控趋势
    ├── 加 grad clipping(max_norm=1.0)
    ├── 降低 LR / 增加 warmup
    └── 切换 FP16 → BF16(动态范围大 8 个数量级)

LLM 特有的高频 NaN 陷阱

Softmax 全 Mask 行

python 复制代码
# 因果掩码 + padding 导致某行全为 -inf
scores = scores.masked_fill(~mask, float('-inf'))
attn = F.softmax(scores, dim=-1)  # softmax([-inf,...,-inf]) = [NaN,...,NaN]

# 修复:softmax 后将全 mask 行置零
attn = F.softmax(scores, dim=-1)
attn = attn.masked_fill(~mask.any(dim=-1, keepdim=True), 0.0)

这是 LLM 训练中最常见的 NaN 来源,尤其在变长序列 packing 场景下。

Log Probabilities 中的 log(0)

python 复制代码
# CrossEntropyLoss 内部 log(softmax),若 softmax 输出精确 0 → log(0) = -inf
# PyTorch CE 已做数值稳定,但自定义 loss 需注意
# 使用 log_softmax + nll_loss 组合,或直接 F.cross_entropy(fused 且稳定)

FP16 下的梯度平方溢出

FP16 最大值为 65504。当梯度 > 255 时,grad² > 65504 → overflow to inf → Adam 的 v 变为 inf → 更新量变为 NaN。BF16 最大值约 3.4×10³⁸,彻底消除这一问题。这也是为什么现代 LLM 训练几乎全部使用 BF16。

Gradient Checkpointing 的重计算不一致

如果 checkpoint 包裹的模块在前向和重计算时行为不同(如 dropout 未固定 seed、BN 在 train/eval 间切换),重计算的激活值与原始值不一致,导致梯度错误甚至 NaN。

python 复制代码
# 确保 checkpoint 内的操作是确定性的
# use_reentrant=False(PyTorch 2.x 推荐)自动处理大部分情况
checkpoint(fn, *args, use_reentrant=False)

集成到训练循环的调试 Hook

python 复制代码
class NaNDebugger:
    """插入训练循环的自动 NaN 检测器"""
    
    def __init__(self, model: torch.nn.Module, check_every: int = 1):
        self.model = model
        self.check_every = check_every
        self.step = 0
    
    def after_backward(self, loss: torch.Tensor):
        self.step += 1
        if self.step % self.check_every != 0:
            return
        
        # 1. 检查 loss
        if has_nonfinite(loss):
            print(f" Step {self.step}: Loss is non-finite!")
            print(f"   Loss value: {loss.item()}")
            bad = find_nonfinite_grad_names(self.model)
            print(f"   Bad grads: {bad[:5]}")  # 只显示前 5 个
            raise RuntimeError(f"NaN detected at step {self.step}")
        
        # 2. 检查梯度范数趋势
        gnorm = gradient_l2_norm(self.model)
        if gnorm > 100.0:
            print(f"  Step {self.step}: Grad norm = {gnorm:.2f} (abnormally high)")
        
        # 3. 可选:抽样检查中间激活(开销较大,仅在排查时开启)
        # for name, buf in self.model.named_buffers():
        #     if has_nonfinite(buf):
        #         print(f" Buffer '{name}' is non-finite: {tensor_health(buf)}")

调试工具选型速查

场景 工具 开销 适用阶段
快速检查单个 tensor has_nonfinite() 极低 实时 / 生产
详细诊断异常 tensor tensor_health() 排查时
监控梯度整体健康 gradient_l2_norm() 每步 / 每 N 步
定位具体坏梯度参数 find_nonfinite_grad_names() NaN 发生时
自动捕获首次 NaN torch.autograd.set_detect_anomaly(True) 极高 (10-100x) 仅小样本复现
大规模训练 NaN 定位 自定义 Hook + 日志 可控 生产训练

detect_anomaly 警告 :此模式会为每个算子插入额外检查,使训练速度下降 10-100 倍。绝不要在全量训练中开启。正确用法:用小数据集 + 小模型复现 NaN → 开 anomaly → 获得精确堆栈 → 关闭 anomaly → 修复代码 → 全量验证。

掌握这四个函数和排查流程后,已具备独立应对 LLM 训练中各类数值异常的能力。下一步建议结合 torch.compile 的数值一致性检查和分布式训练中的跨 rank NaN 同步检测,将调试能力扩展到千卡规模。

相关推荐
AndrewHZ3 小时前
【LLM技术全景】大模型能力探秘:In-Context Learning与思维链(CoT)
人工智能·语言模型·大模型·llm·cot·思维链·icl
枫子有风4 小时前
LLM-Agent智能体(大厂面试常问)
面试·职场和发展·llm·agent
昵称好难啊4 小时前
7.OpenClaw源码解析——可靠消息投递
人工智能·llm·agent
董厂长7 小时前
Loop Engineering:停止手动提示,开始设计自动提示的系统
大数据·人工智能·驱动开发·llm
把你拉进白名单8 小时前
7.OpenClaw源码解析——可靠消息投递
人工智能·llm·agent
武子康8 小时前
调查研究-180 roboflow/supervision:计算机视觉工程里的“胶水层“,为什么值得关注?
人工智能·opencv·计算机视觉·chatgpt·llm·向量化
Liigo9 小时前
【AI对话实录】大模型自行删减原文并编造虚假URL链接
ai·llm·deepseek·liigo·faking
chenjim1 天前
你的 Agent 是个黑箱:eBPF 如何看见它真正在做什么
llm·agent
Lkstar1 天前
万字长文Query改写与多路召回实战|从HyDE到RRF融合,召回率提升22%的完整方案
数据库·人工智能·llm