【代码模板】Pytorch AMP 混合精度训练

背景

当使用AMP混合精度训练时,可以提升训练速度,并降低对显存的占用。下面提供一个使用AMP训练的代码demo。

Demo

python 复制代码
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

参考

Automatic Mixed Precision

相关推荐
西西o几秒前
SpringAi GA1.0.0入门到源码完整系列课
人工智能·语言模型
IT_陈寒几秒前
Vite 5个隐藏功能大揭秘:90%的开发者都不知道这些提速技巧!
前端·人工智能·后端
得贤招聘官4 分钟前
第六代AI面试智能体:重塑招聘流程的高效解决方案
人工智能·面试·职场和发展
阿杰学AI6 分钟前
AI核心知识27——大语言模型之AI Agent(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·agent·ai agent
视***间8 分钟前
视程空间展示亮相强悍的机器人AI运算模组
人工智能
whaosoft-1439 分钟前
51c视觉~合集54
人工智能
九千七52613 分钟前
sklearn学习(4)K近邻(KNN)
人工智能·学习·机器学习·sklearn·knn·近邻搜索
沫儿笙15 分钟前
kuka库卡弧焊接机器人保护气节气装置
人工智能·物联网·机器人
路边草随风16 分钟前
flink实现变更算子checkpoint断点续传依然生效
大数据·人工智能·flink
西猫雷婶18 分钟前
CNN卷积计算|多维卷积核自动计算
人工智能·pytorch·深度学习·神经网络·机器学习·cnn