torch.cuda.amp.GradScaler
是 PyTorch 中的一个用于自动混合精度(Automatic Mixed Precision, AMP)训练的工具。AMP 允许在训练深度学习模型时动态切换浮点数的精度(例如,使用半精度浮点数 float16
而非 float32
),以减少显存占用和加速计算,同时保持模型的精度。
1. GradScaler
的作用
在混合精度训练中,模型的某些部分以半精度(float16
)计算,而其他部分仍然以全精度(float32
)计算。使用 float16
进行计算可以显著提高计算速度和减少显存占用,但也可能导致数值不稳定或梯度下溢(gradient underflow)。GradScaler
通过动态缩放损失值来缓解这些问题,并在反向传播过程中对缩放后的梯度进行适当调整,确保训练过程稳定。
2. 混合精度训练的基本步骤
2. 1. 初始化 GradScaler
:
scaler = torch.cuda.amp.GradScaler()
2. 2. 在前向传播中使用 autocast
上下文管理器:
在模型的前向传播中,使用 torch.cuda.amp.autocast
上下文管理器将部分计算切换到半精度。
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
在 PyTorch 中,autocast
是一个用于自动混合精度训练的上下文管理器。
2. 3 使用 scaler.scale
缩放损失并反向传播:
在计算损失并调用 backward()
前,通过 scaler.scale()
对损失进行缩放。
scaler.scale(loss).backward()
2. 4 使用 scaler.step
进行优化器更新:
使用 scaler.step()
来执行优化器的 step()
操作。
scaler.step(optimizer)
2.5 调用 scaler.update
: 通过 scaler.update()
来更新缩放因子,并根据需要调整精度。
scaler.update()
3. 完整的示例代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
# 假设你有一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 初始化模型、损失函数和优化器
model = SimpleModel().cuda()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
# 初始化 GradScaler
scaler = GradScaler()
# 假设你有输入和目标
inputs = torch.randn(64, 10).cuda()
targets = torch.randn(64, 1).cuda()
# 训练循环中的一次前向和反向传播
for epoch in range(10):
optimizer.zero_grad()
# 前向传播,使用 autocast 进行混合精度计算
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播,使用 scaler 进行梯度缩放
scaler.scale(loss).backward()
# 使用 scaler 进行优化器步进
scaler.step(optimizer)
# 更新缩放因子
scaler.update()
print(f"Epoch [{epoch+1}/10], Loss: {loss.item()}")
4. 总结
torch.cuda.amp.GradScaler
是用于混合精度训练的工具,通过动态缩放损失值来提高数值稳定性。- 使用
autocast
上下文管理器来自动处理前向传播中的精度切换。 - 在反向传播和优化器更新时,通过
scaler
来处理损失缩放和梯度计算。
混合精度训练能够在现代 GPU 上显著提升训练速度和效率,同时通过 GradScaler
保持模型的稳定性和精度。