【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
一水鉴天5 分钟前
整体设计 全面梳理复盘之30 Transformer 九宫格三层架构 Designer 全部功能定稿(初稿)之2
前端·人工智能
luoganttcc6 分钟前
DiffusionVLA 与BridgeVLA 相比 在 精度和成功率和效率上 有什么 优势
人工智能·算法
飞哥数智坊7 分钟前
TRAE CN + K2 Thinking,我试着生成了一个简版的在线 PS
人工智能·ai编程·trae
moeyui7058 分钟前
Python文件编码读取和处理整理知识点
开发语言·前端·python
caiyueloveclamp13 分钟前
AI一键生成PPT的实用软件与网站推荐TOP10
人工智能·powerpoint·ai生成ppt·aippt·免费aippt
张较瘦_16 分钟前
[论文阅读] AI+ | AI重构工业数字孪生!新一代iDTS破解数据稀缺、智能不足难题,附3大落地案例
论文阅读·人工智能·重构
Studying 开龙wu19 分钟前
目标检测标注工具常用的三种:LabelImg、CVAT、Roboflow
人工智能·目标检测·计算机视觉
程序员爱钓鱼28 分钟前
Python编程实战 - Python实用工具与库 - 正则表达式匹配(re 模块)
后端·python·面试
程序员爱钓鱼30 分钟前
Python编程实战 - Python实用工具与库 - 爬取并存储网页数据
后端·python·面试
bin915332 分钟前
PHP文档保卫战:AI自动生成下的创意守护与反制指南
开发语言·人工智能·php·工具·ai工具