一维卷积完成文本情感分类任务

前文介绍

本文主要搭建经典的一维卷积神经网路模型,用于完成 IMDB 电影评论的情感分类预测任务,并在最后对模型进行了改造升级。

数据处理

IMDB 数据中每个样本都有一个电影评论文本情感标签,如下所示展示了两个样本,01 分别表示的是负面情感正面情感

css 复制代码
b"I gave this a four purely out of its historical context. ... If you are a Stooge fan, then stick with the shorts."
0
b'First than anything, I\'m not going to praise I\xc3\xb1arritu\'s short film, ... that is not equal to American People.'
1

首先我们要将样本中的一些无用字符进行剔除,将样本数据清洗干净,然后使用 TextVectorization 来对样本中的评论数据进行向量化,也就是先进行分词,然后将分词映射成对应的整数,最后得到就是将评论转变后的整数数组。关键代码如下:

scss 复制代码
vectorize_layer = TextVectorization(standardize=custom_standardization, max_tokens=max_features, output_mode="int", output_sequence_length=sequence_length)
text_ds = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)
def vectorize_text(text, label):
    text = tf.expand_dims(text, -1)
    return vectorize_layer(text), label
train_ds = raw_train_ds.map(vectorize_text).cache().prefetch(buffer_size = 10)
val_ds = raw_val_ds.map(vectorize_text).cache().prefetch(buffer_size = 10)
test_ds = raw_test_ds.map(vectorize_text).cache().prefetch(buffer_size = 10)

模型搭建

我们搭建的是一个简单但是经典的文本分类模型,由以下部分构成:

  1. inputs = tf.keras.Input(shape=(None,), dtype="int64"):创建一个可变长度的整数序列输入张量,用于接收上面经过向量化的文本数据。
  2. x = layers.Embedding(max_features, embedding_dim)(inputs):对输入的整数序列进行嵌入操作,将每个整数转换为对应的向量表示。
  3. x = layers.Dropout(0.5)(x):在嵌入层后添加了一个丢弃 50% 的Dropout层,用于在训练过程中随机丢弃一定比例的神经元,以防止过拟合。
  4. x = layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3)(x):添加了一个一维卷积层,用于提取文本特征。
  5. x = layers.GlobalMaxPool1D()(x):应用一维全局最大池化操作,用于提取每个卷积核的最大特征值,以捕获最重要的特征。
  6. x = layers.Dense(128, activation="relu")(x):添加一个包含128个神经元的全连接层,并使用ReLU激活函数。
  7. x = layers.Dropout(0.5)(x):再次添加Dropout层,以进一步减少过拟合的风险。
  8. predictions = layers.Dense(1, activation="sigmoid", name="predictions")(x):最后添加一个输出层,用于二元分类任务的预测,使用 Sigmoid 激活函数输出 0 或者 1 。
  9. model = tf.keras.Model(inputs, predictions):创建一个Keras模型,将输入和输出连接起来。

编译训练

  1. model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]):编译模型,指定损失函数、优化器以及评估指标。
  2. model.fit(train_ds, validation_data=val_ds, epochs=3):使用训练数据集进行模型训练,同时使用验证数据集进行验证,训练 3 个 epochs 。

训练日志打印:

arduino 复制代码
Epoch 1/3
157/157 [==============================] - 9s 26ms/step - loss: 0.6597 - accuracy: 0.5632 - val_loss: 0.4223 - val_accuracy: 0.8080
Epoch 2/3
157/157 [==============================] - 1s 9ms/step - loss: 0.3168 - accuracy: 0.8676 - val_loss: 0.3400 - val_accuracy: 0.8596
Epoch 3/3
157/157 [==============================] - 1s 9ms/step - loss: 0.1406 - accuracy: 0.9501 - val_loss: 0.3697 - val_accuracy: 0.8750

使用测试数据进行模型效果测试,日志打印:

arduino 复制代码
196/196 [==============================] - 4s 18ms/step - loss: 0.3887 - accuracy: 0.8665

模型升级

由于上面的模型的输入只能接收经过向量化处理的整数,而我们平时遇到的都是直接输入字符串,所以需要将上面的模型进行改造,我们在上面已经训练好的模型的前面再添加一个输入层专门用于接受字符串,然后再跟一个 vectorize_layer 层用来将字符串进行向量化,接着传入到我们之前已经训练好的模型中即可。最后使用 Model 模型将输入和输出固定就可以造出一个新的用于接收字符串的模型,我们可以直接输入电影文本评论用于预测情感倾向。

关键代码:

ini 复制代码
inputs = tf.keras.Input(shape=(1,), dtype="string")
indices = vectorize_layer(inputs)
outputs = model(indices)
end_to_end_model = tf.keras.Model(inputs, outputs)
end_to_end_model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
end_to_end_model.evaluate(raw_test_ds)

使用测试数据进行测试,将日志打印出来,可以看出和上面的测试结果是一样的:

arduino 复制代码
196/196 [==============================] - 4s 18ms/step - loss: 0.3887 - accuracy: 0.8665

参考

github.com/wangdayaya/...

相关推荐
看到我,请让我去学习4 小时前
OpenCV 与深度学习:从图像分类到目标检测技术
深度学习·opencv·分类
加油加油的大力4 小时前
入门基于深度学习(以yolov8和unet为例)的计算机视觉领域的学习路线
深度学习·yolo·计算机视觉
IRevers6 小时前
【自动驾驶】经典LSS算法解析——深度估计
人工智能·python·深度学习·算法·机器学习·自动驾驶
学废了wuwu6 小时前
深度学习归一化方法维度参数详解(C/H/W/D完全解析)
c语言·人工智能·深度学习
whabc1006 小时前
和鲸社区深度学习基础训练营2025年关卡4
人工智能·深度学习
whabc1007 小时前
和鲸社区深度学习基础训练营2025年关卡2(3)pytorch
人工智能·深度学习·sklearn
学废了wuwu7 小时前
深度学习中的归一化技术详解:BN、LN、IN、GN
人工智能·深度学习
MUTA️7 小时前
《MAE: Masked Autoencoders Are Scalable Vision Learners》论文精读笔记
人工智能·笔记·深度学习·transformer
Ronin-Lotus7 小时前
深度学习篇---昇腾NPU&CANN 工具包
人工智能·深度学习·npu·昇腾 cann
PNP机器人7 小时前
普林斯顿大学DPPO机器人学习突破:Diffusion Policy Policy Optimization 全新优化扩散策略
人工智能·深度学习·学习·机器人·仿真平台·franka fr3