【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
PILIPALAPENG4 分钟前
第3周 Day 2:Function Calling —— 让 Agent 听懂人话,自己干活
前端·人工智能·python
阿里云大数据AI技术15 分钟前
PAI Physical AI Notebook详解8:Isaac Lab Arena 全身机器人机动+操控工作流
人工智能
高木木的博客29 分钟前
数字架构智能化测试平台(1)--总纲
人工智能·python·nginx·架构
wanghowie31 分钟前
11. AI 客服系统架构设计:不是调 API,而是系统工程
人工智能·系统架构
zhangchaoxies34 分钟前
golang如何使用SQLx原生SQL查询_golang SQLx原生SQL查询使用方法
jvm·数据库·python
m0_7436239234 分钟前
mysql如何优化InnoDB缓冲池大小_mysql缓冲池内存调优
jvm·数据库·python
袋鼠云数栈UED团队37 分钟前
基于 OpenSpec 实现规范驱动开发
前端·人工智能
m0_6178814237 分钟前
如何操作 XML 数据_XMLTYPE 与 EXTRACT 函数解析节点
jvm·数据库·python
Raink老师39 分钟前
【AI面试临阵磨枪】什么是 Tokenization?子词分词(Subword)的优缺点?
人工智能·ai 面试
qq_3345635542 分钟前
golang如何实现SSTable持久化_golang SSTable持久化实现要点
jvm·数据库·python