PyTorch 深度学习实战(36):混合精度训练与梯度缩放

在上一篇文章中,我们探讨了图生成模型与分子设计。本文将深入介绍混合精度训练(Mixed Precision Training) 和**梯度缩放(Gradient Scaling)**技术,这些技术可以显著加速模型训练并减少显存占用,同时保持模型精度。我们将使用PyTorch的AMP(Automatic Mixed Precision)模块在图像分类任务上实现这些技术。

一、混合精度训练基础

1. 精度类型对比

精度类型 位数 范围 内存占用 计算速度
FP32 32位 ~1e-38 to ~3e38 4字节/参数 基准速度
FP16 16位 ~6e-5 to 65504 2字节/参数 2-8倍加速
BF16 16位 ~1e-38 to ~3e38 2字节/参数 类似FP16

2. 混合精度训练三大组件

python 复制代码
class MixedPrecisionComponents:
    def __init__(self):
        self.fp16_operations = ["矩阵乘法", "卷积"]  # 适合FP16的操作
        self.fp32_operations = ["Softmax", "LayerNorm"]  # 需要FP32精度的操作
        self.gradient_scaling = True  # 防止梯度下溢

3. 混合精度训练流程

复制代码

二、混合精度训练实战

1. 环境配置

bash 复制代码
pip install torch torchvision torchmetrics

2. 基础实现(手动模式)

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
import torch.nn.functional as F
​
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
​
# 模型定义
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        self.pool = nn.MaxPool2d(2, 2)
​
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
​
​
# 初始化
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()  # 梯度缩放器
​
# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
​
​
# 训练循环
def train_epoch(epoch):
    model.train()
    total_loss = 0
​
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
​
        # 启用自动混合精度
        with autocast('cuda'):
            output = model(data)
            loss = criterion(output, target)
​
        # 梯度缩放反向传播
        scaler.scale(loss).backward()
​
        # 梯度缩放优化器步进
        scaler.step(optimizer)
​
        # 更新缩放器
        scaler.update()
​
        total_loss += loss.item()
​
        if batch_idx % 100 == 0:
            print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")
​
    return total_loss / len(train_loader)
​
​
# 训练多个epoch
for epoch in range(1, 11):
    avg_loss = train_epoch(epoch)
    print(f"Epoch {epoch} completed. Avg Loss: {avg_loss:.4f}")

输出为:

python 复制代码
Files already downloaded and verified
Epoch: 1 | Batch: 0/196 | Loss: 2.3047
Epoch: 1 | Batch: 100/196 | Loss: 1.3887
Epoch 1 completed. Avg Loss: 1.4791
Epoch: 2 | Batch: 0/196 | Loss: 1.2960
Epoch: 2 | Batch: 100/196 | Loss: 1.0821
Epoch 2 completed. Avg Loss: 1.0934
Epoch: 3 | Batch: 0/196 | Loss: 1.0498
Epoch: 3 | Batch: 100/196 | Loss: 0.9498
Epoch 3 completed. Avg Loss: 0.9368
Epoch: 4 | Batch: 0/196 | Loss: 0.8334
Epoch: 4 | Batch: 100/196 | Loss: 0.6887
Epoch 4 completed. Avg Loss: 0.8291
Epoch: 5 | Batch: 0/196 | Loss: 0.6790
Epoch: 5 | Batch: 100/196 | Loss: 0.8170
Epoch 5 completed. Avg Loss: 0.7436
Epoch: 6 | Batch: 0/196 | Loss: 0.5595
Epoch: 6 | Batch: 100/196 | Loss: 0.6540
Epoch 6 completed. Avg Loss: 0.6649
Epoch: 7 | Batch: 0/196 | Loss: 0.5427
Epoch: 7 | Batch: 100/196 | Loss: 0.5254
Epoch 7 completed. Avg Loss: 0.5915
Epoch: 8 | Batch: 0/196 | Loss: 0.5462
Epoch: 8 | Batch: 100/196 | Loss: 0.5190
Epoch 8 completed. Avg Loss: 0.5130
Epoch: 9 | Batch: 0/196 | Loss: 0.4183
Epoch: 9 | Batch: 100/196 | Loss: 0.4018
Epoch 9 completed. Avg Loss: 0.4439
Epoch: 10 | Batch: 0/196 | Loss: 0.5110
Epoch: 10 | Batch: 100/196 | Loss: 0.3564
Epoch 10 completed. Avg Loss: 0.3754

3. 自动混合精度(AMP)高级配置

python 复制代码
class PrecisionDebugger:
    def __init__(self, model):
        self.model = model
        self.fp16_tensors = []
        self.fp32_tensors = []
​
    def track_precision(self):
        self.fp16_tensors = []
        self.fp32_tensors = []
        for name, param in self.model.named_parameters():
            if param.dtype == torch.float16:
                self.fp16_tensors.append(name)
            else:
                self.fp32_tensors.append(name)
​
        print("FP16参数:", self.fp16_tensors)
        print("FP32参数:", self.fp32_tensors)
​
    def detect_overflow(self, scaler):
        # 新版本PyTorch的检查方式
        if scaler.is_enabled():
            # 获取缩放器内部状态
            scale = scaler.get_scale()
            # 检查是否有溢出发生
            if scaler._found_inf.item() if hasattr(scaler, '_found_inf') else False:
                print("警告: 检测到梯度溢出!当前缩放因子:", scale)
                return True
        return False
    
# 高级AMP配置
def train_with_custom_amp(epochs=10):
    # 创建梯度缩放器,带自定义参数
    scaler = GradScaler(
        init_scale=2. ** 16,  # 初始缩放因子
        growth_factor=2.0,  # 增长因子
        backoff_factor=0.5,  # 回退因子
        growth_interval=2000,  # 增长间隔
        enabled=True  # 可动态启用/禁用
    )
    model.train()
    total_loss = 0
​
    # 初始化调试器
    debugger = PrecisionDebugger(model)
​
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
​
            # 自定义autocast区域
            with autocast('cuda', dtype=torch.float16, cache_enabled=True):
                # 此区域内操作自动选择合适精度
                output = model(data)
                loss = criterion(output, target)
​
            # 带剪裁的梯度缩放
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)  # 取消缩放以进行剪裁
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
​
            # 每100批次检查一次
            if batch_idx % 100 == 0:
                # 打印训练信息
                print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")
​
                # 使用调试器检查精度和溢出
                debugger.track_precision()
                if debugger.detect_overflow(scaler):
                    print("梯度溢出检测到,已自动调整缩放因子")
​
                # 打印当前缩放因子
                print(f"当前缩放因子: {scaler.get_scale()}")
    return total_loss / len(train_loader)
​
train_with_custom_amp()

输出为:

python 复制代码
Epoch: 0 | Batch: 0/196 | Loss: 0.3563
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 65536.0
Epoch: 0 | Batch: 100/196 | Loss: 0.2108
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 65536.0
Epoch: 1 | Batch: 0/196 | Loss: 0.1676
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 1 | Batch: 100/196 | Loss: 0.2804
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 2 | Batch: 0/196 | Loss: 0.1914
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 2 | Batch: 100/196 | Loss: 0.1878
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 3 | Batch: 0/196 | Loss: 0.1588
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 3 | Batch: 100/196 | Loss: 0.1360
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 4 | Batch: 0/196 | Loss: 0.0888
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 4 | Batch: 100/196 | Loss: 0.1061
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 5 | Batch: 0/196 | Loss: 0.0906
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 5 | Batch: 100/196 | Loss: 0.0912
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 6 | Batch: 0/196 | Loss: 0.1038
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 6 | Batch: 100/196 | Loss: 0.0700
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 7 | Batch: 0/196 | Loss: 0.0480
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 7 | Batch: 100/196 | Loss: 0.0719
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 8 | Batch: 0/196 | Loss: 0.0289
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 8 | Batch: 100/196 | Loss: 0.0637
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 9 | Batch: 0/196 | Loss: 0.0236
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0
Epoch: 9 | Batch: 100/196 | Loss: 0.0420
FP16参数: []
FP32参数: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
当前缩放因子: 32768.0

三、关键技术解析

1. 梯度缩放原理

梯度缩放解决FP16下溢问题:

  1. 前向传播:FP16计算

  2. 损失计算:FP32

  3. 反向传播:FP16梯度

  4. 梯度缩放:将梯度乘以缩放因子S(FP32)

  5. 参数更新:梯度/S后更新FP32主权重

数学表达:

2. 精度问题调试技巧

python 复制代码
class PrecisionDebugger:
    def __init__(self, model):
        self.model = model
        self.fp16_tensors = []
        self.fp32_tensors = []
​
    def track_precision(self):
        self.fp16_tensors = []
        self.fp32_tensors = []
        for name, param in self.model.named_parameters():
            if param.dtype == torch.float16:
                self.fp16_tensors.append(name)
            else:
                self.fp32_tensors.append(name)
​
        print("FP16参数:", self.fp16_tensors)
        print("FP32参数:", self.fp32_tensors)
​
    def detect_overflow(self, scaler):
        # 新版本PyTorch的检查方式
        if scaler.is_enabled():
            # 获取缩放器内部状态
            scale = scaler.get_scale()
            # 检查是否有溢出发生
            if scaler._found_inf.item() if hasattr(scaler, '_found_inf') else False:
                print("警告: 检测到梯度溢出!当前缩放因子:", scale)
                return True
        return False

3. 混合精度最佳实践

场景 推荐配置 理由
大batch训练 init_scale=2**16, growth_factor=2.0 需要更大缩放因子
小batch训练 init_scale=2**10, growth_factor=1.5 梯度更稳定
不稳定模型 禁用部分层AMP 防止数值问题
多GPU训练 保持相同缩放因子 确保一致性

四、性能对比实验

1. 基准测试代码

python 复制代码
def benchmark_training(precision='fp32', batch_size=256):
    model = CNN().to(device)
    data = torch.randn(batch_size, 3, 32, 32, device=device)
    target = torch.randint(0, 10, (batch_size,), device=device)

    if precision == 'fp16':
        scaler = GradScaler()

    # 预热
    for _ in range(10):
        if precision == 'fp16':
            with autocast('cuda'):
                output = model(data)
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    # 正式测试
    stmt = """
    if precision == 'fp16':
        with autocast('cuda'):
            output = model(data)
            loss = criterion(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    """

    timer = Timer(
        stmt=stmt,
        globals={'model': model, 'data': data, 'target': target,
                 'criterion': criterion, 'optimizer': optimizer,
                 'precision': precision, 'scaler': scaler if precision == 'fp16' else None,
                 'autocast': autocast}
    )

    result = timer.timeit(100)
    print(f"{precision.upper()} 平均耗时: {result.mean * 1000:.2f}ms")
    return result


# 运行测试
fp32_result = benchmark_training('fp32')
fp16_result = benchmark_training('fp16')
print(f"加速比: {fp32_result.mean / fp16_result.mean:.2f}x")

输出为:

python 复制代码
FP32 平均耗时: 6.79ms
FP16 平均耗时: 4.04ms
加速比: 1.68x

2. 典型测试结果

精度模式 训练时间 显存占用 最终准确率
FP32 基准1.0x 100% 92.3%
FP16 (无缩放) 1.8x 55% 训练失败
AMP (带缩放) 2.5x 60% 92.1%

五、总结与展望

本文详细介绍了混合精度训练与梯度缩放技术,关键要点包括:

  1. 完整的AMP实现:从基础使用到高级配置

  2. 梯度缩放原理:数学推导与实现细节

  3. 性能优化技巧:调试方法与最佳实践

在下一篇文章中,我们将探讨分布式训练(DP/DDP/Deepspeed)实战,介绍如何将模型训练扩展到多GPU和多节点环境。

相关推荐
nancy_princess3 小时前
clip实验
人工智能·深度学习
AI医影跨模态组学3 小时前
Radiother Oncol 空军军医大学西京医院等团队:基于纵向CT的亚区域放射组学列线图预测食管鳞状细胞癌根治性放化疗后局部无复发生存期
人工智能·深度学习·医学影像·影像组学
A尘埃3 小时前
神经网络的激活函数+损失函数
人工智能·深度学习·神经网络·激活函数
没有不重的名么4 小时前
Pytorch深度学习快速入门教程
人工智能·pytorch·深度学习
有为少年4 小时前
告别“唯语料论”:用合成抽象数据为大模型开智
人工智能·深度学习·神经网络·算法·机器学习·大模型·预训练
AI医影跨模态组学4 小时前
J Transl Med(IF=7.5)苏州大学附属第一医院秦颂兵教授等团队:基于机器学习影像组学的食管鳞癌预后评估列线图
人工智能·深度学习·机器学习·ct·医学·医学影像
Birdy_x4 小时前
接口自动化项目实战(1):requests请求封装
开发语言·前端·python
我爱学习好爱好爱4 小时前
Ansible 常用模块详解:lineinfile、replace、get_url实战
linux·python·ansible
一个处女座的程序猿6 小时前
LLMs之Scaling Law之Cross-Entropy:《What Scales in Cross-Entropy Scaling Law?》翻译与解读
深度学习·scaling law·cross-entropy
一轮弯弯的明月6 小时前
Python基础-速通秘籍(下)
开发语言·笔记·python·学习