【AI知识点】模型训练优化之——混合精度训练

混合精度训练

混合精度训练是现代深度学习训练中的关键技术,它通过在不同计算环节使用不同精度(fp32, fp16, bf16)的数值表示来加速训练并减少内存占用。

为什么需要混合精度?

深度学习模型训练默认使用 32 位浮点数(FP32) 进行计算和参数存储,但实践中发现:

  • 计算效率:FP16(16 位浮点数)或 BF16(脑浮点数)的计算速度比 FP32 快 2-8 倍(尤其在支持 CUDA 的 GPU 上,如 NVIDIA 的 Tensor Core 专门优化低精度计算)。

  • 内存占用:低精度数据类型的内存占用仅为 FP32 的 1/2(FP16/BF16),可支持更大的 batch size、更深的模型或更高分辨率的输入。

  • 精度冗余:模型参数和计算过程中存在精度冗余,并非所有操作都需要 FP32 精度才能保持模型性能。

混合精度训练的核心是 "按需分配精度":对精度敏感的操作(如参数更新、损失计算)保留高精度(FP32),对精度不敏感的计算(如卷积、矩阵乘法)使用低精度(FP16/BF16),兼顾效率与精度。

混合精度训练中各个阶段的参数精度
  1. 模型初始化: 模型权重以 FP32 形式存储,保证权重的精确性。
  2. 前向传播阶段: 前向传播时,会复制一份 FP32 格式的权重并强制转化为 FP16 格式进行计算,利用 FP16 计算速度快和显存占用少的优势加速运算。
  3. 损失计算阶段: 通常与前向传播一致,使用 FP16 精度计算损失
  4. 损失缩放阶段: FP16 精度 。由于反向传播采用 FP16 格式计算梯度,而损失值可能很小,容易出现数值稳定性问题(如梯度下溢),所以引入损失缩放。将损失值乘以一个缩放因子,把可能下溢的数值提升到 FP16 可以表示的范围,确保梯度在 FP16 精度下能被有效表示。
  5. 反向传播阶段: 计算权重的梯度(FP16 精度),以加快计算速度。
  6. 权重更新阶段: 先将FP16 梯度反缩放(除以缩放因子,恢复原始幅值),此时梯度仍为 FP16,然后将其转换为 FP32 ,用于优化器更新,然后用FP32的梯度(AdamW的FP32的一阶矩和二阶矩)更新 FP32 的权重
python 复制代码
import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for data, target in dataloader:
    optimizer.zero_grad()
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    # 反向传播:先缩放损失,再计算梯度(避免 FP16 梯度下溢)
    scaler.scale(loss).backward()
    #反缩放(因为梯度裁剪需要在原始梯度上进行)
    scaler.unscale_(optimizer)
    # 梯度裁剪(可选,防止梯度爆炸)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
     # 更新参数:用缩放后的梯度更新,内部会自动调整缩放因子
    scaler.step(optimizer)
    # 更新缩放因子
    scaler.update()
相关推荐
不当菜鸡的程序媛11 分钟前
https://duoke360.com/post/35063
人工智能
IT_陈寒16 分钟前
SpringBoot3踩坑实录:一个@Async注解让我多扛了5000QPS
前端·人工智能·后端
_Meilinger_25 分钟前
碎片笔记|生成模型原理解读:AutoEncoder、GAN 与扩散模型图像生成机制
人工智能·生成对抗网络·gan·扩散模型·图像生成·diffusion model
Listennnn1 小时前
BEV query 式图片点云视觉特征融合
人工智能
DS-RAG1 小时前
万方智能体投票火热进行中~
人工智能
semantist@语校1 小时前
语校网500所里程碑:日本语言学校数据库的标准化与可追溯机制
大数据·数据库·人工智能·百度·语言模型·oracle·github
key062 小时前
《数据出境安全评估办法》企业应对策略
网络·人工智能·安全
key062 小时前
数据安全能力成熟度模型 (DSMM) 核心要点
大数据·人工智能
山海青风2 小时前
藏语自然语言处理入门 - 3 找关键词
人工智能·自然语言处理
Java与Android技术栈2 小时前
AI Coding 让我两天完成图像编辑器 Monica 的国际化与多主题
人工智能