【AI知识点】模型训练优化之——混合精度训练

混合精度训练

混合精度训练是现代深度学习训练中的关键技术,它通过在不同计算环节使用不同精度(fp32, fp16, bf16)的数值表示来加速训练并减少内存占用。

为什么需要混合精度?

深度学习模型训练默认使用 32 位浮点数(FP32) 进行计算和参数存储,但实践中发现:

  • 计算效率:FP16(16 位浮点数)或 BF16(脑浮点数)的计算速度比 FP32 快 2-8 倍(尤其在支持 CUDA 的 GPU 上,如 NVIDIA 的 Tensor Core 专门优化低精度计算)。

  • 内存占用:低精度数据类型的内存占用仅为 FP32 的 1/2(FP16/BF16),可支持更大的 batch size、更深的模型或更高分辨率的输入。

  • 精度冗余:模型参数和计算过程中存在精度冗余,并非所有操作都需要 FP32 精度才能保持模型性能。

混合精度训练的核心是 "按需分配精度":对精度敏感的操作(如参数更新、损失计算)保留高精度(FP32),对精度不敏感的计算(如卷积、矩阵乘法)使用低精度(FP16/BF16),兼顾效率与精度。

混合精度训练中各个阶段的参数精度
  1. 模型初始化: 模型权重以 FP32 形式存储,保证权重的精确性。
  2. 前向传播阶段: 前向传播时,会复制一份 FP32 格式的权重并强制转化为 FP16 格式进行计算,利用 FP16 计算速度快和显存占用少的优势加速运算。
  3. 损失计算阶段: 通常与前向传播一致,使用 FP16 精度计算损失
  4. 损失缩放阶段: FP16 精度 。由于反向传播采用 FP16 格式计算梯度,而损失值可能很小,容易出现数值稳定性问题(如梯度下溢),所以引入损失缩放。将损失值乘以一个缩放因子,把可能下溢的数值提升到 FP16 可以表示的范围,确保梯度在 FP16 精度下能被有效表示。
  5. 反向传播阶段: 计算权重的梯度(FP16 精度),以加快计算速度。
  6. 权重更新阶段: 先将FP16 梯度反缩放(除以缩放因子,恢复原始幅值),此时梯度仍为 FP16,然后将其转换为 FP32 ,用于优化器更新,然后用FP32的梯度(AdamW的FP32的一阶矩和二阶矩)更新 FP32 的权重
python 复制代码
import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for data, target in dataloader:
    optimizer.zero_grad()
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    # 反向传播:先缩放损失,再计算梯度(避免 FP16 梯度下溢)
    scaler.scale(loss).backward()
    #反缩放(因为梯度裁剪需要在原始梯度上进行)
    scaler.unscale_(optimizer)
    # 梯度裁剪(可选,防止梯度爆炸)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
     # 更新参数:用缩放后的梯度更新,内部会自动调整缩放因子
    scaler.step(optimizer)
    # 更新缩放因子
    scaler.update()
相关推荐
醒了就刷牙7 小时前
MovieNet
论文阅读·人工智能·论文笔记
传说故事7 小时前
【论文自动阅读】RoboBrain 2.0
人工智能·具身智能
MaoziShan7 小时前
[ICLR 2026] 一文读懂 AutoGEO:生成式搜索引擎优化(GEO)的自动化解决方案
人工智能·python·搜索引擎·语言模型·自然语言处理·内容运营·生成式搜索引擎
LS_learner7 小时前
理解Clawdbot 的本质
人工智能
方见华Richard7 小时前
整数阶时间重参数化:基于自适应豪斯多夫维数的偏微分方程正则化新框架
人工智能·笔记·交互·原型模式·空间计算
盼小辉丶7 小时前
PyTorch实战(27)——自动混合精度训练
pytorch·深度学习·混合精度训练
aihuangwu7 小时前
如何把豆包的回答导出
人工智能·ai·deepseek·ds随心转
好奇龙猫7 小时前
【人工智能学习-AI入试相关题目练习-第十六次】
人工智能·学习
bing.shao7 小时前
Golang 开发者视角:解读《“人工智能 + 制造” 专项行动》的技术落地机遇
人工智能·golang·制造
LOnghas12117 小时前
玉米目标检测实战:基于YOLO13-C3k2-RFAConv的优化方案_1
人工智能·目标检测·计算机视觉