PyTorch下,使用list放置模块,导致计算设备不一的报错

报错

在复现 Transformer 代码的训练阶段时,发生报错:

bash 复制代码
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

解决方案

通过next(linear.parameters()).device确定 model 已经在 cuda:0 上了,同时输入 model.forward()的张量也位于 cuda:0。输入的张量没什么好推敲的,于是考虑到模型具有多层结构,遂输出每层结构的设备信息,model.encoder -> model.encoder.sublayer[0] ··· ···

测试发现,model.encoder.sublayer[0] 之后的模块的设备信息均位于 cpu,原因是构造这部分模块时,由于需要多个相同的模块,使用了 list 来存放模块:

python 复制代码
# module: 需要深拷贝的模块
# n: 拷贝的次数
# return: 深拷贝后的模块列表
def clones(module, n: int) -> list:
    return [copy.deepcopy(module) for _ in range(n)]

显然 list 不支持 GPU,需要用 PyTorch 提供的代替:

python 复制代码
def clones(module, n: int):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

ModuleList 把子模块存入列表,能像 Python 里普通的列表被索引,最重要的是能使内部的模块被正确注册,并对所有的 Module 方法可见。[Source]

成功解决!

相关环境

bash 复制代码
python                    3.11.7               he1021f5_0
pytorch                   2.1.2           py3.11_cuda12.1_cudnn8_0    
相关推荐
七夜zippoe几秒前
异步编程实战:构建高性能Python网络应用
开发语言·python·websocket·asyncio·aiohttp
tianyuanwo1 分钟前
Python虚拟环境深度解析:从virtualenv到virtualenvwrapper
开发语言·python·virtualenv
越甲八千1 分钟前
ORM 的优势
数据库·python
斯外戈的小白2 分钟前
【NLP】Transformer在pytorch 的实现+情感分析案例+生成式任务案例
pytorch·自然语言处理·transformer
是有头发的程序猿4 分钟前
Python爬虫防AI检测实战指南:从基础到高级的规避策略
人工智能·爬虫·python
grd44 分钟前
Electron for OpenHarmony 实战:Pagination 分页组件实现
python·学习
CryptoRzz5 分钟前
印度交易所 BSE 与 NSE 实时数据 API 接入指南
java·c语言·python·区块链·php·maven·symfony
山土成旧客12 分钟前
【Python学习打卡-Day35】从黑盒到“玻璃盒”:掌握PyTorch模型可视化、进度条与推理
pytorch·python·学习
@zulnger13 分钟前
python 学习笔记(循环)
笔记·python·学习
cyyt15 分钟前
深度学习周报(25.12.29~26.1.4)
人工智能·深度学习