# 基于BERT的文本分类

基于BERT的文本分类项目的实现

一、项目背景

该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:


二、数据集介绍

2.1 数据集基本信息

数据集 自定义
类型 二分类(正面/负面)
样本量 训练集 + 验证集 + 测试集
文本长度 平均x字(最大x字)
领域 商品评论、影视评论
python 复制代码
# 加载数据集
dataset = pd.read_csv('data/train.txt', sep='\t')
print(dataset['train'][0])
# 输出:{'text': '这个手机性价比超高,拍照效果惊艳!', 'label': 1}

2.2 数据分析

2.2.1 句子长度分布
python 复制代码
import matplotlib.pyplot as plt

def analyze_length(texts):
    lengths = [len(t) for t in texts]
    plt.figure(figsize=(12,5))
    plt.hist(lengths, bins=30, range=(0,256), color='blue', alpha=0.7)
    plt.title("文本长度分布", fontsize=14)
    plt.xlabel("字符数")
    plt.ylabel("样本量")
    plt.show()

analyze_length(dataset['train']['text'])
2.2.2 标签分布
python 复制代码
import pandas as pd

pd.Series(dataset['train']['label']).value_counts().plot(
    kind='pie',
    autopct='%1.1f%%',
    title='类别分布(0-负面 1-正面)'
)
plt.show()
2.2.3 类别平衡处理
python 复制代码
from torch.utils.data import WeightedRandomSampler

# 计算类别权重
labels = dataset['train']['label']
class_weights = 1 / torch.Tensor([len(labels)-sum(labels), sum(labels)])
sampler = WeightedRandomSampler(
    weights=[class_weights[label] for label in labels],
    num_samples=len(labels),
    replacement=True
)

三、数据处理

3.1 BERT分词器

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

def collate_fn(batch):
    texts = [item['text'] for item in batch]
    labels = [item['label'] for item in batch]
    
    # BERT编码
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': torch.LongTensor(labels)
    }

3.2 数据加载器

python 复制代码
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset['train'],
    batch_size=32,
    collate_fn=collate_fn,
    sampler=sampler
)

val_loader = DataLoader(
    dataset['validation'],
    batch_size=32,
    collate_fn=collate_fn
)

四、模型构建

4.1 BERT分类模型

python 复制代码
import torch.nn as nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        return self.fc(pooled)

4.2 模型配置

python 复制代码
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

五、模型训练与验证

5.1 训练流程

python 复制代码
from tqdm import tqdm

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    for batch in tqdm(loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

5.2 验证流程

python 复制代码
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

六、实验结果

6.1 评估指标

Epoch 训练Loss 验证准确率 测试准确率
python 复制代码
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(loader):
    y_true = []
    y_pred = []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()

plot_confusion_matrix(test_loader)

6.2 学习曲线

python 复制代码
# 记录训练过程
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
for epoch in range(3):
    train_loss = train_epoch(model, train_loader)
    val_acc = evaluate(model, val_loader)
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)

七、流程架构图

原始文本 分词编码 BERT特征提取 全连接分类 损失计算 反向传播 模型评估


相关推荐
riri191918 分钟前
机器学习:支持向量机(SVM)原理解析及垃圾邮件过滤实战
人工智能·机器学习·支持向量机
从零开始学习人工智能24 分钟前
深入解析支撑向量机(SVM):原理、推导与实现
人工智能·机器学习·支持向量机
小猪猪_129 分钟前
神经网络与深度学习(第一章)
人工智能·深度学习·神经网络
土豆宝31 分钟前
AI玩游戏的一点尝试(5)—— 多样化的数字识别
人工智能·游戏
deephub34 分钟前
BayesFlow:基于神经网络的摊销贝叶斯推断框架
人工智能·python·深度学习·神经网络·机器学习·贝叶斯
DFminer1 小时前
【仿生机器人】机器人情绪系统的深度解析
人工智能·机器人
superior tigre1 小时前
神经网络基础:从单个神经元到多层网络(superior哥AI系列第3期)
网络·人工智能·神经网络
IT_陈寒1 小时前
开发者必看!5个VSCode隐藏技巧让你的编码效率提升200% 🚀
前端·人工智能·后端
老司机的新赛道1 小时前
吴恩达:构建自动化评估并不需要大量投入,从一些简单快速的示例入手,然后逐步迭代!
人工智能·ai·agent·智能体
东临碣石821 小时前
【AI论文】论文转海报:迈向从科学论文到多模态海报的自动化生成
运维·人工智能·自动化