Keras内置数据集

目录

1、MNIST数字分类数据集

2、CIFAR10小图像分类数据集

3、CIFAR100小图像分类数据集

4、IMDB电影评论情感分类数据集

参数说明

imdb_word_index.json

示例

5、路透社新闻专线分类数据集

reuters_word_index.json

[6、Fashion MNIST数据集](#6、Fashion MNIST数据集)

7、加州房价回归数据集

参数说明


1、MNIST数字分类数据集

包含60000个10位数的28x28灰度图像的数据集,以及10000个图像的测试集

1、加载本地mnist.npz格式数据

keras.datasets.mnist.load_data(path="mnist.npz")

2、 使用keras.datasets.mnist.load_data()函数加载MNIST数据集

python 复制代码
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

#检查训练集和测试集的形状
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

2、CIFAR10小图像分类数据集

这是一个由50000张32x32彩色训练图像和10000张测试图像组成的数据集,标记为10个类别。

标签 类别
0 airplane
1 automobile
2 bird
3 cat
4 deer
5 dog
6 frog
7 horse
8 ship
9 truck

使用 keras.datasets.cifar10.load_data()加载数据集

python 复制代码
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

#检查训练集和测试集的形状
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

3、CIFAR100小图像分类数据集

该数据集与 CIFAR-10 类似,不同之处在于它有 100 个类,每个类包含 600 张图像。每类有 500 张训练图像和 100 张测试图像。CIFAR-100 中的 100 个类分为 20 个超类。每个图像都带有一个"精细"标签(它所属的类)和一个"粗略"标签(它所属的超类)

|--------------------------------|-------------------------------------------------------|
| 超类 | 类别 |
| aquatic mammals | beaver, dolphin, otter, seal, whale |
| fish | aquarium fish, flatfish, ray, shark, trout |
| flowers | orchids, poppies, roses, sunflowers, tulips |
| food containers | bottles, bowls, cans, cups, plates |
| fruit and vegetables | apples, mushrooms, oranges, pears, sweet peppers |
| household electrical devices | clock, computer keyboard, lamp, telephone, television |
| household furniture | bed, chair, couch, table, wardrobe |
| insects | bee, beetle, butterfly, caterpillar, cockroach |
| large carnivores | bear, leopard, lion, tiger, wolf |
| large man-made outdoor things | bridge, castle, house, road, skyscraper |
| large natural outdoor scenes | cloud, forest, mountain, plain, sea |
| large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo |
| medium-sized mammals | fox, porcupine, possum, raccoon, skunk |
| non-insect invertebrates | crab, lobster, snail, spider, worm |
| people | baby, boy, girl, man, woman |
| reptiles | crocodile, dinosaur, lizard, snake, turtle |
| small mammals | hamster, mouse, rabbit, shrew, squirrel |
| trees | maple, oak, palm, pine, willow |
| vehicles 1 | bicycle, bus, motorcycle, pickup truck, train |
| vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor |

使用 keras.datasets.cifar100.load_data()加载数据集

python 复制代码
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

#检查训练集和测试集的形状
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

4、IMDB电影评论情感分类数据集

这是来自IMDB的25000条电影评论的数据集,按情绪(积极/消极)进行标记。评论已经过预处理,每个评论都被编码为单词索引(整数)列表。

python 复制代码
keras.datasets.imdb.load_data(
    path="imdb.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    **kwargs
)
参数说明
  • **path:**数据存储的位置。
  • **num_words:**integer或None。单词根据它们出现的频率(在训练集中)进行排名,并且只保留最频繁的num_Words单词。任何不太频繁的单词都将在序列数据中显示为oov_char值。如果"无",则保留所有单词。默认为"无"。
  • **skip_top:**跳过前N个最频繁出现的单词(可能没有信息)。这些单词将在数据集中显示为oov_char值。当为0时,不跳过任何单词。默认值为0。
  • **maxlen:**int或None。最大序列长度。任何较长的序列都将被截断。无,意味着没有截断。默认为"无"。
  • **seed:**int,用于可再现数据混洗的种子。
  • **start_char:**int。序列的开头将用这个字符标记。0通常是填充字符。默认值为1。
  • **oov_char:**int,词汇表外的字符。由于num_Words或skip_top限制而被剪切掉的单词将被替换为此字符。
  • **index_from:**int,使用此索引或更高的索引实际单词。
imdb_word_index.json

单词索引词典。键是字符串,值是它们的索引

使用keras.datasets.imdb.get_word_index函数加载imdb_word_index.json

python 复制代码
keras.datasets.imdb.get_word_index(path="imdb_word_index.json")
示例
python 复制代码
# 导入Keras库中的IMDB数据集
import keras.datasets.imdb

# 设置起始字符的索引为1
start_char = 1

# 设置未知字符的索引为2
oov_char = 2

# 设置索引从3开始
index_from = 3

# 使用默认参数加载IMDB数据集的训练数据,并只获取训练序列(不获取测试序列)
(x_train, _), _ = keras.datasets.imdb.load_data(
    start_char=start_char, oov_char=oov_char, index_from=index_from
)

# 获取单词到索引的映射文件
word_index = keras.datasets.imdb.get_word_index()

# 反转单词索引,得到一个将索引映射到单词的字典
# 并将`index_from`添加到索引中,以与`x_train`同步
inverted_word_index = dict(
    (i + index_from, word) for (word, i) in word_index.items()
)

# 更新`inverted_word_index`,包含`start_char`和`oov_char`
inverted_word_index[start_char] = "[START]"
inverted_word_index[oov_char] = "[OOV]"

# 解码数据集中的第一个序列
decoded_sequence = " ".join(inverted_word_index[i] for i in x_train[0])

5、路透社新闻专线分类数据集

这是一个由路透社11228条新闻专线组成的数据集,标签超过46个主题。

python 复制代码
keras.datasets.reuters.load_data(
    path="reuters.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    test_split=0.2,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
)

参数说明

  • **path:**指定了保存数据的npz文件路径,这里设置为"reuters.npz"。
  • **num_words:**用于指定要保留的单词数量,设置为None表示保留所有单词。
  • skip_top用于指定要跳过的最常见的单词数量,设置为0表示不跳过任何单词。
  • **maxlen:**用于指定每个输入序列的最大长度,设置为None表示使用默认值。
  • **test_split:**参数用于指定测试集所占的比例,设置为0.2表示测试集占20%。
  • **seed:**参数用于指定随机数生成器的种子,设置为113以确保结果可重复。
  • **start_charoov_char:**分别用于指定未知单词的起始字符和未知单词的输出字符,设置为1和2。
  • **index_from:**参数用于指定索引的起始值,设置为3表示从3开始编号。
reuters_word_index.json

检索一个dict,将单词映射到路透社数据集中的索引。实际的单词索引从3开始,保留了3个索引:0(填充)、1(开始)、2(oov)。例如,"the"的单词索引为1,但在实际的训练数据中,"the"的索引将为1+3=4。反之亦然,要使用此映射将训练数据中的单词索引翻译回单词,索引需要减去3。

使用keras.datasets.reuters.get_word_index加载imdb_word_index.json

python 复制代码
keras.datasets.reuters.get_word_index(path="reuters_word_index.json")

6、Fashion MNIST数据集

这是一个由10个时尚类别的60000张28x28灰度图像组成的数据集,以及10000张图像的测试集

标签 类别
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

使用fashion_mnist.load_data()加载

python 复制代码
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

#检查测试集和训练集
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

7、加州房价回归数据集

这是一个连续回归数据集,包含20640个样本,每个样本有8个特征。目标变量是一个标量:加利福尼亚地区的房屋中值,单位为美元。

使用keras.datasets.california_housing.load_data加载

python 复制代码
keras.datasets.california_housing.load_data(
    version="large", path="california_housing.npz", test_split=0.2, seed=113
)
参数说明
  • version:"小"或"大"。小版本包含600个样本,大版本包含20640个样本。
  • **path:**本地数据集的路径。
  • **testsplit:**作为测试集保留的数据的一部分。
  • **seed:**在计算测试分割之前对数据进行混洗的随机种子。
相关推荐
千天夜23 分钟前
激活函数解析:神经网络背后的“驱动力”
人工智能·深度学习·神经网络
大数据面试宝典24 分钟前
用AI来写SQL:让ChatGPT成为你的数据库助手
数据库·人工智能·chatgpt
封步宇AIGC29 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
m0_5236742131 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
HappyAcmen41 分钟前
IDEA部署AI代写插件
java·人工智能·intellij-idea
噜噜噜噜鲁先森1 小时前
看懂本文,入门神经网络Neural Network
人工智能
InheritGuo2 小时前
It’s All About Your Sketch: Democratising Sketch Control in Diffusion Models
人工智能·计算机视觉·sketch
weixin_307779132 小时前
证明存在常数c, C > 0,使得在一系列特定条件下,某个特定投资时刻出现的概率与天数的对数成反比
人工智能·算法·机器学习
封步宇AIGC2 小时前
量化交易系统开发-实时行情自动化交易-3.4.1.6.A股宏观经济数据
人工智能·python·机器学习·数据挖掘
Jack黄从零学c++2 小时前
opencv(c++)图像的灰度转换
c++·人工智能·opencv