【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
CodeDevMaster2 分钟前
browser-use:AI驱动的浏览器自动化工具使用指南
python·llm
程序员辣条4 分钟前
深度测评 RAG 应用评估框架:指标最全面的 RAGas
人工智能·程序员
curdcv_po5 分钟前
字节跳动Trae:一款革命性的免费AI编程工具完全评测
人工智能·trae
程序员辣条6 分钟前
为什么需要提示词工程?什么是提示词工程(prompt engineering)?为什么需要提示词工程?收藏我这一篇就够了!
人工智能·程序员·产品经理
孔令飞10 分钟前
Go:终于有了处理未定义字段的实用方案
人工智能·云原生·go
清流君23 分钟前
【MySQL】数据库 Navicat 可视化工具与 MySQL 命令行基本操作
数据库·人工智能·笔记·mysql·ue5·数字孪生
Blossom.11830 分钟前
人工智能在智能家居中的应用与发展
人工智能·深度学习·机器学习·智能家居·vr·虚拟现实·多模态融合
biter008832 分钟前
ubuntu(28):ubuntu系统多版本conda和多版本cuda共存
linux·人工智能·ubuntu·conda
内网渗透32 分钟前
Python 虚拟环境管理:venv 与 conda 的选择与配置
开发语言·python·conda·虚拟环境·venv
薄荷很无奈40 分钟前
CuML + Cudf (RAPIDS) 加速python数据分析脚本
python·机器学习·数据分析·gpu算力