Keras内置数据集

目录

1、MNIST数字分类数据集

2、CIFAR10小图像分类数据集

3、CIFAR100小图像分类数据集

4、IMDB电影评论情感分类数据集

参数说明

imdb_word_index.json

示例

5、路透社新闻专线分类数据集

reuters_word_index.json

[6、Fashion MNIST数据集](#6、Fashion MNIST数据集)

7、加州房价回归数据集

参数说明


1、MNIST数字分类数据集

包含60000个10位数的28x28灰度图像的数据集,以及10000个图像的测试集

1、加载本地mnist.npz格式数据

复制代码
keras.datasets.mnist.load_data(path="mnist.npz")

2、 使用keras.datasets.mnist.load_data()函数加载MNIST数据集

python 复制代码
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

#检查训练集和测试集的形状
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

2、CIFAR10小图像分类数据集

这是一个由50000张32x32彩色训练图像和10000张测试图像组成的数据集,标记为10个类别。

标签 类别
0 airplane
1 automobile
2 bird
3 cat
4 deer
5 dog
6 frog
7 horse
8 ship
9 truck

使用 keras.datasets.cifar10.load_data()加载数据集

python 复制代码
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

#检查训练集和测试集的形状
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

3、CIFAR100小图像分类数据集

该数据集与 CIFAR-10 类似,不同之处在于它有 100 个类,每个类包含 600 张图像。每类有 500 张训练图像和 100 张测试图像。CIFAR-100 中的 100 个类分为 20 个超类。每个图像都带有一个"精细"标签(它所属的类)和一个"粗略"标签(它所属的超类)

|--------------------------------|-------------------------------------------------------|
| 超类 | 类别 |
| aquatic mammals | beaver, dolphin, otter, seal, whale |
| fish | aquarium fish, flatfish, ray, shark, trout |
| flowers | orchids, poppies, roses, sunflowers, tulips |
| food containers | bottles, bowls, cans, cups, plates |
| fruit and vegetables | apples, mushrooms, oranges, pears, sweet peppers |
| household electrical devices | clock, computer keyboard, lamp, telephone, television |
| household furniture | bed, chair, couch, table, wardrobe |
| insects | bee, beetle, butterfly, caterpillar, cockroach |
| large carnivores | bear, leopard, lion, tiger, wolf |
| large man-made outdoor things | bridge, castle, house, road, skyscraper |
| large natural outdoor scenes | cloud, forest, mountain, plain, sea |
| large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo |
| medium-sized mammals | fox, porcupine, possum, raccoon, skunk |
| non-insect invertebrates | crab, lobster, snail, spider, worm |
| people | baby, boy, girl, man, woman |
| reptiles | crocodile, dinosaur, lizard, snake, turtle |
| small mammals | hamster, mouse, rabbit, shrew, squirrel |
| trees | maple, oak, palm, pine, willow |
| vehicles 1 | bicycle, bus, motorcycle, pickup truck, train |
| vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor |

使用 keras.datasets.cifar100.load_data()加载数据集

python 复制代码
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

#检查训练集和测试集的形状
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

4、IMDB电影评论情感分类数据集

这是来自IMDB的25000条电影评论的数据集,按情绪(积极/消极)进行标记。评论已经过预处理,每个评论都被编码为单词索引(整数)列表。

python 复制代码
keras.datasets.imdb.load_data(
    path="imdb.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    **kwargs
)
参数说明
  • **path:**数据存储的位置。
  • **num_words:**integer或None。单词根据它们出现的频率(在训练集中)进行排名,并且只保留最频繁的num_Words单词。任何不太频繁的单词都将在序列数据中显示为oov_char值。如果"无",则保留所有单词。默认为"无"。
  • **skip_top:**跳过前N个最频繁出现的单词(可能没有信息)。这些单词将在数据集中显示为oov_char值。当为0时,不跳过任何单词。默认值为0。
  • **maxlen:**int或None。最大序列长度。任何较长的序列都将被截断。无,意味着没有截断。默认为"无"。
  • **seed:**int,用于可再现数据混洗的种子。
  • **start_char:**int。序列的开头将用这个字符标记。0通常是填充字符。默认值为1。
  • **oov_char:**int,词汇表外的字符。由于num_Words或skip_top限制而被剪切掉的单词将被替换为此字符。
  • **index_from:**int,使用此索引或更高的索引实际单词。
imdb_word_index.json

单词索引词典。键是字符串,值是它们的索引

使用keras.datasets.imdb.get_word_index函数加载imdb_word_index.json

python 复制代码
keras.datasets.imdb.get_word_index(path="imdb_word_index.json")
示例
python 复制代码
# 导入Keras库中的IMDB数据集
import keras.datasets.imdb

# 设置起始字符的索引为1
start_char = 1

# 设置未知字符的索引为2
oov_char = 2

# 设置索引从3开始
index_from = 3

# 使用默认参数加载IMDB数据集的训练数据,并只获取训练序列(不获取测试序列)
(x_train, _), _ = keras.datasets.imdb.load_data(
    start_char=start_char, oov_char=oov_char, index_from=index_from
)

# 获取单词到索引的映射文件
word_index = keras.datasets.imdb.get_word_index()

# 反转单词索引,得到一个将索引映射到单词的字典
# 并将`index_from`添加到索引中,以与`x_train`同步
inverted_word_index = dict(
    (i + index_from, word) for (word, i) in word_index.items()
)

# 更新`inverted_word_index`,包含`start_char`和`oov_char`
inverted_word_index[start_char] = "[START]"
inverted_word_index[oov_char] = "[OOV]"

# 解码数据集中的第一个序列
decoded_sequence = " ".join(inverted_word_index[i] for i in x_train[0])

5、路透社新闻专线分类数据集

这是一个由路透社11228条新闻专线组成的数据集,标签超过46个主题。

python 复制代码
keras.datasets.reuters.load_data(
    path="reuters.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    test_split=0.2,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
)

参数说明

  • **path:**指定了保存数据的npz文件路径,这里设置为"reuters.npz"。
  • **num_words:**用于指定要保留的单词数量,设置为None表示保留所有单词。
  • skip_top用于指定要跳过的最常见的单词数量,设置为0表示不跳过任何单词。
  • **maxlen:**用于指定每个输入序列的最大长度,设置为None表示使用默认值。
  • **test_split:**参数用于指定测试集所占的比例,设置为0.2表示测试集占20%。
  • **seed:**参数用于指定随机数生成器的种子,设置为113以确保结果可重复。
  • **start_charoov_char:**分别用于指定未知单词的起始字符和未知单词的输出字符,设置为1和2。
  • **index_from:**参数用于指定索引的起始值,设置为3表示从3开始编号。
reuters_word_index.json

检索一个dict,将单词映射到路透社数据集中的索引。实际的单词索引从3开始,保留了3个索引:0(填充)、1(开始)、2(oov)。例如,"the"的单词索引为1,但在实际的训练数据中,"the"的索引将为1+3=4。反之亦然,要使用此映射将训练数据中的单词索引翻译回单词,索引需要减去3。

使用keras.datasets.reuters.get_word_index加载imdb_word_index.json

python 复制代码
keras.datasets.reuters.get_word_index(path="reuters_word_index.json")

6、Fashion MNIST数据集

这是一个由10个时尚类别的60000张28x28灰度图像组成的数据集,以及10000张图像的测试集

标签 类别
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

使用fashion_mnist.load_data()加载

python 复制代码
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

#检查测试集和训练集
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

7、加州房价回归数据集

这是一个连续回归数据集,包含20640个样本,每个样本有8个特征。目标变量是一个标量:加利福尼亚地区的房屋中值,单位为美元。

使用keras.datasets.california_housing.load_data加载

python 复制代码
keras.datasets.california_housing.load_data(
    version="large", path="california_housing.npz", test_split=0.2, seed=113
)
参数说明
  • version:"小"或"大"。小版本包含600个样本,大版本包含20640个样本。
  • **path:**本地数据集的路径。
  • **testsplit:**作为测试集保留的数据的一部分。
  • **seed:**在计算测试分割之前对数据进行混洗的随机种子。
相关推荐
新知图书3 分钟前
OpenCV界面编程
人工智能·opencv·计算机视觉
小杨4045 分钟前
python入门系列十五(asyncio)
人工智能·python·pycharm
hanniuniu137 分钟前
技术驱动革新,强力巨彩LED软模组助力创意显示
人工智能
xcLeigh7 分钟前
计算机视觉图像处理基础系列:滤波、边缘检测与形态学操作
图像处理·人工智能·计算机视觉·ai
程序猿阿伟13 分钟前
《打破SQL与AI框架对接壁垒,解锁融合新路径》
数据库·人工智能·sql
Helios@14 分钟前
CNN 中感受野/权值共享是什么意思?
人工智能·深度学习·计算机视觉
冰蓝蓝37 分钟前
TensorBoard
人工智能·深度学习
搞程序的心海41 分钟前
神经网络入门:生动解读机器学习的“神经元”
人工智能·神经网络·机器学习
AI浩43 分钟前
OverLoCK:一种采用“先总体把握再初步审视继而深入观察”架构的卷积神经网络(ConvNet),融合了上下文信息的动态卷积核
人工智能·神经网络·cnn
视觉AI1 小时前
研究下适合部署在jeston上的深度学习类单目标跟踪算法
深度学习·算法·目标跟踪