基于 BERT 实现文本相似度打分:完整项目实战

文本相似度计算是自然语言处理(NLP)中的经典任务,而 STS-B(Semantic Textual Similarity Benchmark)数据集是该任务的主流评测基准 ------ 它要求模型对两个句子的语义相似度进行 0-5 的连续值打分(0 完全不相似,5 完全相同)。本文将分享如何基于 BERT 模型实现 STS-B 文本相似度打分的完整项目,涵盖模块化设计、数据处理、模型构建、训练优化及实战问题解决,最终产出可直接运行的工业级代码。

一、任务背景与项目目标

1. STS-B 任务介绍

STS-B 数据集共三列,第一列句子1,第二列句子2,第三列标签为 0-5 的相似度分数,属于回归任务(区别于分类任务)。train-5230,test-1361,valid-1458,任务核心是让模型学习句子对的语义关联程度,评估指标为皮尔逊相关系数(Pearson Correlation)和斯皮尔曼相关系数(Spearman Correlation)。

2. 为什么选择 BERT?

BERT(Bidirectional Encoder Representations from Transformers)作为预训练语言模型,能捕捉句子的深层语义特征,相比传统的 TF-IDF、Word2Vec 等方法,在语义相似度任务上具有压倒性优势。我们将 BERT 的分类头替换为回归头直接输出连续的相似度分数。

3. 项目目标

  • 采用模块化设计,拆分代码为配置、数据处理、模型、训练、预测等模块;
  • 支持已划分好的无表头 STS-B 数据集(train/valid/test);
  • 实现模型训练、验证、测试全流程;
  • 解决实战中的常见问题(如 Tokenizer 加载、数据集读取、警告处理等);
  • 提供单条句子对的相似度预测功能。

二、环境准备

首先安装项目所需依赖库,建议使用虚拟环境隔离依赖:

bash 复制代码
# 核心依赖
pip install torch transformers pandas scikit-learn tqdm
# 可选:若需可视化训练过程
pip install matplotlib seaborn
  • torch:PyTorch 深度学习框架,用于模型构建与训练;
  • transformers:HuggingFace 库,提供预训练 BERT 模型和 Tokenizer;
  • pandas:数据处理;
  • scikit-learn:评估指标计算(皮尔逊 / 斯皮尔曼相关系数);
  • tqdm:训练进度条可视化。

三、项目结构设计

为保证代码的可维护性和扩展性,采用模块化拆分,最终项目结构如下:

复制代码
sts-b-bert-similarity/
├── config.py          # 全局配置(超参数、路径、模型名)
├── preprocessor.py    # 数据加载、编码、Dataset/DataLoader构建
├── model.py           # BERT回归模型定义、优化器配置
├── train_eval.py      # 训练循环、验证、测试逻辑
├── predict.py         # 单条句子对相似度预测
├── main.py            # 项目主入口(整合所有模块)
├── STS-B.train.data   # 训练集(无表头,三列:句子1/句子2/分数)
├── STS-B.valid.data   # 验证集
├── STS-B.test.data    # 测试集
└── best-sts-b-model.pt# 训练后保存的最优模型(自动生成)

模块化设计的优势:

  • 配置集中管理,无需修改业务代码即可调整超参数;
  • 各模块职责单一,便于调试和扩展(如更换模型、新增数据增强);
  • 代码复用性高,后续可快速迁移到其他文本相似度任务。

四、核心模块实现

1. 配置模块(config.py

集中管理所有超参数和路径,避免硬编码:

python 复制代码
import torch

# 模型相关配置
MODEL_NAME = "bert-base-uncased"  # 英文预训练模型(中文可选bert-base-chinese)
NUM_LABELS = 1  # 回归任务,输出1个连续值
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # 优先使用GPU

# 数据相关配置
TRAIN_DATA_PATH = "STS-B.train.data"  # 训练集路径
VAL_DATA_PATH = "STS-B.valid.data"    # 验证集路径
TEST_DATA_PATH = "STS-B.test.data"    # 测试集路径
MAX_LENGTH = 128  # 文本最大长度(BERT输入限制512,按需调整)
RANDOM_STATE = 42  # 随机种子(保证结果可复现)

# 训练相关配置
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5  # BERT经典学习率
EPS = 1e-8  # AdamW优化器epsilon
WARMUP_STEPS = 0  # 学习率预热步数

# 路径配置
SAVE_MODEL_PATH = "best-sts-b-model.pt"  # 最优模型保存路径

2. 数据处理模块(preprocessor.py

处理无表头数据集的加载、BERT Tokenizer 编码、Dataset 和 DataLoader 构建,这是项目的核心难点之一。

(1)无表头数据集加载

STS-B 数据集无列名,直接按列索引读取(第 0 列 = 句子 1,第 1 列 = 句子 2,第 2 列 = 分数),并适配制表符分隔(TSV)格式:

python 复制代码
import pandas as pd
import torch
from transformers import BertTokenizer
from config import *

def load_and_split_data():
    """加载已划分好的无表头数据集"""
    # 读取无表头数据,sep='\t'适配STS-B默认的制表符分隔
    train_df = pd.read_csv(TRAIN_DATA_PATH, header=None, sep='\t')
    val_df = pd.read_csv(VAL_DATA_PATH, header=None, sep='\t')
    test_df = pd.read_csv(TEST_DATA_PATH, header=None, sep='\t')
    
    # 按列索引提取数据(0=句子1,1=句子2,2=分数)
    train_data = (
        train_df.iloc[:, 0].tolist(),
        train_df.iloc[:, 1].tolist(),
        train_df.iloc[:, 2].tolist()
    )
    val_data = (
        val_df.iloc[:, 0].tolist(),
        val_df.iloc[:, 1].tolist(),
        val_df.iloc[:, 2].tolist()
    )
    test_data = (
        test_df.iloc[:, 0].tolist(),
        test_df.iloc[:, 1].tolist(),
        test_df.iloc[:, 2].tolist()
    )
    return train_data, val_data, test_data
(2)BERT Tokenizer 编码

BERT 要求句子对拼接为[CLS] sentence1 [SEP] sentence2 [SEP]格式,生成input_ids(token ID)、attention_mask(掩码)、token_type_ids(句子区分标识):

python 复制代码
def encode_texts(texts1, texts2, tokenizer):
    """对句子对进行BERT编码"""
    encodings = tokenizer(
        texts1, texts2,
        truncation=True,  # 截断过长文本
        padding="max_length",  # 填充到MAX_LENGTH
        max_length=MAX_LENGTH,
        return_tensors="pt"  # 返回PyTorch张量
    )
    return encodings
(3)Dataset 与 DataLoader 构建

将编码后的特征和标签封装为 PyTorch Dataset,便于批量加载:

python 复制代码
class STSBDataset(torch.utils.data.Dataset):
    """STS-B数据集类"""
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

def build_dataloaders():
    """构建DataLoader(数据加载器)"""
    # 加载Tokenizer
    tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
    # 加载数据
    train_data, val_data, test_data = load_and_split_data()
    train_texts1, train_texts2, train_labels = train_data
    val_texts1, val_texts2, val_labels = val_data
    test_texts1, test_texts2, test_labels = test_data

    # 编码文本
    train_encodings = encode_texts(train_texts1, train_texts2, tokenizer)
    val_encodings = encode_texts(val_texts1, val_texts2, tokenizer)
    test_encodings = encode_texts(test_texts1, test_texts2, tokenizer)

    # 转换标签为浮点型张量(适配回归任务)
    train_labels = torch.tensor(train_labels, dtype=torch.float32)
    val_labels = torch.tensor(val_labels, dtype=torch.float32)
    test_labels = torch.tensor(test_labels, dtype=torch.float32)

    # 构建Dataset
    train_dataset = STSBDataset(train_encodings, train_labels)
    val_dataset = STSBDataset(val_encodings, val_labels)
    test_dataset = STSBDataset(test_encodings, test_labels)

    # 构建DataLoader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, val_loader, test_loader, tokenizer

3. 模型模块(model.py

定义 BERT 回归模型(将分类头改为回归头),配置优化器和学习率调度器:

python 复制代码
import torch
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
import torch.optim as optim
from config import *

def build_model():
    """构建BERT回归模型"""
    # 加载BERT模型,num_labels=1表示回归任务
    model = BertForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=NUM_LABELS
    )
    model.to(DEVICE)  # 移到指定设备(GPU/CPU)
    return model

def build_optimizer_scheduler(model, train_loader):
    """构建优化器和学习率调度器"""
    # 使用PyTorch原生AdamW(替代transformers的弃用版本)
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        eps=EPS
    )
    # 学习率调度器(线性预热+衰减)
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=WARMUP_STEPS,
        num_training_steps=total_steps
    )
    # 回归任务损失函数:均方误差(MSE)
    loss_fn = torch.nn.MSELoss()
    return optimizer, scheduler, loss_fn

4. 训练验证模块(train_eval.py)

实现训练循环、验证逻辑、早停机制(防止过拟合),并计算核心评估指标:

python 复制代码
import torch
import numpy as np
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from config import *

def validate(model, val_loader, loss_fn):
    """验证函数:计算损失和评估指标"""
    model.eval()  # 评估模式(关闭Dropout)
    val_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():  # 禁用梯度计算(加速+节省显存)
        for batch in val_loader:
            # 数据移到设备
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            token_type_ids = batch["token_type_ids"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            
            # 前向传播
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
            logits = outputs.logits.squeeze()  # 去掉维度1(回归输出为标量)
            loss = loss_fn(logits, labels)
            val_loss += loss.item()
            
            # 收集预测值和真实标签
            all_preds.extend(logits.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 计算平均损失、皮尔逊/斯皮尔曼相关系数
    avg_val_loss = val_loss / len(val_loader)
    pearson_corr = pearsonr(all_preds, all_labels)[0]
    spearman_corr = spearmanr(all_preds, all_labels)[0]
    return avg_val_loss, pearson_corr, spearman_corr

def train_model(model, train_loader, val_loader, optimizer, scheduler, loss_fn):
    """训练函数:带早停机制"""
    best_pearson = 0.0  # 最优验证集皮尔逊系数
    patience = 2  # 早停耐心值(连续2轮无提升则停止)
    patience_counter = 0
    
    for epoch in range(EPOCHS):
        model.train()  # 训练模式(开启Dropout)
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch in progress_bar:
            # 数据移到设备
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            token_type_ids = batch["token_type_ids"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            
            # 清零梯度
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )
            logits = outputs.logits.squeeze()
            loss = loss_fn(logits, labels)
            
            # 反向传播+优化
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # 累计损失
            train_loss += loss.item()
            progress_bar.set_postfix({"train_loss": loss.item()})
        
        # 计算训练集平均损失
        avg_train_loss = train_loss / len(train_loader)
        
        # 验证集评估
        avg_val_loss, pearson_corr, spearman_corr = validate(model, val_loader, loss_fn)
        
        # 打印日志
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"Val Pearson Corr: {pearson_corr:.4f} | Val Spearman Corr: {spearman_corr:.4f}")
        
        # 保存最优模型+早停判断
        if pearson_corr > best_pearson:
            best_pearson = pearson_corr
            torch.save(model.state_dict(), SAVE_MODEL_PATH)
            print(f"Best model saved (Pearson: {best_pearson:.4f})")
            patience_counter = 0  # 重置耐心计数器
        else:
            patience_counter += 1
            print(f"Patience counter: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping! No improvement in validation Pearson correlation.")
                break  # 停止训练

def test_model(model, test_loader, loss_fn):
    """测试函数:评估最优模型在测试集上的性能"""
    model.load_state_dict(torch.load(SAVE_MODEL_PATH))  # 加载最优模型
    model.to(DEVICE)
    test_loss, test_pearson, test_spearman = validate(model, test_loader, loss_fn)
    print("\n==================== Test Set Results ====================")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Pearson Corr: {test_pearson:.4f} | Test Spearman Corr: {test_spearman:.4f}")
    return test_loss, test_pearson, test_spearman

5. 预测模块(predict.py

实现单条句子对的相似度预测,限制输出范围在 0-5 之间:

python 复制代码
import torch
from transformers import BertTokenizer
from config import *

def predict_similarity(sentence1, sentence2, model, tokenizer):
    """单条句子对相似度预测"""
    model.eval()
    # 编码句子对
    encodings = tokenizer(
        [sentence1], [sentence2],
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )
    # 数据移到设备
    input_ids = encodings["input_ids"].to(DEVICE)
    attention_mask = encodings["attention_mask"].to(DEVICE)
    token_type_ids = encodings["token_type_ids"].to(DEVICE)
    
    # 预测
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        score = outputs.logits.squeeze().cpu().numpy()
        # 限制分数在0-5之间(防止预测值超出范围)
        score = max(0.0, min(5.0, score))
    return score

6. 主入口(main.py

整合所有模块,实现一键运行:

python 复制代码
import os
# 关闭huggingface_hub的symlinks警告
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

from preprocessor import build_dataloaders
from model import build_model, build_optimizer_scheduler
from train_eval import train_model, test_model
from predict import predict_similarity
from config import *

def main():
    # 1. 数据处理:构建DataLoader和Tokenizer
    print("========== Loading and Preprocessing Data ==========")
    train_loader, val_loader, test_loader, tokenizer = build_dataloaders()
    
    # 2. 构建模型、优化器、损失函数
    print("\n========== Building Model ==========")
    model = build_model()
    optimizer, scheduler, loss_fn = build_optimizer_scheduler(model, train_loader)
    
    # 3. 训练模型
    print("\n========== Starting Training ==========")
    train_model(model, train_loader, val_loader, optimizer, scheduler, loss_fn)
    
    # 4. 测试模型
    print("\n========== Starting Testing ==========")
    test_model(model, test_loader, loss_fn)
    
    # 5. 示例预测
    print("\n========== Example Prediction ==========")
    model.load_state_dict(torch.load(SAVE_MODEL_PATH))  # 加载最优模型
    # 测试句子对
    sent1 = "A cat is chasing a mouse."
    sent2 = "A feline is pursuing a rodent."
    score = predict_similarity(sent1, sent2, model, tokenizer)
    print(f"Sentence 1: {sent1}")
    print(f"Sentence 2: {sent2}")
    print(f"Predicted Similarity Score: {score:.2f}")

if __name__ == "__main__":
    main()

五、实战问题与解决

在项目开发过程中,我遇到了多个典型问题,以下是解决方案:

1. Tokenizer 加载失败(OSError: Can't load tokenizer for 'bert-base-uncased')

  • 原因:网络问题导致无法从 Hugging Face 服务器下载 Tokenizer,或本地存在同名文件夹冲突。
  • 解决
    • 方法 1:科学上网(推荐);
    • 方法 2:手动下载bert-base-uncasedvocab.txt到本地文件夹(如./bert-vocab/),修改config.py中的MODEL_NAME为本地路径。

2. 数据集 KeyError(无表头处理)

  • 原因:代码中使用列名读取数据,但数据集无表头。
  • 解决 :使用pd.read_csv(header=None)读取无表头数据,通过列索引(iloc[:, 0])而非列名获取数据。

3. 训练中的警告处理

  • AdamW 弃用警告 :替换为 PyTorch 原生torch.optim.AdamW
  • overflowing tokens 警告 :属于无害提示(文本截断正常),可显式指定truncation_strategy="longest_first"消除;
  • symlinks 警告 :设置环境变量HF_HUB_DISABLE_SYMLINKS_WARNING=1关闭。

4. GPU 训练加速

  • 若有 NVIDIA GPU,安装 CUDA 和 cuDNN,确保torch.cuda.is_available()返回True,训练速度可提升 10-20 倍;
  • 若无 GPU,调小BATCH_SIZE(如 8)减少 CPU 计算压力。

六、结果分析

1. 训练指标

使用bert-base-uncased模型训练 5 轮(早停触发于第 3 轮),最终测试集指标:

  • Test Loss: 0.3215
  • Test Pearson Corr: 0.8523(达到行业主流水平)
  • Test Spearman Corr: 0.8417

2. 示例预测

输入句子对:

  • Sentence 1: "A cat is chasing a mouse."
  • Sentence 2: "A feline is pursuing a rodent."

预测相似度分数:4.72(接近满分 5,符合语义)。

七、优化方向

1. 模型升级

  • 更换更大的预训练模型(如bert-large-uncasedroberta-base),进一步提升精度;
  • 采用微调策略(如 LoRA),减少显存占用,加速训练。

2. 超参数调优

  • 调整MAX_LENGTH(如 256)适配长句子;
  • 尝试不同学习率(3e-5、5e-5)、批次大小(32);
  • 使用网格搜索 / 贝叶斯优化自动调优超参数。

3. 数据增强

  • 英文:同义词替换(WordNet)、句子重排、回译;
  • 中文:同义词替换(哈工大词林)、随机插入 / 删除停用词。

4. 工程优化

  • 添加日志模块(logging),保存训练日志到文件;
  • 实现模型导出(ONNX 格式),支持线上部署;
  • 构建可视化界面(Gradio/Streamlit),方便非技术人员使用。

八、总结

本文从模块化设计出发,实现了基于 BERT 的 STS-B 文本相似度打分项目,涵盖数据处理、模型构建、训练优化全流程,并解决了实战中的典型问题。项目代码具有高可维护性和扩展性,可直接迁移到其他文本相似度任务(如问答匹配、舆情相似度)。

BERT 作为预训练语言模型,在语义理解任务上的优势显著,通过简单的头层改造即可适配回归任务。未来可结合最新的大模型技术(如 LLM 微调),进一步提升模型性能和泛化能力。

附:完整代码仓库

所有代码已整理至 GitHub 仓库(示例地址):https://github.com/jiapengLi11/BERT-similarity,包含数据集说明、运行脚本和详细注释,欢迎 Star 和 Fork。

相关推荐
深圳南柯电子16 小时前
结构线束EMC整改:从原理到实践的技术解决方案|深圳南柯电子
网络·人工智能·互联网·实验室·emc
桃子叔叔16 小时前
论文解析:CONSISTENCY-GUIDED PROMPT LEARNING FOR VISION-LANGUAGE MODELS
人工智能·语言模型·prompt
沃达德软件17 小时前
智慧警务与数据分析
大数据·人工智能·信息可视化·数据挖掘·数据分析
再__努力1点17 小时前
【59】3D尺度不变特征变换(SIFT3D):医学影像关键点检测的核心算法与实现
人工智能·python·算法·计算机视觉·3d
Eloudy17 小时前
06章 矢量ALU运算 - “Vega“ 7nm Instruction Set ArchitectureReference Guide
人工智能·gpu·arch
渡我白衣17 小时前
AI应用层革命(五)——智能体的自主演化:从工具到生命
人工智能·神经网络·机器学习·计算机视觉·目标跟踪·自然语言处理·知识图谱
小白量化17 小时前
量化研究--上线完成强大的金融数据库3.0系统
数据库·人工智能·python·算法·金融·量化·qmt
腾飞开源17 小时前
31_Spring AI 干货笔记之嵌入模型 Amazon Bedrock
人工智能·amazon bedrock·嵌入模型·spring ai·converse api·cohere嵌入·titan嵌入
一碗白开水一17 小时前
【论文阅读】DALL-E 123系列论文概述
论文阅读·人工智能·pytorch·深度学习·算法