大模型系列3--pytorch dataloader的原理

pytorch dataloader运行原理

  • [1. 背景](#1. 背景)
  • [2. 环境搭建](#2. 环境搭建)
    • [2.1. 安装WSL & vscode](#2.1. 安装WSL & vscode)
    • [2.2. 安装conda & pytorch_gpu环境 & pytorch 2.11](#2.2. 安装conda & pytorch_gpu环境 & pytorch 2.11)
    • [2.3 命令行验证python环境](#2.3 命令行验证python环境)
    • [2.4. vscode启用pytorch_cpu虚拟环境](#2.4. vscode启用pytorch_cpu虚拟环境)
  • [3. 调试工具](#3. 调试工具)
    • [3.1. vscode 断点调试](#3.1. vscode 断点调试)
    • [3.2. py-spy代码栈探测](#3.2. py-spy代码栈探测)
    • [3.3. gdb attach](#3.3. gdb attach)
    • [3.4. 查看进程访问的系统调用](#3.4. 查看进程访问的系统调用)
  • [4. DataLoader代码分析](#4. DataLoader代码分析)

1. 背景

工作中遇到需要跟踪dataloader访问IO卡住的问题,有一个类似于IO read的堆栈的hang,需要判断是否是真的IO hang住,于是乎趁着周末仔细阅读一下dataloader的代码,了解下torch dataloader的内部原理。作为一个初学者,这个文章会比较杂一些,请各位读者谅解。

为了和linux相配套,本文拟采用WSL环境来搭建conda + torch的开发环境。

2. 环境搭建

2.1. 安装WSL & vscode

参考系列中的一篇文章:环境部署

2.2. 安装conda & pytorch_gpu环境 & pytorch 2.11

下载conda

在WSL中安装conda,通过以下命令下载sh脚本
wget https://repo.anaconda.com/archive/Anaconda3-2024.02-1-Linux-x86_64.sh

有另外一个镜像站,下载很快:https://mirrors.sustech.edu.cn/anaconda/archive/

对下载的内容进行SHA-256校验

  • Get-FileHash filename -Algorithm SHA256
  • c536ddb7b4ba738bddbd4e581b29308cb332fa12ae3fa2cd66814bd735dff231

安装conda

  • bash Anaconda3-2024.02-1-Linux-x86_64.sh
  • 按照提示,填yes,设置安装目录,更新shell,随后重启WSL的terminal界面。可以看到如下图,zshrc环境已经被更新了,重启shell会默认进入到(base)环境。

创建python虚拟环境

创建python虚拟环境pytorch_cpu,并激活它

安装2.1版本pytorch

  • conda install pytorch==2.1 cpuonly -c pytorch

安装pandas

  • conda install pandas

2.3 命令行验证python环境

准备构造一段数据:使用ChatGPT写一段代码,要求生成1-100个文件,采用pickle + gzip的模式,命名为1-100.pkl.gz,每个文件中是10个随机的kv对,k和v都是随机数字转换成的字符串。构造的代码如下:

python 复制代码
import os
import pickle
import gzip
import random
import string

# 解释代码 | 注释代码 | 生成单测 |
def generate_random_dict():
    random_dict = {}
    for _ in range(10):
        key = ''.join(random.choice(string.digits) for _ in range(5))
        value = ''.join(random.choice(string.digits) for _ in range(5))
        random_dict[key] = value
    return random_dict

def generate_files():
    file_names = [f'{i}.pkl.gz' for i in range(1, 101)]
    for file_name in file_names:
        with gzip.open(file_name, 'wb') as f:
            random_dict = generate_random_dict()
            pickle.dump(random_dict, f)
    for file_name in file_names:
        print(file_name)

if __name__ == "__main__":
    os.chdir("c:\\workspace\\llm\\hello_project_1\\dataset\\data\\filelist")
    generate_files()

运行上述代码:

  • python demo_gen_pkl_gz.py

输出结果如下:

2.4. vscode启用pytorch_cpu虚拟环境

vscode中启动WSL,然后打开一个python文件,点击vscode屏幕右下角的python环境,默认是/usr/bin/python,会自动提示多个python环境,选择pytorch_cpu环境,如下图所示:

打开上述python文件demo_gen_pkl_gz.py,点击右上角的三角符号,选择Run Python File,即可run此python文件。

3. 调试工具

为了更方便地进行问题跟踪,我们需要学习几种调试工具

3.1. vscode 断点调试

  • 在相应的代码增加断点
  • 点击右上角的Python Debugger: Debugger using launch.json 按钮
  • 它会自动在断点处停下来
    • 查看local和global的变量,主动添加新的监视
    • 查看线程堆栈
    • 单步运行或者继续或者停止均可

      如果将断点放在内部库的代码,例如在gzip.open实现内部打断点,会发现断点不生效。需要在lanuch.json中增加一行配置:"justMyCode": false,就可以使得断点生效了。

3.2. py-spy代码栈探测

  • pip3 install py-spy
  • py-spy dump --pid ${pid}
  • 支持的一些有用的参数

3.3. gdb attach

  • conda install gdb
  • apt-get install python3-dbg
  • gdb -p ${pid} 加载进程,即可使用各种命令进行调试

3.4. 查看进程访问的系统调用

  • strace -f -p ${pid} -s 1024

4. DataLoader代码分析

4.1. DataLoader代码示例

下面是一个采用多进程来读取数据的代码,它的代码逻辑很简单。首先创建一个DataLoader结构,它传入的最关键的参数为dataset,用以从dataset数据集中读取数据;最后通过for data in dataloader:将数据从dataloader中打印出来。可以通过调整num_workers来设置是否启动后台进程进行load数据

python 复制代码
import gzip
import os
import pickle
import random
import time

import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset


def load_gzip_pickle(pkl_fpath):
    with gzip.open(pkl_fpath, "rb") as f:
        data = pickle.load(f)
    return data


class MapDataSet(Dataset):
    def __init__(self, index_list_fpath):
        self.index_list = pd.read_csv(index_list_fpath)

    def __len__(self):
        return len(self.index_list)

    def __getitem__(self, idx):
        pkl_fpath = self.index_list.iloc[idx].tolist()[0]
        pkl_fpath = f"filelist/{pkl_fpath}"
        print("try to simulate slow io wait...")
        #time.sleep(10)
        data = load_gzip_pickle(pkl_fpath)
        # post processing
        print("try to simulate slow data processing...")
        #time.sleep(10)
        print(pkl_fpath, ": idx:", idx, ": data:", data.keys(), ": len", len(data), ": pid:", os.getpid())
        return data


def get_data_loader(index_list_fpath, batch_size=1, num_workers=16):
    dataset = MapDataSet(index_list_fpath=index_list_fpath)
    return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=lambda batch: batch[0])


def test_dataloader(index_list_fpath):
    batch_size = 1
    num_workers = 0
    dataloader = get_data_loader(index_list_fpath=index_list_fpath, batch_size=batch_size, num_workers=num_workers)
    for data in dataloader:
        print(data.keys(), ": len", len(data), ": pid:", os.getpid())


if __name__ == "__main__":
    os.chdir("c:\\workspace\\llm\\hello_project_1\\dataset\\data")
    index_list_fpath = "filelist.csv"
    test_dataloader(index_list_fpath)

输出结果

try to simulate slow io wait...
try to simulate slow data processing...
filelist/1.pkl.gz : idx: 0 : data: dict_keys(['86099', '83840', '15119', '03197', '57912', '42663', '32969', '49818', '47455', '53997']) : len 10 : pid: 9724
dict_keys(['86099', '83840', '15119', '03197', '57912', '42663', '32969', '49818', '47455', '53997']) : len 10 : pid: 9724
try to simulate slow io wait...
try to simulate slow data processing...
filelist/2.pkl.gz : idx: 1 : data: dict_keys(['91534', '12121', '94084', '12699', '03382', '10877', '21595', '20303', '41507', '47594']) : len 10 : pid: 9724
dict_keys(['91534', '12121', '94084', '12699', '03382', '10877', '21595', '20303', '41507', '47594']) : len 10 : pid: 9724
try to simulate slow io wait...
try to simulate slow data processing...
filelist/3.pkl.gz : idx: 2 : data: dict_keys(['85974', '89204', '39248', '46884', '09986', '30033', '97369', '18704', '24227', '15649']) : len 10 : pid: 9724
dict_keys(['85974', '89204', '39248', '46884', '09986', '30033', '97369', '18704', '24227', '15649']) : len 10 : pid: 9724
try to simulate slow io wait...
.......

4.2.

相关推荐
赛逸展张胜1 分钟前
CES Asia是一个关于什么的展会?
大数据·人工智能·科技
从以前9 分钟前
【算法题解】Bindian 山丘信号问题(E. Bindian Signaling)
开发语言·python·算法
海绵波波10721 分钟前
flask后端开发(9):ORM模型外键+迁移ORM模型
后端·python·flask
余生H26 分钟前
前端Python应用指南(二)深入Flask:理解Flask的应用结构与模块化设计
前端·后端·python·flask·全栈
Coovally AI模型快速验证30 分钟前
YOLO11全解析:从原理到实战,全流程体验下一代目标检测
人工智能·yolo·目标检测·机器学习·计算机视觉·目标跟踪·yolo11
CriticalThinking1 小时前
Pycharm不正常识别包含中文路径的解释器
ide·python·pycharm
湫ccc1 小时前
《Opencv》基础操作详解(2)
人工智能·opencv·计算机视觉
羑悻的小杀马特1 小时前
【AIGC篇】畅谈游戏开发设计中AIGC所发挥的不可或缺的作用
c++·人工智能·aigc·游戏开发
火山方舟1 小时前
解密!企业级智能客服高效运营的秘密武器 | 大模型流程设计与Prompt模版
前端·人工智能·稀土