【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
weixin_75033552几秒前
李沐 X 动手学深度学习--第九章 现代循环神经网络
人工智能·rnn·深度学习
摸鱼仙人~1 分钟前
深度学习数据集划分比例多少合适
人工智能·深度学习
Niuguangshuo3 分钟前
Python 设计模式:外观模式
python·设计模式·外观模式
矩阵猫咪6 分钟前
基于时间卷积网络TCN实现电力负荷多变量时序预测(PyTorch版)
pytorch·深度学习·tcn·时序预测·时间卷积网络·电力负荷
Blossom.11822 分钟前
《探索边缘计算:重塑未来智能物联网的关键技术》
人工智能·深度学习·神经网络·物联网·机器学习·计算机视觉·边缘计算
果冻人工智能26 分钟前
探索 AI 思维的剖析
人工智能
XINVRY-FPGA1 小时前
Xilinx FPGA XCVC1902-2MSEVSVA2197 Versal AI Core系列芯片的详细介绍
人工智能·嵌入式硬件·5g·ai·fpga开发·云计算·fpga
wgc2k1 小时前
吴恩达深度学习复盘(6)神经网络的矢量化原理
python·深度学习·矩阵
jndingxin1 小时前
OpenCV 图形API(16)将极坐标(magnitude 和 angle)转换为笛卡尔坐标(x 和 y)函数polarToCart()
人工智能·opencv·计算机视觉
?Agony1 小时前
P17_ResNeXt-50
人工智能·pytorch·python·算法