【Bug1】RuntimeError: DataLoader worker (pid(s) 15904) exited unexpectedly
环境
python
Windows 11
Python 3.10
torch 2.0.1
numpy 1.25.0
问题详情
在使用 PyTorch 的 DataLoader
时出现的错误。详情
python
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
.....
RuntimeError: DataLoader worker (pid(s) ) exited unexpectedly
意思是,这是运行时错误,是由于在主进程完成初始化之前试图启动了新的进程导致的错误。
错误的代码示例
python
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 假设我们有一些简单的数据
data = np.array([1,2,3,4,5,6,7]) # np, tensor 格式都可以
targets = torch.tensor([1,1,1,1,0,0,0]) # 标签
# 定义自定义数据集
class SimpleDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
def __len__(self):
return len(self.data)
# 实例化数据集
dataset = SimpleDataset(data, targets)
# 创建 DataLoader, 如果启动多线程num_workers>=1,需要将启动代码放置在 if __name__ == "__main__": 下, 否则会报错
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
# 使用 DataLoader 迭代数据
for i, (batch_data, batch_label) in enumerate(dataloader):
print(f"Batch {i}: batch_data: {batch_data}, batch_label: {batch_label}")
解决方法
【方法1】(不推荐)
这是由于多线程加载数据使用不当使用的错误,因此可以设置为单线程即可,num_workers 不设置或设置为0
。
python
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)
【方法2】
将涉及dataloader 的代码放置在if __name__ == "__main__":
下运行,修改代码如下
python
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 定义自定义数据集
class SimpleDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
def __len__(self):
return len(self.data)
def train():
# 假设我们有一些简单的数据
data = np.array([1,2,3,4,5,6,7]) # np, tensor 格式都可以
targets = torch.tensor([1,1,1,1,0,0,0]) # 标签
# 实例化数据集
dataset = SimpleDataset(data, targets)
# 创建 DataLoader, 如果启动多线程,需要将启动代码放置在 if __name__ == "__main__": 下, 否则会报错
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
# 使用 DataLoader 迭代数据
for i, (batch_data, batch_label) in enumerate(dataloader):
print(f"Batch {i}: batch_data: {batch_data}, batch_label: {batch_label}")
if __name__ == "__main__":
train()