NLP之Bert实现文本多分类

文章目录

代码

python 复制代码
from pypro.chapters03.demo03_数据获取与处理 import train_list, label_list, val_train_list, val_label_list
import tensorflow as tf
from transformers import TFBertForSequenceClassification

bert_model = "bert-base-chinese"

model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)
model.compile(metrics=['accuracy'], loss=tf.nn.sigmoid_cross_entropy_with_logits)
model.summary()
result = model.fit(x=train_list[:24], y=label_list[:24], batch_size=12, epochs=1)
print(result.history)
# 保存模型(模型保存的本质就是保存训练的参数,而对于深度学习而言还保存神经网络结构)
model.save_weights('../data/model.h5')

model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)
model.load_weights('../data/model.h5')
result = model.predict(val_train_list[:12])  # 预测值
print(result)
result = tf.nn.sigmoid(result)
print(result)
result = tf.cast(tf.greater_equal(result, 0.5), tf.float32)
print(result)

代码整体流程解读

这段代码的目的是利用TensorFlow和transformers库来进行文本序列的分类任务。下面是整体流程的概述和逐步计划:

  1. 导入必要的库和数据:

    • 从一个叫做 pypro.chapters03.demo03_数据获取与处理 的模块中导入了四个列表:train_list, label_list, val_train_list, val_label_list。这些列表分别包含训练数据、训练标签、验证数据和验证标签。
    • 导入TensorFlow和transformers库。
  2. 初始化预训练的BERT模型:

    • 使用 bert-base-chinese 模型初始化一个用于序列分类的BERT模型。
    • 模型被配置为对32个不同的标签进行分类。
  3. 编译模型:

    • 使用sigmoid交叉熵作为损失函数,并跟踪准确度作为性能指标。
  4. 模型摘要:

    • 输出模型的概要信息,包括每一层的名称、类型、输出形状和参数数量。
  5. 训练模型:

    • 使用提供的训练数据和标签(仅取前24个样本)来训练模型。
    • 批量大小设置为12,训练仅进行1个时代(epoch),这意味着数据将通过模型传递一次。
  6. 输出训练结果:

    • 打印训练过程中记录的历史数据,通常包括损失值和准确度。
  7. 保存模型权重:

    • 将训练后的模型权重保存到本地文件 model.h5
  8. 加载模型权重:

    • 初始化一个新的模型结构,并加载之前保存的权重。
  9. 模型预测:

    • 使用验证数据(仅取前12个样本)进行预测。
  10. 激活函数处理:

    • 将预测结果通过sigmoid函数处理,转换成0到1之间的值。
  11. 转换预测结果:

    • 通过比较预测值是否大于或等于0.5来将概率转换为二进制分类结果。

debug上面的代码

下面逐行解释上述代码:

  1. from pypro.chapters03.demo03_数据获取与处理 import train_list, label_list, val_train_list, val_label_list

    这行代码从demo03_数据获取与处理模块中导入四个列表。这些列表包含训练数据和标签(train_list, label_list),以及验证数据和标签(val_train_list, val_label_list)。这是数据准备步骤的一部分。

  2. import tensorflow as tf

    这行代码导入了TensorFlow库,它是一个广泛用于机器学习和深度学习任务的开源库。

  3. from transformers import TFBertForSequenceClassification

    这里导入了transformers库中的TFBertForSequenceClassification类。transformers库包含了许多预训练模型,用于NLP任务,这里特别导入的是适用于TensorFlow的BERT模型,用于序列分类任务。

  4. bert_model = "bert-base-chinese"

    定义一个字符串变量bert_model,它保存了预训练模型的名称。在这里,我们将使用中文BERT基础模型。

  5. model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)

    使用bert-base-chinese模型和TFBertForSequenceClassification类创建一个新的序列分类模型实例。num_labels=32表明有32个不同的类别用于分类。

  6. model.compile(metrics=['accuracy'], loss=tf.nn.sigmoid_cross_entropy_with_logits)

    编译模型,设置度量为准确度(accuracy),并使用sigmoid_cross_entropy_with_logits作为损失函数,这通常用于二分类问题,但在这里,由于是多标签分类(32个类别),可能是对每个标签进行二分类。

  7. model.summary()

    输出模型的摘要信息,包括模型中的层,每层的输出形状和参数数量等详细信息。

  8. result = model.fit(x=train_list[:24], y=label_list[:24], batch_size=12, epochs=1)

    开始训练模型,仅使用前24个样本作为训练数据和标签。批处理大小设置为12,意味着每次梯度更新将基于12个样本。epochs=1表示整个数据集只通过模型训练一次。

  9. print(result.history)

    打印出训练过程中的历史数据,如损失和准确度。

  10. model.save_weights('../data/model.h5')

    保存训练好的模型权重到本地文件model.h5

  11. model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)

    再次初始化一个模型,用于演示如何从头加载一个模型。

  12. model.load_weights('../data/model.h5')

    加载先前保存的模型权重。

  13. result = model.predict(val_train_list[:12]) # 预测值

    使用验证数据集中的前12个样本进行预测,得到模型的输出。

  14. print(result)

    打印出预测结果。

  15. result = tf.nn.sigmoid(result)

    将模型的原始输出通过sigmoid函数转换,得到一个在0到1之间的值,表示属于每个类别的概率。

  16. print(result)

    再次打印经过sigmoid激活函数处理后的预测结果。

  17. result = tf.cast(tf.greater_equal(result, 0.5), tf.float32)

    将sigmoid输出的概率转换为二分类结果。对于每个标签,如果概率大于或等于0.5,则认为该样本属于该标签(转换为1),否则不属于(转换为0)。

  18. `print

(result)`

复制代码
最后,打印出转换后的分类结果。

整体而言,这段代码展示了使用预训练的BERT模型在一个多标签文本分类任务上的训练、保存、加载和预测的完整过程。

相关推荐
热心不起来的市民小周10 分钟前
True or False? 基于 BERT 学生数学问题误解检测
深度学习·nlp·bert
不会计算机的g_c__b1 小时前
搜索引擎评估革命:用户行为模型如何颠覆传统指标?
人工智能·自然语言处理·机器翻译
roman_日积跬步-终至千里1 小时前
【机器学习】非线性分类算法详解(下):决策树(最佳分裂特征选择的艺术)与支持向量机(最大间隔和核技巧)
决策树·机器学习·分类
萑澈2 小时前
大语言模型提示词工程详尽实战指南
人工智能·语言模型·自然语言处理
顾默@7 小时前
个人电脑部署私有化大语言模型LLM
人工智能·语言模型·自然语言处理
roman_日积跬步-终至千里8 小时前
【机器学习】两大线性分类算法:逻辑回归与线性判别分析:找到分界线的艺术
算法·机器学习·分类
向左转, 向右走ˉ8 小时前
层归一化(LayerNorm)与Batch归一化(BatchNorm):从原理到实践的深度对比
人工智能·深度学习·机器学习·分类
SugarPPig13 小时前
(二)LoRA微调BERT:为何在单分类任务中表现优异,而在多分类任务中效果不佳?
人工智能·分类·bert
Fine姐13 小时前
数据挖掘2.1&2.2 分类和线性判别器&确定线性可分性
人工智能·分类·数据挖掘
java1234_小锋16 小时前
【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 微博舆情数据可视化分析-热词情感趋势树形图
python·信息可视化·自然语言处理