共享内存和Pytorch中的Dataloader结合

dataloader中通常使用num_workers来指定多线程来进行数据的读取。可以使用共享内存进行加速。

代码地址:https://github.com/POSTECH-CVLab/point-transformer/blob/master/util/s3dis.py

文章目录

    • [1. 共享内存和dataloader结合](#1. 共享内存和dataloader结合)
      • [1.1 在init中把所有的data存储到共享内存中](#1.1 在init中把所有的data存储到共享内存中)
      • [1.2 在getitem从共享内存中读出data](#1.2 在getitem从共享内存中读出data)
    • [2. 怎么查询key在不在共享内存中](#2. 怎么查询key在不在共享内存中)
    • [3. 共享内存的地址是什么,怎么查看](#3. 共享内存的地址是什么,怎么查看)
    • [4. 共享内存有多大](#4. 共享内存有多大)
    • [5. 共享怎么删除](#5. 共享怎么删除)

1. 共享内存和dataloader结合

复制代码
class S3DIS(Dataset):
    def __init__(self, split='train', data_root='trainval', test_area=5, voxel_size=0.04, voxel_max=None, transform=None, shuffle_index=False, loop=1):
        super().__init__()
        self.split, self.voxel_size, self.transform, self.voxel_max, self.shuffle_index, self.loop = split, voxel_size, transform, voxel_max, shuffle_index, loop
        data_list = sorted(os.listdir(data_root))
        data_list = [item[:-4] for item in data_list if 'Area_' in item]
        if split == 'train':
            self.data_list = [item for item in data_list if not 'Area_{}'.format(test_area) in item]
        else:
            self.data_list = [item for item in data_list if 'Area_{}'.format(test_area) in item]
        for item in self.data_list:
            if not os.path.exists("/dev/shm/{}".format(item)):
                data_path = os.path.join(data_root, item + '.npy')
                data = np.load(data_path)  # xyzrgbl, N*7
                sa_create("shm://{}".format(item), data)
        self.data_idx = np.arange(len(self.data_list))
        print("Totally {} samples in {} set.".format(len(self.data_idx), split))

    def __getitem__(self, idx):
        data_idx = self.data_idx[idx % len(self.data_idx)]
        data = SA.attach("shm://{}".format(self.data_list[data_idx])).copy()
        coord, feat, label = data[:, 0:3], data[:, 3:6], data[:, 6]
        coord, feat, label = data_prepare(coord, feat, label, self.split, self.voxel_size, self.voxel_max, self.transform, self.shuffle_index)
        return coord, feat, label

    def __len__(self):
        return len(self.data_idx) * self.loop

1.1 在init中把所有的data存储到共享内存中

复制代码
for item in self.data_list:
    if not os.path.exists("/dev/shm/{}".format(item)):
        data_path = os.path.join(data_root, item + '.npy')
        data = np.load(data_path)  # xyzrgbl, N*7
        sa_create("shm://{}".format(item), data)

key就是文件名,存储在self.data_list中

1.2 在getitem从共享内存中读出data

复制代码
data = SA.attach("shm://{}".format(self.data_list[data_idx])).copy()

2. 怎么查询key在不在共享内存中

os.path.exists("/dev/shm/{}".format(item))能够查询该key在不在共享内存中。

3. 共享内存的地址是什么,怎么查看

复制代码
$ ls /dev/shm  
Area_5_hallway

通过/dev/shm地址访问,但是这部分数据存储在内存中。

4. 共享内存有多大

复制代码
$ df -h   
Filesystem      Size  Used Avail Use% Mounted on
tmpfs           7.8G   44M  7.8G   1% /dev/shm

大小是物理内存的一半

5. 共享怎么删除

复制代码
SA.delete("shm://{}".format('Area_5_hallway'))

SA.delete删除key

相关推荐
武子康1 小时前
AI研究-117 特斯拉 FSD 视觉解析:多摄像头 - 3D占用网络 - 车机渲染,盲区与低速复杂路况安全指南
人工智能·科技·计算机视觉·3d·视觉检测·特斯拉·model y
Geoking.1 小时前
PyTorch torch.unique() 基础与实战
人工智能·pytorch·python
AndrewHZ1 小时前
【图像处理基石】如何在图像中实现光晕的星芒效果?
图像处理·opencv·计算机视觉·cv·图像增强·算法入门·星芒效果
熊猫_豆豆1 小时前
神经网络的科普,功能用途,包含的数学知识
人工智能·深度学习·神经网络
俊俊谢1 小时前
【第一章】金融数据的获取——金融量化学习入门笔记
笔记·python·学习·金融·量化·akshare
你也渴望鸡哥的力量么2 小时前
基于边缘信息提取的遥感图像开放集飞机检测方法
人工智能·计算机视觉
xian_wwq2 小时前
【学习笔记】深度学习中梯度消失和爆炸问题及其解决方案研究
人工智能·深度学习·梯度
闲人编程2 小时前
现代Python开发环境搭建(VSCode + Dev Containers)
开发语言·vscode·python·容器·dev·codecapsule
XIAO·宝4 小时前
深度学习------图像分割项目
人工智能·深度学习·图像分割
nvd114 小时前
python异步编程 -- 深入理解事件循环event-loop
python