基于Python的自然语言处理系列(45):Sentence-BERT句子相似度计算

在本篇博文中,我们将探讨如何使用Sentence-BERT (SBERT) 来进行句子相似度计算。SBERT 不仅可以用于自然语言推理任务,还可以生成句子嵌入,从而计算两个句子之间的语义相似度。

1. 数据加载

我们首先加载两个数据集:SNLI (斯坦福自然语言推理数据集) 和 MNLI (多体裁自然语言推理数据集)。这些数据集提供了前提-假设句对,用于训练模型以处理推理任务。以下是数据加载的代码示例:

python 复制代码
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')

2. 数据预处理

为了将数据集转换为模型可以处理的格式,我们首先需要对句子进行分词和编码。使用BERT的分词器可以轻松地对句子进行分词,并生成输入ID和注意力掩码。以下代码展示了如何预处理数据:

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_function(examples):
    max_seq_length = 128
    premise_result = tokenizer(
        examples['premise'], padding='max_length', max_length=max_seq_length, truncation=True)
    hypothesis_result = tokenizer(
        examples['hypothesis'], padding='max_length', max_length=max_seq_length, truncation=True)
    
    labels = examples["label"]
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels": labels
    }

tokenized_datasets = raw_dataset.map(preprocess_function, batched=True)

3. 数据加载器

接下来,我们需要使用数据加载器将预处理后的数据传递给模型进行训练。我们使用PyTorch的数据加载器来创建训练集、验证集和测试集的加载器:

python 复制代码
from torch.utils.data import DataLoader

batch_size = 32
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=batch_size)
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size)

4. Sentence-BERT模型

SBERT通过在预训练的BERT模型后加入一个池化层来获取句子的固定大小嵌入表示。以下是加载预训练BERT模型并实现池化操作的代码:

python 复制代码
from transformers import BertModel
model = BertModel.from_pretrained('bert-base-uncased')

def mean_pool(token_embeds, attention_mask):
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    return pool

5. 损失函数与优化器

我们采用交叉熵损失函数用于分类任务,并使用AdamW优化器来优化模型的权重。以下是损失函数与优化器的定义:

python 复制代码
import torch.optim as optim

classifier_head = torch.nn.Linear(768 * 3, 3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

6. 消融实验

在实验中,我们通过不同的模型配置来测试模型的表现,并进行消融实验。消融实验可以帮助我们了解每个组件对模型性能的贡献。

7. 模型训练

模型训练分为多个epoch,每个epoch都会遍历训练集。模型从BERT层提取句子嵌入,经过分类器头部计算预测值并计算损失。以下是训练代码:

python 复制代码
for epoch in range(num_epoch):
    model.train()
    classifier_head.train()
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        u = model(batch['premise_input_ids'].to(device), attention_mask=batch['premise_attention_mask'].to(device))
        v = model(batch['hypothesis_input_ids'].to(device), attention_mask=batch['hypothesis_attention_mask'].to(device))
        
        loss = criterion(classifier_head(torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1)), batch['labels'].to(device))
        
        loss.backward()
        optimizer.step()

8. 推理与相似度计算

训练完成后,我们可以使用训练好的SBERT模型来计算两个句子的相似度。以下代码展示了如何输入两个句子并计算它们的余弦相似度:

python 复制代码
sentence_a = '你做出的贡献帮助我们提供了优质的教育。'
sentence_b = '你的贡献对我们的学生教育毫无帮助。'

similarity = calculate_similarity(model, tokenizer, sentence_a, sentence_b, device)
print(f"余弦相似度: {similarity:.4f}")

结语

在本篇文章中,我们详细介绍了Sentence-BERT (SBERT) 的工作原理、数据预处理、模型训练以及句子相似度的计算。通过引入池化操作,SBERT能够生成固定大小的句子嵌入,从而有效计算句子间的语义相似度。我们还通过训练模型处理了SNLI和MNLI数据集,展示了如何结合这些数据进行模型优化与推理。

SBERT的优势在于它能够为各种自然语言处理任务提供语义丰富的句子嵌入表示,并且在分类和相似度计算任务中表现出色。未来,SBERT还可以应用于更加复杂的任务,如多语言理解、文本生成等领域,进一步提高模型的泛化能力和实际应用效果。

通过本次实验,相信你已经掌握了如何使用SBERT进行句子嵌入计算的核心步骤。在日常的NLP项目中,SBERT的这种高效处理方式能够大大提高文本相似度任务的处理效率。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关推荐
小众AI2 小时前
AI-on-the-edge-device - 将“旧”设备接入智能世界
人工智能·开源·ai编程
舟寒、2 小时前
【论文分享】Ultra-AV: 一个规范化自动驾驶汽车纵向轨迹数据集
人工智能·自动驾驶·汽车
梦云澜5 小时前
论文阅读(十二):全基因组关联研究中生物通路的图形建模
论文阅读·人工智能·深度学习
远洋录5 小时前
构建一个数据分析Agent:提升分析效率的实践
人工智能·ai·ai agent
IT古董6 小时前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师7 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)8 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui9 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20259 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥10 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技