测试数据集

临时想测试某个神经网络模块,真实数据集有时候还得到处找。这里整理一下pytorch官方库一些自带的数据集,以便直接拿来用。

一、图像类数据

这类数据集体积小、下载快,是测试分类 / 卷积模块的首选,全部可通过torchvision.datasets自动下载。

  1. MNIST手写数据集

    这个数据集过于经典,数据也小(几十 MB),8x28 灰度图,10 分类,适合测试基础分类网络。

    python 复制代码
    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    
    # 数据预处理(适配PyTorch张量格式)
    transform = transforms.Compose([
        transforms.ToTensor(),  # 转为张量并归一化到[0,1]
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST官方均值/方差
    ])
    
    # 自动下载并加载真实数据
    train_dataset = datasets.MNIST(
        root='./data',  # 数据保存路径
        train=True,     # 训练集
        download=True,  # 自动下载(首次运行会下载,后续直接加载)
        transform=transform
    )
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 构建DataLoader(和真实训练流程一致)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # 验证真实数据
    for x, y in train_loader:
        print(f"MNIST真实数据形状:{x.shape}(批次, 通道, 高, 宽)")
        print(f"真实标签示例:{y[:5]}")  # 打印前5个真实标签
        break
  2. Fashion-MNIST (服装分类,比 MNIST 稍复杂)

    格式和 MNIST 完全一致,只是类别为服装(T 恤、裤子等 10 类),代码仅需替换数据集名称。

    python 复制代码
    train_dataset = datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform
    )
  3. CIFAR-10/CIFAR-100(彩色图像,测试卷积网络)

    python 复制代码
    # CIFAR-10预处理(彩色图需3通道归一化)
    cifar_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])
    
    # 加载CIFAR-10真实数据
    cifar_train = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=cifar_transform
    )
    cifar_loader = DataLoader(cifar_train, batch_size=32, shuffle=True)
    
    # 验证
    for x, y in cifar_loader:
        print(f"CIFAR-10真实数据形状:{x.shape}")  # (32, 3, 32, 32)
        break

二、文本类(真实文本数据集,测试 NLP 模块)

  1. NLTK 内置真实语料库 (提前下载后离线用)
    NLTK 包含大量真实文本语料(如《傲慢与偏见》《爱丽丝梦游仙境》、总统就职演说等),只需在能联网时执行一次下载,之后永久离线使用。
    联网时下载语料(仅需一次):

    python 复制代码
    import nltk
    nltk.download('gutenberg')  # 古腾堡真实书籍语料
    nltk.download('inaugural')  # 美国总统就职演说真实文本
    
    from nltk.corpus import gutenberg, inaugural
    
    # 读取《傲慢与偏见》真实文本
    pride_prejudice = gutenberg.raw('austen-persuasion.txt')
    print("《傲慢与偏见》片段:")
    print(pride_prejudice[:500])  # 打印前500字符
    
    # 读取2021年美国总统就职演说真实文本
    biden_speech = inaugural.raw('2021-Biden.txt')
    print("\n拜登就职演说片段:")
    print(biden_speech[:500])
  2. 实用补充:spaCy (带真实文本的预训练模型)
    下载模型(仅需一次)

    bash 复制代码
    # 终端执行(推荐小模型,约100M,包含真实文本标注数据)
    pip install spacy
    python -m spacy download en_core_web_sm  # 英文小模型
    # python -m spacy download zh_core_web_sm  # 中文小模

    离线使用真实文本分析(模型下载后无需外网)

    python 复制代码
    import spacy
    
    # 加载离线模型(已下载的模型)
    nlp = spacy.load("en_core_web_sm")  # 英文
    # nlp = spacy.load("zh_core_web_sm")  # 中文
    
    # 1. 模型内置的真实文本示例分析
    real_text = "Apple announced the launch of the new iPhone in California on September 12, 2023."
    doc = nlp(real_text)
    
    # 提取真实文本中的实体(公司、地点、日期等)
    print("真实文本实体识别:")
    for ent in doc.ents:
        print(f"文本片段:{ent.text} | 类型:{ent.label_}")
    
    # 2. 结合NLTK真实语料做深度分析
    from nltk.corpus import gutenberg
    # 读取《白鲸记》真实文本并做句法分析
    moby_dick = gutenberg.raw('melville-moby_dick.txt')[:1000]  # 取前1000字符
    doc_moby = nlp(moby_dick)
    print("\n《白鲸记》真实文本句法分析(名词短语):")
    for chunk in doc_moby.noun_chunks:
        print(chunk.text)

    大型的文本预训练数据,还是要在Hugging Face上找。

三.时序类数据

  1. statsmodels (计量统计库,自带经典时序数据)

    python 复制代码
    import statsmodels.api as sm
    import pandas as pd
    
    # 1. 航空乘客数据集(1949-1960年月度航班乘客数,最经典的时序预测数据集)
    # 自动下载并加载
    air_passengers = sm.datasets.get_rdataset("AirPassengers", "datasets")
    ap_df = air_passengers.data
    # 添加时间索引(匹配数据的时间范围)
    ap_df["date"] = pd.date_range(start="1949-01-01", periods=len(ap_df), freq="M")
    ap_df = ap_df.set_index("date")
    print("航空乘客数据集(真实时序数据):")
    print(ap_df.head())
    print(f"数据时间范围:{ap_df.index.min()} 到 {ap_df.index.max()}")
    
    # 2. CO2浓度数据集(1958-2001年月度大气CO2浓度观测值)
    # 自动下载并加载
    co2_data = sm.datasets.co2.load_pandas()
    co2_df = co2_data.data
    print("\nCO2浓度数据集(真实时序数据):")
    print(co2_df.head())
    print(f"数据时间范围:{co2_df.index.min()} 到 {co2_df.index.max()}")
    
    # 3. 英国汽车产量数据集(月度汽车产量)
    car_production = sm.datasets.get_rdataset("UKDriverDeaths", "datasets")
    car_df = car_production.data
    car_df["date"] = pd.date_range(start="1969-01-01", periods=len(car_df), freq="M")
    car_df = car_df.set_index("date")
    print("\n英国汽车产量数据集(真实时序数据):")
    print(car_df.head())
  2. sktime (专业时序机器学习库,内置数据集)

    python 复制代码
    from sktime.datasets import load_airline, load_lynx, load_UCR_UEA_dataset
    
    # 1. 加载航空乘客数据集(和statsmodels一致,专为时序分析优化)
    y = load_airline()
    print("sktime航空乘客时序数据:")
    print(y.head())
    
    # 2. 加载加拿大山猫捕获量数据集(年度时序数据)
    lynx = load_lynx()
    print("\n加拿大山猫捕获量时序数据:")
    print(lynx.head())
    
    # 3. 加载UCR/UEA时序数据集(国际经典时序数据集,自动下载)
    # 示例:下载冰箱温度时序数据集(真实传感器数据)
    X, y = load_UCR_UEA_dataset("FreezerRegularTrain")
    print("\n冰箱温度时序数据集(真实传感器数据):")
    print(X.head())

这里只是简单记录一些可用的测试数据集。仅供参考。

相关推荐
啃火龙果的兔子2 小时前
Pyglet开发游戏流程详解
python·游戏·pygame
古城小栈2 小时前
PyO3 库全介绍
python·rust
技术工小李2 小时前
2026马年年会“接福袋”游戏
python
0思必得02 小时前
[Web自动化] Requests模块请求参数
运维·前端·python·自动化·html
计算机毕设指导62 小时前
基于微信小程序的个性化漫画阅读推荐系统【源码文末联系】
java·python·微信小程序·小程序·tomcat·maven·intellij-idea
百锦再3 小时前
开发抖音小程序组件大全概述
人工智能·python·ai·小程序·aigc·notepad++·自然语言
沃斯堡&蓝鸟3 小时前
DAY34 文件的规范拆分和写法
开发语言·python
ss2733 小时前
final关键字如何创造线程安全的对象
开发语言·python
大得3693 小时前
gpt-oss:20b大模型知识库,ai大模型
人工智能·python·gpt