深度学习中的混合精度训练

一、混合精度训练是什么?

混合精度训练 = 训练时同时用 FP16/BF16(快、省显存)+ FP32(稳、防梯度消失),

在几乎不掉精度的前提下,大幅提速、省显存。

·FP32:默认精度,稳,但慢、占显存

·FP16 / BF16:半精度,快、显存减半,但容易梯度下溢(变 0)

·混合精度:让该快的快,该稳的稳

二、核心原理(3 个关键点)

1. 为什么要用半精度?

显存减半:参数 / 激活从 4B → 2B

计算更快:GPU Tensor Core 专门加速 FP16/BF16

带宽压力减半:数据搬运更快

2. 直接全用 FP16 会崩,为什么?

梯度通常非常小 → FP16 表示范围窄 → 梯度直接变成 0 → 模型不收敛。

3. 混合精度怎么解决?

三大技术:

FP32 权重主副本(参数更新必须用 FP32)

损失缩放 Loss Scaling(把 loss 放大,防止梯度下溢)

自动精度调度(conv/matmul 用 FP16,loss/softmax 用 FP32)

三、混合精度训练具体怎么实现(标准流程)

标准 5 步(NVIDIA 官方)

1.模型权重存 FP32 主副本

权重实际存在 FP32主本,FP16的副本

2.前向传播用 FP16

前向计算用 权重的FP16 副本。

3.Loss 缩放(FP16 需要)

loss × scale → 梯度变大 → 不会变成 0,更新参数前 ÷ scale → 恢复真实值

4.反向传播算 FP16 梯度

反向传播时使用FP16

5.梯度转回 FP32,更新主权重

梯度算完 → 转回 FP32 更新主权重

四、哪些必须跑 FP32?

Softmax,LayerNorm / BatchNorm,Loss 计算,小梯度累加操作等操作需使用FP32,调用autocast 会自动实现混合精度计算,不用手动实现。

相关推荐
若丶相见1 小时前
AI 大模型零基础知识扫盲
人工智能
猿人谷2 小时前
不只是 CPU 阈值:STAR 如何用 GAT + Transformer 做容器级自动扩缩容?
人工智能·算法
说了很好4 小时前
PyTorch从零搭建DDPM:时间嵌入+UNet网络+扩散调度完整复现
人工智能
Bigfish_coding4 小时前
前端转agent-【python】-06 长期记忆(向量数据库 + 嵌入)
人工智能
小林ixn4 小时前
别再手写Prompt了!用AI Loop实现自动化自我迭代,效率提升10倍
人工智能·自动化运维
说了很好4 小时前
逐行注释DDPM源码:正向加噪、逆向去噪、MSE损失全流程复现
人工智能
Dilee4 小时前
Spring AI 1.1.7 接入 MCP:Filesystem Server 最小 Demo
人工智能·后端
Token炼金师4 小时前
大模型推理超参数原理详解
人工智能
Token炼金师4 小时前
大模型训练超参数:从Loss曲面到收敛策略的底层逻辑
人工智能
后端小肥肠4 小时前
Skill 囤了一堆却用不起来?我用 Codex 写了个整理神器
人工智能·agent