PyTorch下,使用list放置模块,导致计算设备不一的报错

报错

在复现 Transformer 代码的训练阶段时,发生报错:

bash 复制代码
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

解决方案

通过next(linear.parameters()).device确定 model 已经在 cuda:0 上了,同时输入 model.forward()的张量也位于 cuda:0。输入的张量没什么好推敲的,于是考虑到模型具有多层结构,遂输出每层结构的设备信息,model.encoder -> model.encoder.sublayer[0] ··· ···

测试发现,model.encoder.sublayer[0] 之后的模块的设备信息均位于 cpu,原因是构造这部分模块时,由于需要多个相同的模块,使用了 list 来存放模块:

python 复制代码
# module: 需要深拷贝的模块
# n: 拷贝的次数
# return: 深拷贝后的模块列表
def clones(module, n: int) -> list:
    return [copy.deepcopy(module) for _ in range(n)]

显然 list 不支持 GPU,需要用 PyTorch 提供的代替:

python 复制代码
def clones(module, n: int):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

ModuleList 把子模块存入列表,能像 Python 里普通的列表被索引,最重要的是能使内部的模块被正确注册,并对所有的 Module 方法可见。[Source]

成功解决!

相关环境

bash 复制代码
python                    3.11.7               he1021f5_0
pytorch                   2.1.2           py3.11_cuda12.1_cudnn8_0    
相关推荐
有为少年1 小时前
Welford 算法 | 优雅地计算海量数据的均值与方差
人工智能·深度学习·神经网络·学习·算法·机器学习·均值算法
Ven%1 小时前
从单轮问答到连贯对话:RAG多轮对话技术详解
人工智能·python·深度学习·神经网络·算法
谈笑也风生1 小时前
经典算法题型之复数乘法(二)
开发语言·python·算法
阿_旭1 小时前
【PyTorch】20个核心概念详解:从基础到实战的深度学习指南
人工智能·pytorch·深度学习
先知后行。2 小时前
python的类
开发语言·python
dyxal2 小时前
Python包导入终极指南:子文件如何成功调用父目录模块
开发语言·python
nnerddboy2 小时前
解决传统特征波段选择的不可解释性:2. SHAP和LIME
python·机器学习
电商API&Tina2 小时前
【电商API接口】关于电商数据采集相关行业
java·python·oracle·django·sqlite·json·php
yy_xzz2 小时前
002 PyTorch实战:神经网络回归任务 - 气温预测
pytorch·神经网络·回归
weixin_421585012 小时前
解释代码:val_pred = vxm_model.predict(val_input)--与tensor对比
python