1. 图像数据格式:灰度与彩色数据
- 灰度图像:单通道数据,每个像素用 0~255 或 0~1 表示亮度,无色彩信息,存储空间小。
- 彩色图像 :常见为 RGB 三通道,每个通道分别表示红、绿、蓝分量,也有 RGBA(含透明度)、HSV 等格式;在深度学习中通常以
[C, H, W]或[H, W, C]张量形式存储。
2. 模型的定义
在深度学习语境下,模型是由多层神经网络(如卷积层、全连接层)堆叠而成的计算图,包含可学习参数(权重、偏置),用于从输入数据中学习特征并完成特定任务(如分类、分割)。
3. 显存占用的 4 个核心部分
- a. 模型参数 + 梯度参数模型权重本身占用显存,训练时还需存储对应梯度,两者数量级一致(如 AdamW 优化器下,权重和梯度各占一份)。
- b. 优化器参数如 Adam 类优化器会维护一阶矩(momentum)、二阶矩(variance)等状态变量,这部分显存通常与模型参数相当,甚至更多。
- c. 数据批量所占显存输入图像批次(batch)、标签等数据会加载到显存中,batch size 越大,这部分占用越高。
- d. 神经元输出中间状态各层的激活值、中间特征图等前向传播结果,训练时需保留用于反向传播计算梯度,是显存占用的重要来源。
4. batch size 与训练的关系
- 显存层面:batch size 越大,输入数据、中间特征图占用显存越多,越容易触发显存不足;反之可减小 batch size 或使用梯度累积。
- 训练稳定性:较大的 batch size 能降低梯度噪声,使训练过程更稳定,收敛更平滑;过小的 batch size 会导致梯度波动大,可能影响收敛效果。
- 泛化能力:适度小的 batch size 往往带来更好的泛化性能,其噪声相当于一种隐式正则化;过大的 batch size 可能导致模型过拟合。
- 计算效率:在硬件允许范围内,增大 batch size 可提高 GPU 利用率,加快单步训练速度,但总迭代次数会减少,需权衡 epoch 数。
💡 补充提示:在实际训练中,可通过混合精度训练 、梯度检查点 、梯度累积等技术优化显存使用,在有限硬件下平衡 batch size 与训练效果。
优先用:低代价优化
- 混合精度
- PyTorch: torch.cuda.amp.autocast() + GradScaler
- 收益:显存减半,速度提升
- 风险:极少,注意梯度缩放
- 梯度检查点
- PyTorch: torch.utils.checkpoint.checkpoint
- 收益:中间激活显存大幅减少
- 代价:训练速度变慢约 20%
- 梯度累积
- 每 N 步执行一次 optimizer.step (),等效增大 batch size
- 收益:不额外占用显存,等效扩大 batch
- 注意:学习率需要根据步数调整
- 减小 batch size 或输入分辨率
- 显存占用与 batch size 成正比,与分辨率平方成正比
- 收益:效果立竿见影
- 代价:训练稳定性略有下降
进阶用:精度敏感场景
- LoRA/QLoRA:大模型微调首选,显存占用可降低 90%
- FlashAttention:Transformer 架构专用,大幅减少注意力层显存占用
- ZeRO/FSDP:多卡环境下将参数 / 梯度 / 优化器状态分片,极致节省单卡显存
- 参数剪枝 / 共享:压缩模型参数数量,需要微调恢复精度
临时清理显存技巧
- 删除无用变量:del 变量名,再执行 torch.cuda.empty_cache () 释放 PyTorch 缓存
- 推理阶段使用 torch.no_grad () 禁用梯度计算
- 避免在计算图中保留不必要的中间变量
显存估算公式
总显存 ≈ 模型参数 (4B) + 梯度 (4B) + 优化器状态 (8B/Adam) + 数据 (batchC HW4B) + 中间激活 (≈模型参数量级)
避坑提醒
- 先定位显存瓶颈:使用 torch.cuda.memory_summary () 查看显存分配详情
- 不要同时开启过多优化手段,避免调试复杂度急剧上升
- 小模型优先调整 batch size / 混合精度,大模型优先尝试 LoRA/ZeRO
快速排查流程
- 开启混合精度 → 2. 尝试梯度检查点 → 3. 使用梯度累积 → 4. 降低 batch size / 分辨率 → 5. 采用进阶优化方案