【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
CSTechEi5 分钟前
【SPIE/EI/Scopus检索】2026 年第三届数据挖掘与自然语言处理国际会议 (DMNLP 2026)
人工智能·自然语言处理·数据挖掘
GJGCY6 分钟前
技术剖析:智能体工作流与RPA流程自动化的架构差异与融合实现
人工智能·经验分享·ai·自动化·rpa
UMI赋能企业10 分钟前
制造业流程自动化提升生产力的全面分析
大数据·人工智能
MediaTea31 分钟前
Python 第三方库:matplotlib(科学绘图与数据可视化)
开发语言·python·信息可视化·matplotlib
说私域37 分钟前
“开源AI大模型AI智能名片S2B2C商城小程序”视角下的教育用户策略研究
人工智能·小程序
草莓熊Lotso40 分钟前
C++ 方向 Web 自动化测试入门指南:从概念到 Selenium 实战
前端·c++·python·selenium
gddkxc1 小时前
AI CRM中的数据分析:悟空AI CRM如何帮助企业优化运营
人工智能·信息可视化·数据分析
我是李武涯1 小时前
PyTorch Dataloader工作原理 之 default collate_fn操作
pytorch·python·深度学习
AI视觉网奇1 小时前
Python 检测运动模糊 源代码
人工智能·opencv·计算机视觉
东隆科技1 小时前
PRIMES推出SFM 2D全扫描场分析仪革新航空航天LPBF激光增材制造
人工智能·制造