PyTorch下,使用list放置模块,导致计算设备不一的报错

报错

在复现 Transformer 代码的训练阶段时,发生报错:

bash 复制代码
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

解决方案

通过next(linear.parameters()).device确定 model 已经在 cuda:0 上了,同时输入 model.forward()的张量也位于 cuda:0。输入的张量没什么好推敲的,于是考虑到模型具有多层结构,遂输出每层结构的设备信息,model.encoder -> model.encoder.sublayer[0] ··· ···

测试发现,model.encoder.sublayer[0] 之后的模块的设备信息均位于 cpu,原因是构造这部分模块时,由于需要多个相同的模块,使用了 list 来存放模块:

python 复制代码
# module: 需要深拷贝的模块
# n: 拷贝的次数
# return: 深拷贝后的模块列表
def clones(module, n: int) -> list:
    return [copy.deepcopy(module) for _ in range(n)]

显然 list 不支持 GPU,需要用 PyTorch 提供的代替:

python 复制代码
def clones(module, n: int):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

ModuleList 把子模块存入列表,能像 Python 里普通的列表被索引,最重要的是能使内部的模块被正确注册,并对所有的 Module 方法可见。[Source]

成功解决!

相关环境

bash 复制代码
python                    3.11.7               he1021f5_0
pytorch                   2.1.2           py3.11_cuda12.1_cudnn8_0    
相关推荐
码农小白AI几秒前
AI报告审核进入技术驱动时代:IACheck如何从规则引擎走向深度学习,构建检测报告审核“技术矩阵”
人工智能·深度学习
ZC跨境爬虫6 分钟前
Scrapy分布式爬虫(单机模拟多节点):豆瓣Top250项目设置与数据流全解析
分布式·爬虫·python·scrapy
Zzj_tju18 分钟前
大语言模型技术指南:SFT、RLHF、DPO 怎么串起来?对齐训练与关键参数详解
人工智能·深度学习·语言模型
sg_knight19 分钟前
设计模式实战:命令模式(Command)
python·设计模式·命令模式
石榴树下的七彩鱼24 分钟前
图片修复 API 接入实战:网站如何自动去除图片水印(Python / PHP / C# 示例)
图像处理·后端·python·c#·php·api·图片去水印
Polar__Star37 分钟前
C#怎么操作Chart图表控件 C#如何用WinForms Chart控件绑定数据绘制统计图表【控件】
jvm·数据库·python
2401_8971905541 分钟前
CSS如何制作数字滚动效果_利用transform位移数字
jvm·数据库·python
2301_813599552 小时前
HTML图片怎么用UnoCSS对齐_UnoCSS原子化CSS图片对齐实战
jvm·数据库·python
m0_377618232 小时前
c++怎么在不加载整个大文件的情况下获取其SHA256校验值【进阶】
jvm·数据库·python
LN花开富贵2 小时前
【ROS】鱼香ROS2学习笔记二
linux·笔记·python·学习·嵌入式