基于Python的自然语言处理系列(45):Sentence-BERT句子相似度计算

在本篇博文中,我们将探讨如何使用Sentence-BERT (SBERT) 来进行句子相似度计算。SBERT 不仅可以用于自然语言推理任务,还可以生成句子嵌入,从而计算两个句子之间的语义相似度。

1. 数据加载

我们首先加载两个数据集:SNLI (斯坦福自然语言推理数据集) 和 MNLI (多体裁自然语言推理数据集)。这些数据集提供了前提-假设句对,用于训练模型以处理推理任务。以下是数据加载的代码示例:

python 复制代码
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')

2. 数据预处理

为了将数据集转换为模型可以处理的格式,我们首先需要对句子进行分词和编码。使用BERT的分词器可以轻松地对句子进行分词,并生成输入ID和注意力掩码。以下代码展示了如何预处理数据:

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_function(examples):
    max_seq_length = 128
    premise_result = tokenizer(
        examples['premise'], padding='max_length', max_length=max_seq_length, truncation=True)
    hypothesis_result = tokenizer(
        examples['hypothesis'], padding='max_length', max_length=max_seq_length, truncation=True)
    
    labels = examples["label"]
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels": labels
    }

tokenized_datasets = raw_dataset.map(preprocess_function, batched=True)

3. 数据加载器

接下来,我们需要使用数据加载器将预处理后的数据传递给模型进行训练。我们使用PyTorch的数据加载器来创建训练集、验证集和测试集的加载器:

python 复制代码
from torch.utils.data import DataLoader

batch_size = 32
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=batch_size)
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size)

4. Sentence-BERT模型

SBERT通过在预训练的BERT模型后加入一个池化层来获取句子的固定大小嵌入表示。以下是加载预训练BERT模型并实现池化操作的代码:

python 复制代码
from transformers import BertModel
model = BertModel.from_pretrained('bert-base-uncased')

def mean_pool(token_embeds, attention_mask):
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    return pool

5. 损失函数与优化器

我们采用交叉熵损失函数用于分类任务,并使用AdamW优化器来优化模型的权重。以下是损失函数与优化器的定义:

python 复制代码
import torch.optim as optim

classifier_head = torch.nn.Linear(768 * 3, 3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

6. 消融实验

在实验中,我们通过不同的模型配置来测试模型的表现,并进行消融实验。消融实验可以帮助我们了解每个组件对模型性能的贡献。

7. 模型训练

模型训练分为多个epoch,每个epoch都会遍历训练集。模型从BERT层提取句子嵌入,经过分类器头部计算预测值并计算损失。以下是训练代码:

python 复制代码
for epoch in range(num_epoch):
    model.train()
    classifier_head.train()
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        u = model(batch['premise_input_ids'].to(device), attention_mask=batch['premise_attention_mask'].to(device))
        v = model(batch['hypothesis_input_ids'].to(device), attention_mask=batch['hypothesis_attention_mask'].to(device))
        
        loss = criterion(classifier_head(torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1)), batch['labels'].to(device))
        
        loss.backward()
        optimizer.step()

8. 推理与相似度计算

训练完成后,我们可以使用训练好的SBERT模型来计算两个句子的相似度。以下代码展示了如何输入两个句子并计算它们的余弦相似度:

python 复制代码
sentence_a = '你做出的贡献帮助我们提供了优质的教育。'
sentence_b = '你的贡献对我们的学生教育毫无帮助。'

similarity = calculate_similarity(model, tokenizer, sentence_a, sentence_b, device)
print(f"余弦相似度: {similarity:.4f}")

结语

在本篇文章中,我们详细介绍了Sentence-BERT (SBERT) 的工作原理、数据预处理、模型训练以及句子相似度的计算。通过引入池化操作,SBERT能够生成固定大小的句子嵌入,从而有效计算句子间的语义相似度。我们还通过训练模型处理了SNLI和MNLI数据集,展示了如何结合这些数据进行模型优化与推理。

SBERT的优势在于它能够为各种自然语言处理任务提供语义丰富的句子嵌入表示,并且在分类和相似度计算任务中表现出色。未来,SBERT还可以应用于更加复杂的任务,如多语言理解、文本生成等领域,进一步提高模型的泛化能力和实际应用效果。

通过本次实验,相信你已经掌握了如何使用SBERT进行句子嵌入计算的核心步骤。在日常的NLP项目中,SBERT的这种高效处理方式能够大大提高文本相似度任务的处理效率。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关推荐
XinZong1 小时前
【AI开源项目】OneAPI -核心概念、特性、优缺点以及如何在本地和服务器上进行部署!
人工智能·开源
机器之心1 小时前
Runway CEO:AI公司的时代已经结束了
人工智能·后端
T0uken2 小时前
【机器学习】过拟合与欠拟合
人工智能·机器学习
即兴小索奇2 小时前
GPT-4V 是什么?
人工智能
机器学习之心3 小时前
GCN+BiLSTM多特征输入时间序列预测(Pytorch)
人工智能·pytorch·python·gcn+bilstm
码农-阿甘3 小时前
小牛视频翻译 ( 视频翻译 字幕翻译 字幕转语音 人声分离)
人工智能
黑龙江亿林等级保护测评3 小时前
等保行业如何选择核实的安全防御技术
网络·人工智能·python·安全·web安全·智能路由器·ddos
ai产品老杨3 小时前
深度学习模型量化原理
开发语言·人工智能·python·深度学习·安全·音视频
马甲是掉不了一点的<.<3 小时前
计算机视觉常用数据集Cityscapes的介绍、下载、转为YOLO格式进行训练
人工智能·yolo·目标检测·计算机视觉·计算机视觉数据集
weixin_eng020483 小时前
清仓和斩仓有什么不一样?
人工智能·金融·区块链