临时想测试某个神经网络模块,真实数据集有时候还得到处找。这里整理一下pytorch官方库一些自带的数据集,以便直接拿来用。
一、图像类数据
这类数据集体积小、下载快,是测试分类 / 卷积模块的首选,全部可通过torchvision.datasets自动下载。
-
MNIST手写数据集
这个数据集过于经典,数据也小(几十 MB),8x28 灰度图,10 分类,适合测试基础分类网络。
pythonimport torch from torchvision import datasets, transforms from torch.utils.data import DataLoader # 数据预处理(适配PyTorch张量格式) transform = transforms.Compose([ transforms.ToTensor(), # 转为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST官方均值/方差 ]) # 自动下载并加载真实数据 train_dataset = datasets.MNIST( root='./data', # 数据保存路径 train=True, # 训练集 download=True, # 自动下载(首次运行会下载,后续直接加载) transform=transform ) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 构建DataLoader(和真实训练流程一致) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 验证真实数据 for x, y in train_loader: print(f"MNIST真实数据形状:{x.shape}(批次, 通道, 高, 宽)") print(f"真实标签示例:{y[:5]}") # 打印前5个真实标签 break -
Fashion-MNIST (服装分类,比 MNIST 稍复杂)
格式和 MNIST 完全一致,只是类别为服装(T 恤、裤子等 10 类),代码仅需替换数据集名称。
pythontrain_dataset = datasets.FashionMNIST( root='./data', train=True, download=True, transform=transform ) -
CIFAR-10/CIFAR-100(彩色图像,测试卷积网络)
python# CIFAR-10预处理(彩色图需3通道归一化) cifar_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) ]) # 加载CIFAR-10真实数据 cifar_train = datasets.CIFAR10( root='./data', train=True, download=True, transform=cifar_transform ) cifar_loader = DataLoader(cifar_train, batch_size=32, shuffle=True) # 验证 for x, y in cifar_loader: print(f"CIFAR-10真实数据形状:{x.shape}") # (32, 3, 32, 32) break
二、文本类(真实文本数据集,测试 NLP 模块)
-
NLTK 内置真实语料库 (提前下载后离线用)
NLTK 包含大量真实文本语料(如《傲慢与偏见》《爱丽丝梦游仙境》、总统就职演说等),只需在能联网时执行一次下载,之后永久离线使用。
联网时下载语料(仅需一次):pythonimport nltk nltk.download('gutenberg') # 古腾堡真实书籍语料 nltk.download('inaugural') # 美国总统就职演说真实文本 from nltk.corpus import gutenberg, inaugural # 读取《傲慢与偏见》真实文本 pride_prejudice = gutenberg.raw('austen-persuasion.txt') print("《傲慢与偏见》片段:") print(pride_prejudice[:500]) # 打印前500字符 # 读取2021年美国总统就职演说真实文本 biden_speech = inaugural.raw('2021-Biden.txt') print("\n拜登就职演说片段:") print(biden_speech[:500]) -
实用补充:spaCy (带真实文本的预训练模型)
下载模型(仅需一次)bash# 终端执行(推荐小模型,约100M,包含真实文本标注数据) pip install spacy python -m spacy download en_core_web_sm # 英文小模型 # python -m spacy download zh_core_web_sm # 中文小模离线使用真实文本分析(模型下载后无需外网)
pythonimport spacy # 加载离线模型(已下载的模型) nlp = spacy.load("en_core_web_sm") # 英文 # nlp = spacy.load("zh_core_web_sm") # 中文 # 1. 模型内置的真实文本示例分析 real_text = "Apple announced the launch of the new iPhone in California on September 12, 2023." doc = nlp(real_text) # 提取真实文本中的实体(公司、地点、日期等) print("真实文本实体识别:") for ent in doc.ents: print(f"文本片段:{ent.text} | 类型:{ent.label_}") # 2. 结合NLTK真实语料做深度分析 from nltk.corpus import gutenberg # 读取《白鲸记》真实文本并做句法分析 moby_dick = gutenberg.raw('melville-moby_dick.txt')[:1000] # 取前1000字符 doc_moby = nlp(moby_dick) print("\n《白鲸记》真实文本句法分析(名词短语):") for chunk in doc_moby.noun_chunks: print(chunk.text)大型的文本预训练数据,还是要在Hugging Face上找。
三.时序类数据
-
statsmodels (计量统计库,自带经典时序数据)
pythonimport statsmodels.api as sm import pandas as pd # 1. 航空乘客数据集(1949-1960年月度航班乘客数,最经典的时序预测数据集) # 自动下载并加载 air_passengers = sm.datasets.get_rdataset("AirPassengers", "datasets") ap_df = air_passengers.data # 添加时间索引(匹配数据的时间范围) ap_df["date"] = pd.date_range(start="1949-01-01", periods=len(ap_df), freq="M") ap_df = ap_df.set_index("date") print("航空乘客数据集(真实时序数据):") print(ap_df.head()) print(f"数据时间范围:{ap_df.index.min()} 到 {ap_df.index.max()}") # 2. CO2浓度数据集(1958-2001年月度大气CO2浓度观测值) # 自动下载并加载 co2_data = sm.datasets.co2.load_pandas() co2_df = co2_data.data print("\nCO2浓度数据集(真实时序数据):") print(co2_df.head()) print(f"数据时间范围:{co2_df.index.min()} 到 {co2_df.index.max()}") # 3. 英国汽车产量数据集(月度汽车产量) car_production = sm.datasets.get_rdataset("UKDriverDeaths", "datasets") car_df = car_production.data car_df["date"] = pd.date_range(start="1969-01-01", periods=len(car_df), freq="M") car_df = car_df.set_index("date") print("\n英国汽车产量数据集(真实时序数据):") print(car_df.head()) -
sktime (专业时序机器学习库,内置数据集)
pythonfrom sktime.datasets import load_airline, load_lynx, load_UCR_UEA_dataset # 1. 加载航空乘客数据集(和statsmodels一致,专为时序分析优化) y = load_airline() print("sktime航空乘客时序数据:") print(y.head()) # 2. 加载加拿大山猫捕获量数据集(年度时序数据) lynx = load_lynx() print("\n加拿大山猫捕获量时序数据:") print(lynx.head()) # 3. 加载UCR/UEA时序数据集(国际经典时序数据集,自动下载) # 示例:下载冰箱温度时序数据集(真实传感器数据) X, y = load_UCR_UEA_dataset("FreezerRegularTrain") print("\n冰箱温度时序数据集(真实传感器数据):") print(X.head())
这里只是简单记录一些可用的测试数据集。仅供参考。