文本相似度计算是自然语言处理(NLP)中的经典任务,而 STS-B(Semantic Textual Similarity Benchmark)数据集是该任务的主流评测基准 ------ 它要求模型对两个句子的语义相似度进行 0-5 的连续值打分(0 完全不相似,5 完全相同)。本文将分享如何基于 BERT 模型实现 STS-B 文本相似度打分的完整项目,涵盖模块化设计、数据处理、模型构建、训练优化及实战问题解决,最终产出可直接运行的工业级代码。
一、任务背景与项目目标
1. STS-B 任务介绍
STS-B 数据集共三列,第一列句子1,第二列句子2,第三列标签为 0-5 的相似度分数,属于回归任务(区别于分类任务)。train-5230,test-1361,valid-1458,任务核心是让模型学习句子对的语义关联程度,评估指标为皮尔逊相关系数(Pearson Correlation)和斯皮尔曼相关系数(Spearman Correlation)。
2. 为什么选择 BERT?
BERT(Bidirectional Encoder Representations from Transformers)作为预训练语言模型,能捕捉句子的深层语义特征,相比传统的 TF-IDF、Word2Vec 等方法,在语义相似度任务上具有压倒性优势。我们将 BERT 的分类头替换为回归头 ,直接输出连续的相似度分数。
3. 项目目标
- 采用模块化设计,拆分代码为配置、数据处理、模型、训练、预测等模块;
- 支持已划分好的无表头 STS-B 数据集(train/valid/test);
- 实现模型训练、验证、测试全流程;
- 解决实战中的常见问题(如 Tokenizer 加载、数据集读取、警告处理等);
- 提供单条句子对的相似度预测功能。
二、环境准备
首先安装项目所需依赖库,建议使用虚拟环境隔离依赖:
bash
# 核心依赖
pip install torch transformers pandas scikit-learn tqdm
# 可选:若需可视化训练过程
pip install matplotlib seaborn
torch:PyTorch 深度学习框架,用于模型构建与训练;transformers:HuggingFace 库,提供预训练 BERT 模型和 Tokenizer;pandas:数据处理;scikit-learn:评估指标计算(皮尔逊 / 斯皮尔曼相关系数);tqdm:训练进度条可视化。
三、项目结构设计
为保证代码的可维护性和扩展性,采用模块化拆分,最终项目结构如下:
sts-b-bert-similarity/
├── config.py # 全局配置(超参数、路径、模型名)
├── preprocessor.py # 数据加载、编码、Dataset/DataLoader构建
├── model.py # BERT回归模型定义、优化器配置
├── train_eval.py # 训练循环、验证、测试逻辑
├── predict.py # 单条句子对相似度预测
├── main.py # 项目主入口(整合所有模块)
├── STS-B.train.data # 训练集(无表头,三列:句子1/句子2/分数)
├── STS-B.valid.data # 验证集
├── STS-B.test.data # 测试集
└── best-sts-b-model.pt# 训练后保存的最优模型(自动生成)
模块化设计的优势:
- 配置集中管理,无需修改业务代码即可调整超参数;
- 各模块职责单一,便于调试和扩展(如更换模型、新增数据增强);
- 代码复用性高,后续可快速迁移到其他文本相似度任务。
四、核心模块实现
1. 配置模块(config.py)
集中管理所有超参数和路径,避免硬编码:
python
import torch
# 模型相关配置
MODEL_NAME = "bert-base-uncased" # 英文预训练模型(中文可选bert-base-chinese)
NUM_LABELS = 1 # 回归任务,输出1个连续值
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # 优先使用GPU
# 数据相关配置
TRAIN_DATA_PATH = "STS-B.train.data" # 训练集路径
VAL_DATA_PATH = "STS-B.valid.data" # 验证集路径
TEST_DATA_PATH = "STS-B.test.data" # 测试集路径
MAX_LENGTH = 128 # 文本最大长度(BERT输入限制512,按需调整)
RANDOM_STATE = 42 # 随机种子(保证结果可复现)
# 训练相关配置
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5 # BERT经典学习率
EPS = 1e-8 # AdamW优化器epsilon
WARMUP_STEPS = 0 # 学习率预热步数
# 路径配置
SAVE_MODEL_PATH = "best-sts-b-model.pt" # 最优模型保存路径
2. 数据处理模块(preprocessor.py)
处理无表头数据集的加载、BERT Tokenizer 编码、Dataset 和 DataLoader 构建,这是项目的核心难点之一。
(1)无表头数据集加载
STS-B 数据集无列名,直接按列索引读取(第 0 列 = 句子 1,第 1 列 = 句子 2,第 2 列 = 分数),并适配制表符分隔(TSV)格式:
python
import pandas as pd
import torch
from transformers import BertTokenizer
from config import *
def load_and_split_data():
"""加载已划分好的无表头数据集"""
# 读取无表头数据,sep='\t'适配STS-B默认的制表符分隔
train_df = pd.read_csv(TRAIN_DATA_PATH, header=None, sep='\t')
val_df = pd.read_csv(VAL_DATA_PATH, header=None, sep='\t')
test_df = pd.read_csv(TEST_DATA_PATH, header=None, sep='\t')
# 按列索引提取数据(0=句子1,1=句子2,2=分数)
train_data = (
train_df.iloc[:, 0].tolist(),
train_df.iloc[:, 1].tolist(),
train_df.iloc[:, 2].tolist()
)
val_data = (
val_df.iloc[:, 0].tolist(),
val_df.iloc[:, 1].tolist(),
val_df.iloc[:, 2].tolist()
)
test_data = (
test_df.iloc[:, 0].tolist(),
test_df.iloc[:, 1].tolist(),
test_df.iloc[:, 2].tolist()
)
return train_data, val_data, test_data
(2)BERT Tokenizer 编码
BERT 要求句子对拼接为[CLS] sentence1 [SEP] sentence2 [SEP]格式,生成input_ids(token ID)、attention_mask(掩码)、token_type_ids(句子区分标识):
python
def encode_texts(texts1, texts2, tokenizer):
"""对句子对进行BERT编码"""
encodings = tokenizer(
texts1, texts2,
truncation=True, # 截断过长文本
padding="max_length", # 填充到MAX_LENGTH
max_length=MAX_LENGTH,
return_tensors="pt" # 返回PyTorch张量
)
return encodings
(3)Dataset 与 DataLoader 构建
将编码后的特征和标签封装为 PyTorch Dataset,便于批量加载:
python
class STSBDataset(torch.utils.data.Dataset):
"""STS-B数据集类"""
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item["labels"] = self.labels[idx]
return item
def __len__(self):
return len(self.labels)
def build_dataloaders():
"""构建DataLoader(数据加载器)"""
# 加载Tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
# 加载数据
train_data, val_data, test_data = load_and_split_data()
train_texts1, train_texts2, train_labels = train_data
val_texts1, val_texts2, val_labels = val_data
test_texts1, test_texts2, test_labels = test_data
# 编码文本
train_encodings = encode_texts(train_texts1, train_texts2, tokenizer)
val_encodings = encode_texts(val_texts1, val_texts2, tokenizer)
test_encodings = encode_texts(test_texts1, test_texts2, tokenizer)
# 转换标签为浮点型张量(适配回归任务)
train_labels = torch.tensor(train_labels, dtype=torch.float32)
val_labels = torch.tensor(val_labels, dtype=torch.float32)
test_labels = torch.tensor(test_labels, dtype=torch.float32)
# 构建Dataset
train_dataset = STSBDataset(train_encodings, train_labels)
val_dataset = STSBDataset(val_encodings, val_labels)
test_dataset = STSBDataset(test_encodings, test_labels)
# 构建DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
return train_loader, val_loader, test_loader, tokenizer
3. 模型模块(model.py)
定义 BERT 回归模型(将分类头改为回归头),配置优化器和学习率调度器:
python
import torch
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
import torch.optim as optim
from config import *
def build_model():
"""构建BERT回归模型"""
# 加载BERT模型,num_labels=1表示回归任务
model = BertForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS
)
model.to(DEVICE) # 移到指定设备(GPU/CPU)
return model
def build_optimizer_scheduler(model, train_loader):
"""构建优化器和学习率调度器"""
# 使用PyTorch原生AdamW(替代transformers的弃用版本)
optimizer = optim.AdamW(
model.parameters(),
lr=LEARNING_RATE,
eps=EPS
)
# 学习率调度器(线性预热+衰减)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=WARMUP_STEPS,
num_training_steps=total_steps
)
# 回归任务损失函数:均方误差(MSE)
loss_fn = torch.nn.MSELoss()
return optimizer, scheduler, loss_fn
4. 训练验证模块(train_eval.py)
实现训练循环、验证逻辑、早停机制(防止过拟合),并计算核心评估指标:
python
import torch
import numpy as np
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from config import *
def validate(model, val_loader, loss_fn):
"""验证函数:计算损失和评估指标"""
model.eval() # 评估模式(关闭Dropout)
val_loss = 0.0
all_preds = []
all_labels = []
with torch.no_grad(): # 禁用梯度计算(加速+节省显存)
for batch in val_loader:
# 数据移到设备
input_ids = batch["input_ids"].to(DEVICE)
attention_mask = batch["attention_mask"].to(DEVICE)
token_type_ids = batch["token_type_ids"].to(DEVICE)
labels = batch["labels"].to(DEVICE)
# 前向传播
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
logits = outputs.logits.squeeze() # 去掉维度1(回归输出为标量)
loss = loss_fn(logits, labels)
val_loss += loss.item()
# 收集预测值和真实标签
all_preds.extend(logits.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# 计算平均损失、皮尔逊/斯皮尔曼相关系数
avg_val_loss = val_loss / len(val_loader)
pearson_corr = pearsonr(all_preds, all_labels)[0]
spearman_corr = spearmanr(all_preds, all_labels)[0]
return avg_val_loss, pearson_corr, spearman_corr
def train_model(model, train_loader, val_loader, optimizer, scheduler, loss_fn):
"""训练函数:带早停机制"""
best_pearson = 0.0 # 最优验证集皮尔逊系数
patience = 2 # 早停耐心值(连续2轮无提升则停止)
patience_counter = 0
for epoch in range(EPOCHS):
model.train() # 训练模式(开启Dropout)
train_loss = 0.0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for batch in progress_bar:
# 数据移到设备
input_ids = batch["input_ids"].to(DEVICE)
attention_mask = batch["attention_mask"].to(DEVICE)
token_type_ids = batch["token_type_ids"].to(DEVICE)
labels = batch["labels"].to(DEVICE)
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
logits = outputs.logits.squeeze()
loss = loss_fn(logits, labels)
# 反向传播+优化
loss.backward()
optimizer.step()
scheduler.step()
# 累计损失
train_loss += loss.item()
progress_bar.set_postfix({"train_loss": loss.item()})
# 计算训练集平均损失
avg_train_loss = train_loss / len(train_loader)
# 验证集评估
avg_val_loss, pearson_corr, spearman_corr = validate(model, val_loader, loss_fn)
# 打印日志
print(f"\nEpoch {epoch+1} Summary:")
print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
print(f"Val Pearson Corr: {pearson_corr:.4f} | Val Spearman Corr: {spearman_corr:.4f}")
# 保存最优模型+早停判断
if pearson_corr > best_pearson:
best_pearson = pearson_corr
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(f"Best model saved (Pearson: {best_pearson:.4f})")
patience_counter = 0 # 重置耐心计数器
else:
patience_counter += 1
print(f"Patience counter: {patience_counter}/{patience}")
if patience_counter >= patience:
print("Early stopping! No improvement in validation Pearson correlation.")
break # 停止训练
def test_model(model, test_loader, loss_fn):
"""测试函数:评估最优模型在测试集上的性能"""
model.load_state_dict(torch.load(SAVE_MODEL_PATH)) # 加载最优模型
model.to(DEVICE)
test_loss, test_pearson, test_spearman = validate(model, test_loader, loss_fn)
print("\n==================== Test Set Results ====================")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Pearson Corr: {test_pearson:.4f} | Test Spearman Corr: {test_spearman:.4f}")
return test_loss, test_pearson, test_spearman
5. 预测模块(predict.py)
实现单条句子对的相似度预测,限制输出范围在 0-5 之间:
python
import torch
from transformers import BertTokenizer
from config import *
def predict_similarity(sentence1, sentence2, model, tokenizer):
"""单条句子对相似度预测"""
model.eval()
# 编码句子对
encodings = tokenizer(
[sentence1], [sentence2],
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt"
)
# 数据移到设备
input_ids = encodings["input_ids"].to(DEVICE)
attention_mask = encodings["attention_mask"].to(DEVICE)
token_type_ids = encodings["token_type_ids"].to(DEVICE)
# 预测
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
score = outputs.logits.squeeze().cpu().numpy()
# 限制分数在0-5之间(防止预测值超出范围)
score = max(0.0, min(5.0, score))
return score
6. 主入口(main.py)
整合所有模块,实现一键运行:
python
import os
# 关闭huggingface_hub的symlinks警告
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
from preprocessor import build_dataloaders
from model import build_model, build_optimizer_scheduler
from train_eval import train_model, test_model
from predict import predict_similarity
from config import *
def main():
# 1. 数据处理:构建DataLoader和Tokenizer
print("========== Loading and Preprocessing Data ==========")
train_loader, val_loader, test_loader, tokenizer = build_dataloaders()
# 2. 构建模型、优化器、损失函数
print("\n========== Building Model ==========")
model = build_model()
optimizer, scheduler, loss_fn = build_optimizer_scheduler(model, train_loader)
# 3. 训练模型
print("\n========== Starting Training ==========")
train_model(model, train_loader, val_loader, optimizer, scheduler, loss_fn)
# 4. 测试模型
print("\n========== Starting Testing ==========")
test_model(model, test_loader, loss_fn)
# 5. 示例预测
print("\n========== Example Prediction ==========")
model.load_state_dict(torch.load(SAVE_MODEL_PATH)) # 加载最优模型
# 测试句子对
sent1 = "A cat is chasing a mouse."
sent2 = "A feline is pursuing a rodent."
score = predict_similarity(sent1, sent2, model, tokenizer)
print(f"Sentence 1: {sent1}")
print(f"Sentence 2: {sent2}")
print(f"Predicted Similarity Score: {score:.2f}")
if __name__ == "__main__":
main()
五、实战问题与解决
在项目开发过程中,我遇到了多个典型问题,以下是解决方案:
1. Tokenizer 加载失败(OSError: Can't load tokenizer for 'bert-base-uncased')
- 原因:网络问题导致无法从 Hugging Face 服务器下载 Tokenizer,或本地存在同名文件夹冲突。
- 解决 :
- 方法 1:科学上网(推荐);
- 方法 2:手动下载
bert-base-uncased的vocab.txt到本地文件夹(如./bert-vocab/),修改config.py中的MODEL_NAME为本地路径。
2. 数据集 KeyError(无表头处理)
- 原因:代码中使用列名读取数据,但数据集无表头。
- 解决 :使用
pd.read_csv(header=None)读取无表头数据,通过列索引(iloc[:, 0])而非列名获取数据。
3. 训练中的警告处理
- AdamW 弃用警告 :替换为 PyTorch 原生
torch.optim.AdamW; - overflowing tokens 警告 :属于无害提示(文本截断正常),可显式指定
truncation_strategy="longest_first"消除; - symlinks 警告 :设置环境变量
HF_HUB_DISABLE_SYMLINKS_WARNING=1关闭。
4. GPU 训练加速
- 若有 NVIDIA GPU,安装 CUDA 和 cuDNN,确保
torch.cuda.is_available()返回True,训练速度可提升 10-20 倍; - 若无 GPU,调小
BATCH_SIZE(如 8)减少 CPU 计算压力。
六、结果分析
1. 训练指标
使用bert-base-uncased模型训练 5 轮(早停触发于第 3 轮),最终测试集指标:
- Test Loss: 0.3215
- Test Pearson Corr: 0.8523(达到行业主流水平)
- Test Spearman Corr: 0.8417
2. 示例预测
输入句子对:
- Sentence 1: "A cat is chasing a mouse."
- Sentence 2: "A feline is pursuing a rodent."
预测相似度分数:4.72(接近满分 5,符合语义)。
七、优化方向
1. 模型升级
- 更换更大的预训练模型(如
bert-large-uncased、roberta-base),进一步提升精度; - 采用微调策略(如 LoRA),减少显存占用,加速训练。
2. 超参数调优
- 调整
MAX_LENGTH(如 256)适配长句子; - 尝试不同学习率(3e-5、5e-5)、批次大小(32);
- 使用网格搜索 / 贝叶斯优化自动调优超参数。
3. 数据增强
- 英文:同义词替换(WordNet)、句子重排、回译;
- 中文:同义词替换(哈工大词林)、随机插入 / 删除停用词。
4. 工程优化
- 添加日志模块(
logging),保存训练日志到文件; - 实现模型导出(ONNX 格式),支持线上部署;
- 构建可视化界面(Gradio/Streamlit),方便非技术人员使用。
八、总结
本文从模块化设计出发,实现了基于 BERT 的 STS-B 文本相似度打分项目,涵盖数据处理、模型构建、训练优化全流程,并解决了实战中的典型问题。项目代码具有高可维护性和扩展性,可直接迁移到其他文本相似度任务(如问答匹配、舆情相似度)。
BERT 作为预训练语言模型,在语义理解任务上的优势显著,通过简单的头层改造即可适配回归任务。未来可结合最新的大模型技术(如 LLM 微调),进一步提升模型性能和泛化能力。
附:完整代码仓库
所有代码已整理至 GitHub 仓库(示例地址):https://github.com/jiapengLi11/BERT-similarity,包含数据集说明、运行脚本和详细注释,欢迎 Star 和 Fork。