基于Python的自然语言处理系列(45):Sentence-BERT句子相似度计算

在本篇博文中,我们将探讨如何使用Sentence-BERT (SBERT) 来进行句子相似度计算。SBERT 不仅可以用于自然语言推理任务,还可以生成句子嵌入,从而计算两个句子之间的语义相似度。

1. 数据加载

我们首先加载两个数据集:SNLI (斯坦福自然语言推理数据集) 和 MNLI (多体裁自然语言推理数据集)。这些数据集提供了前提-假设句对,用于训练模型以处理推理任务。以下是数据加载的代码示例:

python 复制代码
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')

2. 数据预处理

为了将数据集转换为模型可以处理的格式,我们首先需要对句子进行分词和编码。使用BERT的分词器可以轻松地对句子进行分词,并生成输入ID和注意力掩码。以下代码展示了如何预处理数据:

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_function(examples):
    max_seq_length = 128
    premise_result = tokenizer(
        examples['premise'], padding='max_length', max_length=max_seq_length, truncation=True)
    hypothesis_result = tokenizer(
        examples['hypothesis'], padding='max_length', max_length=max_seq_length, truncation=True)
    
    labels = examples["label"]
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels": labels
    }

tokenized_datasets = raw_dataset.map(preprocess_function, batched=True)

3. 数据加载器

接下来,我们需要使用数据加载器将预处理后的数据传递给模型进行训练。我们使用PyTorch的数据加载器来创建训练集、验证集和测试集的加载器:

python 复制代码
from torch.utils.data import DataLoader

batch_size = 32
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=batch_size)
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size)

4. Sentence-BERT模型

SBERT通过在预训练的BERT模型后加入一个池化层来获取句子的固定大小嵌入表示。以下是加载预训练BERT模型并实现池化操作的代码:

python 复制代码
from transformers import BertModel
model = BertModel.from_pretrained('bert-base-uncased')

def mean_pool(token_embeds, attention_mask):
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    return pool

5. 损失函数与优化器

我们采用交叉熵损失函数用于分类任务,并使用AdamW优化器来优化模型的权重。以下是损失函数与优化器的定义:

python 复制代码
import torch.optim as optim

classifier_head = torch.nn.Linear(768 * 3, 3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

6. 消融实验

在实验中,我们通过不同的模型配置来测试模型的表现,并进行消融实验。消融实验可以帮助我们了解每个组件对模型性能的贡献。

7. 模型训练

模型训练分为多个epoch,每个epoch都会遍历训练集。模型从BERT层提取句子嵌入,经过分类器头部计算预测值并计算损失。以下是训练代码:

python 复制代码
for epoch in range(num_epoch):
    model.train()
    classifier_head.train()
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        u = model(batch['premise_input_ids'].to(device), attention_mask=batch['premise_attention_mask'].to(device))
        v = model(batch['hypothesis_input_ids'].to(device), attention_mask=batch['hypothesis_attention_mask'].to(device))
        
        loss = criterion(classifier_head(torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1)), batch['labels'].to(device))
        
        loss.backward()
        optimizer.step()

8. 推理与相似度计算

训练完成后,我们可以使用训练好的SBERT模型来计算两个句子的相似度。以下代码展示了如何输入两个句子并计算它们的余弦相似度:

python 复制代码
sentence_a = '你做出的贡献帮助我们提供了优质的教育。'
sentence_b = '你的贡献对我们的学生教育毫无帮助。'

similarity = calculate_similarity(model, tokenizer, sentence_a, sentence_b, device)
print(f"余弦相似度: {similarity:.4f}")

结语

在本篇文章中,我们详细介绍了Sentence-BERT (SBERT) 的工作原理、数据预处理、模型训练以及句子相似度的计算。通过引入池化操作,SBERT能够生成固定大小的句子嵌入,从而有效计算句子间的语义相似度。我们还通过训练模型处理了SNLI和MNLI数据集,展示了如何结合这些数据进行模型优化与推理。

SBERT的优势在于它能够为各种自然语言处理任务提供语义丰富的句子嵌入表示,并且在分类和相似度计算任务中表现出色。未来,SBERT还可以应用于更加复杂的任务,如多语言理解、文本生成等领域,进一步提高模型的泛化能力和实际应用效果。

通过本次实验,相信你已经掌握了如何使用SBERT进行句子嵌入计算的核心步骤。在日常的NLP项目中,SBERT的这种高效处理方式能够大大提高文本相似度任务的处理效率。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关推荐
NAGNIP7 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab8 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab8 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP12 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年12 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼12 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS12 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区13 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈13 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang14 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx