NLP:微调BERT进行文本分类

本篇博客的重点在于BERT的使用。

transformers包版本:4.44.2

1. 微调BERT进行文本分类

这里我们使用stanford大学的SST2数据集来演示BERT模型的微调过程。SST-2数据集(Stanford Sentiment Treebank 2)是一个用于情感分类的经典数据集,常用于自然语言处理(NLP)领域的情感分析任务。

  • 第1步: 下载数据。其代码如下:
python 复制代码
import pandas as pd
from transformers import BertTokenizer
from datasets import DatasetDict, Dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments

splits = {'train': 'data/train-00000-of-00001.parquet', 
          'validation': 'data/validation-00000-of-00001.parquet', 
          'test': 'data/test-00000-of-00001.parquet'}
train = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["train"])
validation = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["validation"])
test = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["test"])
dataset = DatasetDict({'train': Dataset.from_pandas(train), 
                       'validation': Dataset.from_pandas(validation), 
                       'test': Dataset.from_pandas(test)})

要注意一下,这里并没有使用datasets包从hugging face上直接下载数据集的方式来获取数据,这是因为使用load_datesets方法获取数据时仍然会提示:NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported

  • 第2步: 构造训练集、验证集和测试集

SST2数据集中训练集(train)共67349条,验证集(validation)共872条,而测试集(test)共1821条。因为训练集数目较大微调会比较耗时,所以从这三个数据集分别抽取出了1000条、200条、200条进行后续的任务。具体代码如下:

python 复制代码
dataset['train'] = dataset['train'].shuffle(seed=42).select(range(1000))
dataset['validation'] = dataset['validation'].shuffle(seed=42).select(range(200))
dataset['test'] = dataset['test'].shuffle(seed=42).select(range(200))
print(dataset)

其输出结果如下:

bash 复制代码
Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 1000
})
Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 200
})
Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 200
})
  • 第3步:从bert中提取嵌入

训练集、验证集及测试集生成后,接着需要将这些语料全都转化成embedding向量。具体代码如下:

python 复制代码
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding='max_length', truncation=True)
dataset =dataset.map(tokenize_function, batched=True)
dataset=dataset.remove_columns(['sentence',"idx"])
dataset=dataset.rename_column("label","labels")
dataset.set_format("torch")
train_dataset=dataset['train']
eval_dataset=dataset['validation']
test_dataset=dataset['test']
  • 第4步:模型训练。 具体代码如下:
python 复制代码
model=BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
training_args = TrainingArguments(
    output_dir='results',         
    per_device_train_batch_size=8,  
    per_device_eval_batch_size=8,      
    num_train_epochs=1,
)
trainer = Trainer(
    model=model,                         
    args=training_args,                 
    train_dataset=train_dataset,        
    eval_dataset=eval_dataset,
)
trainer.train()
trainer.evaluate()
trainer.save_model("results")

关于上述代码,有以下几点需要说明:

  • 训练模型的选择: tranformers库中有多个分类模型,其中BertForSequenceClassification类适用于序列分类任务,比如情感分析和文本分类;而BertForTokenClassification类适用于token级的分类任务,比如命名实体识别。
  • TrainingArguments方法中的主要参数及其作用如下表所示:
参数名 作用
output_dir 指定模型和训练日志保存的记录;
num_train_epochs 设置训练的周期数(即遍历整个训练数据集的次数,指的是整个训练集将被遍历多少次以进行训练);
per_device_train_batch_size 设置每个设备(如GPU)上的训练批次大小,训练批次是指在一次训练迭代中,模型同时处理的数据样本数量;
per_device_eval_batch_size 设置每个设备上的评估批次大小;
logging_dir 指定训练日志的保存目录;
evaluation_strategy 设置评估策略。可以是 'no'(不评估)、'steps'(每隔一定步数评估)或 'epoch'(每个周期评估);
save_total_limit 设置保存模型检查点的总数限制,超过限制的检查点会被删除;
fp16 启用半精度浮点数(FP16)训练,以减少显存使用并加速训练(需要支持 FP16 的硬件);

参考资料

相关推荐
合作小小程序员小小店8 小时前
web网页,在线%抖音,舆情,线性回归%分析系统demo,基于python+web+echart+nlp+线性回归,训练,数据库mysql
python·自然语言处理·回归·nlp·线性回归
Teacher.chenchong9 小时前
GEE云端林业遥感:贯通森林分类、森林砍伐与退化监测、火灾评估、森林扰动监测、森林关键生理参数(树高/生物量/碳储量)反演等
人工智能·分类·数据挖掘
Jay200211111 小时前
【机器学习】7-9 分类任务 & 逻辑回归的成本函数 & 逻辑回归的梯度下降
笔记·机器学习·分类
WWZZ202513 小时前
快速上手大模型:深度学习13(文本预处理、语言模型、RNN、GRU、LSTM、seq2seq)
人工智能·深度学习·算法·语言模型·自然语言处理·大模型·具身智能
老友@14 小时前
RAG 的诞生:为了让 AI 不再“乱编”
人工智能·搜索引擎·ai·语言模型·自然语言处理·rag
Ma04071315 小时前
【论文阅读19】-用于PHM的大型语言模型:优化技术与应用综述
人工智能·语言模型·自然语言处理
斯外戈的小白1 天前
【NLP】基础概念+RNN架构
rnn·自然语言处理·分类
F***c3251 天前
React自然语言处理应用
前端·react.js·自然语言处理
MicroTech20251 天前
MLGO微算法科技时空卷积与双重注意机制驱动的脑信号多任务分类算法
科技·算法·分类
智算菩萨2 天前
走向通用智能的大语言模型:具身、符号落地、因果与记忆的统一认知视角
人工智能·语言模型·自然语言处理