【DataWhale组队学习】DIY-LLM Task2 PyTorch 与资源核算

原文链接

0. 引言:此章并非Pytorch入门

这一章表面上在讲pytorch的张量、矩阵乘法、反向传播、参数初始化、数据加载、优化器、训练循环,但真正的重点应该是从资源的角度重新理一遍这些本来就会写的代码。

我对pytorch的训练原本也不算陌生,至少会写个能跑的baseline,但显然我们的时间金钱不是无限的,学了这章我需要开始关注每一步到底在算什么、吃了多少显存、要多少FLOPs、为什么要特定的精度格式等等,对模型训练需要的时间和内存有一个大概的概念。

此篇笔记将不再赘述我原本熟悉的东西。

1. 张量

显然张量并不是一个数组或者一个单纯的容器,在pytorch中输入、参数、激活值、梯度、优化器状态,本质上都可以看成张量,是整个训练过程最基础的表达形式。

1.1 view、transpose和contiguous

张量看上去是多维的,但层本质上仍是一块线性内存。viewtranspose这样的操作不会重新复制一遍数据,而只是改变了张量对底层内存的解释方式,数据在内存上物理的排列没有改变,这叫做zero-copy,十分高效。

transpose也会带来non-contiguous的问题,即便数据在物理上的排列没有改变,但数据的逻辑排列还是变了,即元数据形状欸或步长变了,逻辑索引顺序与内存存储顺序不再一致,一个转置之后的张量不能再直接view,必须先.contiguous(),而.contiguous()会创建新的连续内存有额外的内存和时间开销。

1.2 einops和jaxtyping

用原生的.view().transpose()操作维度时,需要时刻记住张量的维度顺序,很容易搞晕,一旦模型维度变复杂,代码就很容易变得难读。

课程推荐用jaxtyping给维度加标签,用einops做einsumreducerearrange这类操作,尽管增加了少量语法开销,但其清晰的维度命名显著降低了调试难度,更多用法可见 Einops tutorial

2. 内存

2.1 内存不只属于参数

除了参数本身,还有梯度、激活值和优化器状态,在资源核算时也要按这几类来算。

不能简单地用 总显存 / 参数字节数 去简单估算模型是否可跑,尤其是用了 AdamW这类优化器,优化器状态本身会占掉很大一部分显存。

2.2 FP32、FP16和BF16

  • FP32 训练稳定,但显存和计算代价高。
  • FP16 省显存、速度快,但动态范围太小,训练过程中容易溢出造成致命后果。
  • BF16 大小不变的情况下,牺牲一点精度,保留了更大的动态范围,训练里更稳,现在的主流选择。
  • FP8 极致压缩显存,精度极低,训练极不稳定,主要用于推理的量化。

深度学习训练里,损失小数点后10位的精度并无所谓,但一旦溢出就是致命的,所以BF16截断了FP32的尾数,保留了指数位,它能表示的数值范围与FP32一样大,极大地提升了训练稳定性,优于FP16。

3. 矩阵乘法、FLOPs和MFU

3.1 矩阵乘法是主要计算开销来源

深度学习模型里大部分的计算都来自矩阵乘法,无论是线性层还是Transformer底层核心都离不开它。

如果输入x的形状是 (B, D),权重w的形状是(D, K),一次矩阵乘法可以写成:

C i k = ∑ D A i j B j k C_{ik} = \sum^{D} A_{ij}B_{jk} Cik=∑DAijBjk

乘法 D D D 次,加法 D − 1 D-1 D−1 次,则单次元素的浮点运算数近似为 2 × D 2 \times D 2×D 次

总输出元素个数为 B × K B\times K B×K

那么y = x @ w的FLOPs近似为:

2 × B × D × K 2 \times B \times D \times K 2×B×D×K

这个公式虽然简单,但它是后面训练成本估算的基础。

3.2 FLOPs、FLOP/s 和 MFU

几个概念:

  • FLOPs:总计算量
  • FLOP/s:每秒浮点运算次数
  • MFU:实际吞吐相对理论峰值的利用率,MFU=实测FLOPS硬件理论峰值 / FLOPS,MFU >= 0.5被认为是相当不错的性能,但这个公式忽略了通信和系统开销,只关注纯粹的计算效率。

4. 训练开销

文档里用线性模型推出了一个很重要的经验结论:

  • 前向传播约为 2 × 参数量 × token数
  • 反向传播大约是前向的两倍
  • 所以前向 + 反向总共可近似写成:
    6 × 参数量 × token数 6 \times \text{参数量} \times \text{token数} 6×参数量×token数
    它背后其实就是矩阵乘法在forward和backward里的工作量拆分。

5. 模型训练的数据加载

  • numpy.memmap
    • 对于超大规模语料,可以用numpy.memmap把磁盘文件映射成一个"按需访问"的映射对象,就像一个磁盘上的"指针"一样,当真正访问某一段数据时,系统才会把那部分数据调入内存。
  • pin_memory()
    • 默认情况下,CPU张量放在可分页内存里,而GPU在搬运数据前,需要先复制到固定的非分页内存区域,调用pin_memory()后,张量会被放到固定内存中,这样GPU能更直接地访问,减少一次中间拷贝,提高传输效率。
  • non_blocking=True
    • 如果数据已经在固定内存里,那么.to(device, non_blocking=True)时,传输就可以尽量异步地进行,不必让Python线程原地等待。这样一来,GPU处理当前batch的同时,CPU还能继续准备下一个batch,不容易出现GPU等数据空转。

6. 理解与反思

本章核心是带来了资源核算的视角,以往我基本只关注模型运行效果,现在需要兼顾显存、精度、FLOPs、优化器状态、数据传输等硬件与系统层面问题,将训练从单纯的优化问题,拓展为资源约束下的系统工程问题。

资源开销估算也建立起了对训练成本的直观认知,能快速判断模型配置的开销与性能瓶颈。承接前文分词器的基础内容,本章补齐了模型完整训练的核心逻辑,是后续学习大模型训练细节的关键铺垫。

相关推荐
Elastic 中国社区官方博客2 小时前
Elastic Security、Observability 和 Search 现在在你的 AI 工具中提供交互式 UI
大数据·运维·人工智能·elasticsearch·搜索引擎·安全威胁分析·可用性测试
一碗白开水一2 小时前
【目标跟踪综述】目标跟踪近3年技术研究,全面了解目标跟踪发展
人工智能·计算机视觉·目标跟踪
Promise微笑3 小时前
AI搜索时代的流量重构:GEO优化深度执行细节与把控体系
人工智能·重构
言萧凡_CookieBoty3 小时前
比 Vibe Coding 更可怕的,是 Vibe Design 吧
人工智能·ai编程
Rick19933 小时前
Spring AI 如何进行权限控制
人工智能·python·spring
Theodore_10223 小时前
深度学习(15):倾斜数据集 & 精确率-召回率权衡
人工智能·笔记·深度学习·机器学习·知识图谱
IT_陈寒3 小时前
SpringBoot自动配置这破玩意儿又坑我一次
前端·人工智能·后端
TechubNews3 小时前
Base 发布首个独立 OP Stack 框架的网络升级 Azul,将是 L2 自主迭代的开端?
大数据·网络·人工智能·区块链·能源
啦啦啦_99993 小时前
1.机器学习概述
人工智能·机器学习