基于BERT的文本分类项目的实现
一、项目背景
该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:
二、数据集介绍
2.1 数据集基本信息
数据集 | 自定义 |
---|---|
类型 | 二分类(正面/负面) |
样本量 | 训练集 + 验证集 + 测试集 |
文本长度 | 平均x字(最大x字) |
领域 | 商品评论、影视评论 |
python
# 加载数据集
dataset = pd.read_csv('data/train.txt', sep='\t')
print(dataset['train'][0])
# 输出:{'text': '这个手机性价比超高,拍照效果惊艳!', 'label': 1}
2.2 数据分析
2.2.1 句子长度分布
python
import matplotlib.pyplot as plt
def analyze_length(texts):
lengths = [len(t) for t in texts]
plt.figure(figsize=(12,5))
plt.hist(lengths, bins=30, range=(0,256), color='blue', alpha=0.7)
plt.title("文本长度分布", fontsize=14)
plt.xlabel("字符数")
plt.ylabel("样本量")
plt.show()
analyze_length(dataset['train']['text'])
2.2.2 标签分布
python
import pandas as pd
pd.Series(dataset['train']['label']).value_counts().plot(
kind='pie',
autopct='%1.1f%%',
title='类别分布(0-负面 1-正面)'
)
plt.show()
2.2.3 类别平衡处理
python
from torch.utils.data import WeightedRandomSampler
# 计算类别权重
labels = dataset['train']['label']
class_weights = 1 / torch.Tensor([len(labels)-sum(labels), sum(labels)])
sampler = WeightedRandomSampler(
weights=[class_weights[label] for label in labels],
num_samples=len(labels),
replacement=True
)
三、数据处理
3.1 BERT分词器
python
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
def collate_fn(batch):
texts = [item['text'] for item in batch]
labels = [item['label'] for item in batch]
# BERT编码
inputs = tokenizer(
texts,
padding=True,
truncation=True,
max_length=256,
return_tensors='pt'
)
return {
'input_ids': inputs['input_ids'],
'attention_mask': inputs['attention_mask'],
'labels': torch.LongTensor(labels)
}
3.2 数据加载器
python
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset['train'],
batch_size=32,
collate_fn=collate_fn,
sampler=sampler
)
val_loader = DataLoader(
dataset['validation'],
batch_size=32,
collate_fn=collate_fn
)
四、模型构建
4.1 BERT分类模型
python
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-chinese')
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask)
pooled = self.dropout(outputs.pooler_output)
return self.fc(pooled)
4.2 模型配置
python
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()
五、模型训练与验证
5.1 训练流程
python
from tqdm import tqdm
def train_epoch(model, loader):
model.train()
total_loss = 0
for batch in tqdm(loader):
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
5.2 验证流程
python
def evaluate(model, loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
preds = torch.argmax(outputs, dim=1)
correct += (preds == labels).sum().item()
total += len(labels)
return correct / total
六、实验结果
6.1 评估指标
Epoch | 训练Loss | 验证准确率 | 测试准确率 |
---|
python
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(loader):
y_true = []
y_pred = []
model.eval()
with torch.no_grad():
for batch in loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
preds = torch.argmax(outputs, dim=1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('混淆矩阵')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.show()
plot_confusion_matrix(test_loader)
6.2 学习曲线
python
# 记录训练过程
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
for epoch in range(3):
train_loss = train_epoch(model, train_loader)
val_acc = evaluate(model, val_loader)
writer.add_scalar('Loss/Train', train_loss, epoch)
writer.add_scalar('Accuracy/Validation', val_acc, epoch)
七、流程架构图
原始文本 分词编码 BERT特征提取 全连接分类 损失计算 反向传播 模型评估