# 基于BERT的文本分类

基于BERT的文本分类项目的实现

一、项目背景

该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:


二、数据集介绍

2.1 数据集基本信息

数据集 自定义
类型 二分类(正面/负面)
样本量 训练集 + 验证集 + 测试集
文本长度 平均x字(最大x字)
领域 商品评论、影视评论
python 复制代码
# 加载数据集
dataset = pd.read_csv('data/train.txt', sep='\t')
print(dataset['train'][0])
# 输出:{'text': '这个手机性价比超高,拍照效果惊艳!', 'label': 1}

2.2 数据分析

2.2.1 句子长度分布
python 复制代码
import matplotlib.pyplot as plt

def analyze_length(texts):
    lengths = [len(t) for t in texts]
    plt.figure(figsize=(12,5))
    plt.hist(lengths, bins=30, range=(0,256), color='blue', alpha=0.7)
    plt.title("文本长度分布", fontsize=14)
    plt.xlabel("字符数")
    plt.ylabel("样本量")
    plt.show()

analyze_length(dataset['train']['text'])
2.2.2 标签分布
python 复制代码
import pandas as pd

pd.Series(dataset['train']['label']).value_counts().plot(
    kind='pie',
    autopct='%1.1f%%',
    title='类别分布(0-负面 1-正面)'
)
plt.show()
2.2.3 类别平衡处理
python 复制代码
from torch.utils.data import WeightedRandomSampler

# 计算类别权重
labels = dataset['train']['label']
class_weights = 1 / torch.Tensor([len(labels)-sum(labels), sum(labels)])
sampler = WeightedRandomSampler(
    weights=[class_weights[label] for label in labels],
    num_samples=len(labels),
    replacement=True
)

三、数据处理

3.1 BERT分词器

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

def collate_fn(batch):
    texts = [item['text'] for item in batch]
    labels = [item['label'] for item in batch]
    
    # BERT编码
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': torch.LongTensor(labels)
    }

3.2 数据加载器

python 复制代码
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset['train'],
    batch_size=32,
    collate_fn=collate_fn,
    sampler=sampler
)

val_loader = DataLoader(
    dataset['validation'],
    batch_size=32,
    collate_fn=collate_fn
)

四、模型构建

4.1 BERT分类模型

python 复制代码
import torch.nn as nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        return self.fc(pooled)

4.2 模型配置

python 复制代码
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

五、模型训练与验证

5.1 训练流程

python 复制代码
from tqdm import tqdm

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    for batch in tqdm(loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

5.2 验证流程

python 复制代码
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

六、实验结果

6.1 评估指标

Epoch 训练Loss 验证准确率 测试准确率
python 复制代码
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(loader):
    y_true = []
    y_pred = []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()

plot_confusion_matrix(test_loader)

6.2 学习曲线

python 复制代码
# 记录训练过程
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
for epoch in range(3):
    train_loss = train_epoch(model, train_loader)
    val_acc = evaluate(model, val_loader)
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)

七、流程架构图

原始文本 分词编码 BERT特征提取 全连接分类 损失计算 反向传播 模型评估


相关推荐
lucky_lyovo5 分钟前
卷积神经网络--网络性能提升
人工智能·神经网络·cnn
liliangcsdn9 分钟前
smolagents - 如何在mac用agents做简单算术题
人工智能·macos·prompt
nju_spy13 分钟前
周志华《机器学习导论》第8章 集成学习 Ensemble Learning
人工智能·随机森林·机器学习·集成学习·boosting·bagging·南京大学
静心问道37 分钟前
TrOCR: 基于Transformer的光学字符识别方法,使用预训练模型
人工智能·深度学习·transformer·多模态
说私域39 分钟前
基于开源AI大模型、AI智能名片与S2B2C商城小程序源码的用户价值引导与核心用户沉淀策略研究
人工智能·开源
亲持红叶40 分钟前
GLU 变种:ReGLU 、 GEGLU 、 SwiGLU
人工智能·深度学习·神经网络·激活函数
说私域40 分钟前
线上协同办公时代:以开源AI大模型等工具培养网感,拥抱职业变革
人工智能·开源
群联云防护小杜42 分钟前
深度隐匿源IP:高防+群联AI云防护防绕过实战
运维·服务器·前端·网络·人工智能·网络协议·tcp/ip
摘星编程1 小时前
构建智能客服Agent:从需求分析到生产部署
人工智能·需求分析·智能客服·agent开发·生产部署
不爱学习的YY酱1 小时前
信息检索革命:Perplexica+cpolar打造你的专属智能搜索中枢
人工智能