【AI知识点】模型训练优化之——混合精度训练

混合精度训练

混合精度训练是现代深度学习训练中的关键技术,它通过在不同计算环节使用不同精度(fp32, fp16, bf16)的数值表示来加速训练并减少内存占用。

为什么需要混合精度?

深度学习模型训练默认使用 32 位浮点数(FP32) 进行计算和参数存储,但实践中发现:

  • 计算效率:FP16(16 位浮点数)或 BF16(脑浮点数)的计算速度比 FP32 快 2-8 倍(尤其在支持 CUDA 的 GPU 上,如 NVIDIA 的 Tensor Core 专门优化低精度计算)。

  • 内存占用:低精度数据类型的内存占用仅为 FP32 的 1/2(FP16/BF16),可支持更大的 batch size、更深的模型或更高分辨率的输入。

  • 精度冗余:模型参数和计算过程中存在精度冗余,并非所有操作都需要 FP32 精度才能保持模型性能。

混合精度训练的核心是 "按需分配精度":对精度敏感的操作(如参数更新、损失计算)保留高精度(FP32),对精度不敏感的计算(如卷积、矩阵乘法)使用低精度(FP16/BF16),兼顾效率与精度。

混合精度训练中各个阶段的参数精度
  1. 模型初始化: 模型权重以 FP32 形式存储,保证权重的精确性。
  2. 前向传播阶段: 前向传播时,会复制一份 FP32 格式的权重并强制转化为 FP16 格式进行计算,利用 FP16 计算速度快和显存占用少的优势加速运算。
  3. 损失计算阶段: 通常与前向传播一致,使用 FP16 精度计算损失
  4. 损失缩放阶段: FP16 精度 。由于反向传播采用 FP16 格式计算梯度,而损失值可能很小,容易出现数值稳定性问题(如梯度下溢),所以引入损失缩放。将损失值乘以一个缩放因子,把可能下溢的数值提升到 FP16 可以表示的范围,确保梯度在 FP16 精度下能被有效表示。
  5. 反向传播阶段: 计算权重的梯度(FP16 精度),以加快计算速度。
  6. 权重更新阶段: 先将FP16 梯度反缩放(除以缩放因子,恢复原始幅值),此时梯度仍为 FP16,然后将其转换为 FP32 ,用于优化器更新,然后用FP32的梯度(AdamW的FP32的一阶矩和二阶矩)更新 FP32 的权重
python 复制代码
import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for data, target in dataloader:
    optimizer.zero_grad()
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    # 反向传播:先缩放损失,再计算梯度(避免 FP16 梯度下溢)
    scaler.scale(loss).backward()
    #反缩放(因为梯度裁剪需要在原始梯度上进行)
    scaler.unscale_(optimizer)
    # 梯度裁剪(可选,防止梯度爆炸)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
     # 更新参数:用缩放后的梯度更新,内部会自动调整缩放因子
    scaler.step(optimizer)
    # 更新缩放因子
    scaler.update()
相关推荐
郝学胜-神的一滴3 小时前
主成分分析(PCA)在计算机图形学中的深入解析与应用
开发语言·人工智能·算法·机器学习·1024程序员节
飞哥数智坊3 小时前
想用好 AI 编程?你可能得先学点管理
人工智能·ai编程
golang学习记3 小时前
太卷了,蚂蚁又发布了新一代Code Agent!
人工智能
StarPrayers.3 小时前
神经网络中的 HWC→CHW 格式转换
人工智能·深度学习·神经网络
ModelWhale4 小时前
和鲸科技入选《大模型一体机产业图谱》,以一体机智驱科研、重塑教学
人工智能·科研·高等教育
区块block4 小时前
DeFi中的自主代理:用AI重塑金融
人工智能·金融
数据科学作家4 小时前
如何入门python机器学习?金融从业人员如何快速学习Python、机器学习?机器学习、数据科学如何进阶成为大神?
大数据·开发语言·人工智能·python·机器学习·数据分析·统计分析
GJGCY4 小时前
金融智能体技术解读:十大应用场景与AI Agent架构设计思路
人工智能·经验分享·ai·金融·自动化
文火冰糖的硅基工坊4 小时前
[人工智能-大模型-57]:模型层技术 - 软件开发的不同层面(如底层系统、中间件、应用层等),算法的类型、设计目标和实现方式存在显著差异。
人工智能·算法·中间件
Coovally AI模型快速验证4 小时前
突破性开源模型DepthLM问世:视觉语言模型首次实现精准三维空间理解
人工智能·语言模型·自然语言处理·ocr·音视频·ai编程