【AI知识点】模型训练优化之——混合精度训练

混合精度训练

混合精度训练是现代深度学习训练中的关键技术,它通过在不同计算环节使用不同精度(fp32, fp16, bf16)的数值表示来加速训练并减少内存占用。

为什么需要混合精度?

深度学习模型训练默认使用 32 位浮点数(FP32) 进行计算和参数存储,但实践中发现:

  • 计算效率:FP16(16 位浮点数)或 BF16(脑浮点数)的计算速度比 FP32 快 2-8 倍(尤其在支持 CUDA 的 GPU 上,如 NVIDIA 的 Tensor Core 专门优化低精度计算)。

  • 内存占用:低精度数据类型的内存占用仅为 FP32 的 1/2(FP16/BF16),可支持更大的 batch size、更深的模型或更高分辨率的输入。

  • 精度冗余:模型参数和计算过程中存在精度冗余,并非所有操作都需要 FP32 精度才能保持模型性能。

混合精度训练的核心是 "按需分配精度":对精度敏感的操作(如参数更新、损失计算)保留高精度(FP32),对精度不敏感的计算(如卷积、矩阵乘法)使用低精度(FP16/BF16),兼顾效率与精度。

混合精度训练中各个阶段的参数精度
  1. 模型初始化: 模型权重以 FP32 形式存储,保证权重的精确性。
  2. 前向传播阶段: 前向传播时,会复制一份 FP32 格式的权重并强制转化为 FP16 格式进行计算,利用 FP16 计算速度快和显存占用少的优势加速运算。
  3. 损失计算阶段: 通常与前向传播一致,使用 FP16 精度计算损失
  4. 损失缩放阶段: FP16 精度 。由于反向传播采用 FP16 格式计算梯度,而损失值可能很小,容易出现数值稳定性问题(如梯度下溢),所以引入损失缩放。将损失值乘以一个缩放因子,把可能下溢的数值提升到 FP16 可以表示的范围,确保梯度在 FP16 精度下能被有效表示。
  5. 反向传播阶段: 计算权重的梯度(FP16 精度),以加快计算速度。
  6. 权重更新阶段: 先将FP16 梯度反缩放(除以缩放因子,恢复原始幅值),此时梯度仍为 FP16,然后将其转换为 FP32 ,用于优化器更新,然后用FP32的梯度(AdamW的FP32的一阶矩和二阶矩)更新 FP32 的权重
python 复制代码
import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for data, target in dataloader:
    optimizer.zero_grad()
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    # 反向传播:先缩放损失,再计算梯度(避免 FP16 梯度下溢)
    scaler.scale(loss).backward()
    #反缩放(因为梯度裁剪需要在原始梯度上进行)
    scaler.unscale_(optimizer)
    # 梯度裁剪(可选,防止梯度爆炸)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
     # 更新参数:用缩放后的梯度更新,内部会自动调整缩放因子
    scaler.step(optimizer)
    # 更新缩放因子
    scaler.update()
相关推荐
aopstudio2 小时前
如何优雅地清理Hugging Face缓存到本地的模型文件(2025最新版)
人工智能·python·缓存·语言模型
叫我木子2 小时前
c语言,识别到黑色就自动开枪,4399单击游戏狙击战场,源码分享,豆包ai出品
c语言·人工智能·游戏
skywalk81633 小时前
Ascend C算子开发能力认证考试伴侣-昇腾Ascend C编程入门教程
人工智能·昇腾·ascendc
中电金信3 小时前
中电金信携手海光推出金融业云原生基础设施联合解决方案
大数据·人工智能
用户5191495848453 小时前
Braintree iOS Drop-in SDK - 一站式支付解决方案
人工智能·aigc
科技小郑3 小时前
吱吱企业即时通讯以安全为基,重塑安全办公新体验
大数据·网络·人工智能·安全·信息与通信·吱吱企业通讯
就叫飞六吧3 小时前
生产环境禁用AI框架工具回调:安全风险与最佳实践
人工智能·安全
胡乱编胡乱赢3 小时前
关于在pycharm终端连接服务器
人工智能·深度学习·pycharm·终端连接服务器
聚客AI4 小时前
⚠️Embedding选型指南:五步搞定数据规模、延迟与精度平衡!
人工智能·llm·掘金·日新计划