# 基于BERT的文本分类

基于BERT的文本分类项目的实现

一、项目背景

该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:


二、数据集介绍

2.1 数据集基本信息

数据集 自定义
类型 二分类(正面/负面)
样本量 训练集 + 验证集 + 测试集
文本长度 平均x字(最大x字)
领域 商品评论、影视评论
python 复制代码
# 加载数据集
dataset = pd.read_csv('data/train.txt', sep='\t')
print(dataset['train'][0])
# 输出:{'text': '这个手机性价比超高,拍照效果惊艳!', 'label': 1}

2.2 数据分析

2.2.1 句子长度分布
python 复制代码
import matplotlib.pyplot as plt

def analyze_length(texts):
    lengths = [len(t) for t in texts]
    plt.figure(figsize=(12,5))
    plt.hist(lengths, bins=30, range=(0,256), color='blue', alpha=0.7)
    plt.title("文本长度分布", fontsize=14)
    plt.xlabel("字符数")
    plt.ylabel("样本量")
    plt.show()

analyze_length(dataset['train']['text'])
2.2.2 标签分布
python 复制代码
import pandas as pd

pd.Series(dataset['train']['label']).value_counts().plot(
    kind='pie',
    autopct='%1.1f%%',
    title='类别分布(0-负面 1-正面)'
)
plt.show()
2.2.3 类别平衡处理
python 复制代码
from torch.utils.data import WeightedRandomSampler

# 计算类别权重
labels = dataset['train']['label']
class_weights = 1 / torch.Tensor([len(labels)-sum(labels), sum(labels)])
sampler = WeightedRandomSampler(
    weights=[class_weights[label] for label in labels],
    num_samples=len(labels),
    replacement=True
)

三、数据处理

3.1 BERT分词器

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

def collate_fn(batch):
    texts = [item['text'] for item in batch]
    labels = [item['label'] for item in batch]
    
    # BERT编码
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': torch.LongTensor(labels)
    }

3.2 数据加载器

python 复制代码
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset['train'],
    batch_size=32,
    collate_fn=collate_fn,
    sampler=sampler
)

val_loader = DataLoader(
    dataset['validation'],
    batch_size=32,
    collate_fn=collate_fn
)

四、模型构建

4.1 BERT分类模型

python 复制代码
import torch.nn as nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        return self.fc(pooled)

4.2 模型配置

python 复制代码
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

五、模型训练与验证

5.1 训练流程

python 复制代码
from tqdm import tqdm

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    for batch in tqdm(loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

5.2 验证流程

python 复制代码
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

六、实验结果

6.1 评估指标

Epoch 训练Loss 验证准确率 测试准确率
python 复制代码
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(loader):
    y_true = []
    y_pred = []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()

plot_confusion_matrix(test_loader)

6.2 学习曲线

python 复制代码
# 记录训练过程
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
for epoch in range(3):
    train_loss = train_epoch(model, train_loader)
    val_acc = evaluate(model, val_loader)
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)

七、流程架构图

原始文本 分词编码 BERT特征提取 全连接分类 损失计算 反向传播 模型评估


相关推荐
风筝超冷19 分钟前
Seq2Seq - 编码器(Encoder)和解码器(Decoder)
人工智能·深度学习·seq2seq
uncle_ll22 分钟前
李宏毅NLP-3-语音识别part2-LAS
人工智能·自然语言处理·语音识别·las
helloworld工程师25 分钟前
Spring AI应用:利用DeepSeek+嵌入模型+Milvus向量数据库实现检索增强生成--RAG应用(超详细)
人工智能·spring·milvus
終不似少年遊*2 小时前
【NLP解析】多头注意力+掩码机制+位置编码:Transformer三大核心技术详解
人工智能·自然语言处理·大模型·nlp·transformer·注意力机制
清岚_lxn5 小时前
原生SSE实现AI智能问答+Vue3前端打字机流效果
前端·javascript·人工智能·vue·ai问答
_一条咸鱼_7 小时前
大厂AI 大模型面试:注意力机制原理深度剖析
人工智能·深度学习·机器学习
FIT2CLOUD飞致云7 小时前
四月月报丨MaxKB正在被能源、交通、金属矿产等行业企业广泛采纳
人工智能·开源
_一条咸鱼_7 小时前
大厂AI大模型面试:泛化能力原理
人工智能·深度学习·机器学习
Amor风信子7 小时前
【大模型微调】如何解决llamaFactory微调效果与vllm部署效果不一致如何解决
人工智能·学习·vllm
Jamence7 小时前
多模态大语言模型arxiv论文略读(十五)
人工智能·语言模型·自然语言处理