技巧-PyTorch中num_works的作用和实验测试

本专栏为深度学习的一些技巧,方法和实验测试,偏向于实际应用,后续不断更新,感兴趣童鞋可关,方便后续推送

简介

在 PyTorch 中,num_workers 是 DataLoader 中的一个参数,用于控制数据加载的并发线程数。它允许您在数据加载过程中使用多个线程,以提高数据加载的效率。

具体来说,num_workers 参数指定了 DataLoader 在加载数据时将创建的子进程数量。当 num_workers 大于 0 时,DataLoader 会自动利用多个子进程来加速数据加载。这有助于减少主进程的等待时间,并使得数据加载更加并行化。

例如,如果您有一个大型数据集需要加载,而且您的系统有多个 CPU 核心可用,您可以使用 num_workers 参数来提高数据加载的效率。假设您的系统有 4 个 CPU 核心,您可以将 num_workers 设置为 4,以使 DataLoader 在每个核心上创建一个子进程,并行加载数据.

使用方法

下面是一个示例代码,演示了如何使用 num_workers 参数来加速数据加载:

cpp 复制代码
python
import torch  
from torch.utils.data import DataLoader  
from torchvision import datasets, transforms  
  
# 定义数据预处理操作  
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])  
  
# 加载数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  
  
# 创建 DataLoader,设置 num_workers 为 4  
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)  
 # 训练模型...

在上述示例中,我们使用 MNIST 数据集,设置了 num_workers 为 4,以利用系统的 4 个 CPU 核心并行加载数据。这将加速数据加载的过程,使得模型训练更加高效。

实测效果

我采用MMDetetion训练,它可以通过钩子函数统计每一iter的数据读取耗时(data_time)和总耗时(time)

当num_works设置为1时打印结果如下:
当num_works设置为4时打印结果如下:
实验效果与理论一致

相关推荐
做咩啊~几秒前
CentOS 7部署OpenLDAP+phpLDAPadmin实现统一认证
linux·运维·centos
Justinyh1 分钟前
Notion同步到CSDN + 构建Obsidian本地博客系统指南
python·csdn·图床·notion·obsidian·文档同步·piclist
大千AI助手3 分钟前
多维空间的高效导航者:KD树算法深度解析
数据结构·人工智能·算法·机器学习·大千ai助手·kd tree·kd树
Coding茶水间4 分钟前
基于深度学习的西红柿成熟度检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
roman_日积跬步-终至千里4 分钟前
【模式识别与机器学习(11)】数据预处理(第三部分):高级技术与质量保证
人工智能·机器学习·支持向量机
^乘风破浪^6 分钟前
Centos升级openssh及openssl
linux·运维·centos
HX4366 分钟前
Swift - Sendable (not just Sendable)
人工智能·ios·全栈
大白的编程笔记7 分钟前
大语言模型(Large Language Model, LLM)系统详解
人工智能·语言模型·自然语言处理
满天星83035777 分钟前
【Linux】【进程间通信】管道
linux·运维·服务器
赖small强7 分钟前
【Linux驱动开发】Linux EXT4文件系统技术深度解析与实践指南
linux·驱动开发·ext4·superblock·super block·block bitmap·inode bitmap