梯度检查点是通过只保存部分中间激活值、反向时重算前向来节省显存的技术,能降低40%~60%显存但增加15%~30%训练时间,要求模块前向可重入且无副作用。梯度检查点是什么,为什么能省显存梯度检查点(torch.utils.checkpoint.checkpoint)不是"不存梯度",而是"不存中间激活值"。反向传播时需要前向计算的中间结果来算梯度,常规训练会把所有层的输出全存着,显存占用和网络深度线性增长;而检查点只存部分层的输入,反向时临时重跑对应前向------用时间换空间。典型节省比例:ResNet-50 训练 batch size 可从 16 提到 32,ViT-L 类模型显存常降 40%~60%。但注意:它只对**前向可重入、无副作用**的模块有效,比如不能包裹含 nn.Dropout(训练态随机行为不可复现)或修改全局状态的自定义层。怎么加 checkpoint,最简安全写法别直接套整个模型,先从单个 nn.Sequential 或自定义 forward 块开始。PyTorch 官方推荐方式是用 checkpoint.checkpoint 包裹函数调用,而不是用装饰器(后者容易隐式捕获非 tensor 参数)。必须确保被包裹函数只接收 Tensor 参数,且不依赖闭包变量(如 self.training)若模块含 training 切换逻辑(如 Dropout),改用 checkpoint.checkpoint_sequential 或手动拆分 + torch.no_grad() 重算示例:对 Transformer 层列表做检查点from torch.utils.checkpoint import checkpointdef custom_forward(x, layer): return layer(x)# 替换原循环:x = layer(x)x = checkpoint(custom_forward, x, layer)常见报错和绕过方法RuntimeError: Trying to backward through the graph a second time:说明 checkpoint 内部用了被复用的 Tensor(比如共享 embedding),或者你在检查点外又对同一张量调了 backward()。根本原因是计算图被意外保留。 VWO 一个A/B测试工具
相关推荐
米高梅狮子1 小时前
13.ETCD 存储系统、生产环境 Kubernetes 集群部署和Kubernetes 集群升级Yupureki1 小时前
《MySQL数据库基础》6.表的增删查改北顾笙9801 小时前
MySQL-day1QQ24221997910 小时前
基于python+微信小程序的家教管理系统_mh3j9RSTJ_162510 小时前
PYTHON+AI LLM DAY THREETY-SEVEN阿波罗尼亚10 小时前
数据库序列(Sequence)郝学胜-神的一滴10 小时前
深度学习优化核心:梯度下降与网络训练全解析Aision_10 小时前
Agent 为什么需要 Checkpoint?清水白石00810 小时前
《Python性能深潜:从对象分配开销到“小对象风暴”的破解之道(含实战与最佳实践)》