用 TripletLoss 优化bert ranking

下面是 用 TripletLoss 优化bert ranking 的demo

python 复制代码
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics.pairwise import pairwise_distances

class TripletRankingDataset(Dataset):
    def __init__(self, queries, positive_docs, negative_docs, tokenizer, max_length):
        self.input_ids_q = []
        self.attention_masks_q = []
        self.input_ids_p = []
        self.attention_masks_p = []
        self.input_ids_n = []
        self.attention_masks_n = []
        
        for query, pos_doc, neg_doc in zip(queries, positive_docs, negative_docs):
            encoded_query = tokenizer.encode_plus(query, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
            encoded_pos_doc = tokenizer.encode_plus(pos_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
            encoded_neg_doc = tokenizer.encode_plus(neg_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
            
            self.input_ids_q.append(encoded_query['input_ids'])
            self.attention_masks_q.append(encoded_query['attention_mask'])
            self.input_ids_p.append(encoded_pos_doc['input_ids'])
            self.attention_masks_p.append(encoded_pos_doc['attention_mask'])
            self.input_ids_n.append(encoded_neg_doc['input_ids'])
            self.attention_masks_n.append(encoded_neg_doc['attention_mask'])
        
        self.input_ids_q = torch.cat(self.input_ids_q, dim=0)
        self.attention_masks_q = torch.cat(self.attention_masks_q, dim=0)
        self.input_ids_p = torch.cat(self.input_ids_p, dim=0)
        self.attention_masks_p = torch.cat(self.attention_masks_p, dim=0)
        self.input_ids_n = torch.cat(self.input_ids_n, dim=0)
        self.attention_masks_n = torch.cat(self.attention_masks_n, dim=0)
        
    def __len__(self):
        return len(self.input_ids_q)
    
    def __getitem__(self, idx):
        input_ids_q = self.input_ids_q[idx]
        attention_mask_q = self.attention_masks_q[idx]
        input_ids_p = self.input_ids_p[idx]
        attention_mask_p = self.attention_masks_p[idx]
        input_ids_n = self.input_ids_n[idx]
        attention_mask_n = self.attention_masks_n[idx]
        return input_ids_q, attention_mask_q, input_ids_p, attention_mask_p, input_ids_n, attention_mask_n

class BERTTripletRankingModel(torch.nn.Module):
    def __init__(self, bert_model_name, hidden_size):
        super(BERTTripletRankingModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.fc = torch.nn.Linear(hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs[1])
        logits = self.fc(pooled_output)
        return logits.squeeze()

def triplet_loss(anchor, positive, negative, margin):
    distance_positive = torch.nn.functional.pairwise_distance(anchor, positive)
    distance_negative = torch.nn.functional.pairwise_distance(anchor, negative)
    losses = torch.relu(distance_positive - distance_negative + margin)
    return torch.mean(losses)

# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

# 示例输入数据
queries = ['I like cats', 'The sun is shining']
positive_docs = ['I like dogs', 'The weather is beautiful']
negative_docs = ['Snakes are dangerous', 'It is raining']

# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5
margin = 1.0

# 创建数据集和数据加载器
dataset = TripletRankingDataset(queries, positive_docs, negative_docs, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型并加载预训练权重
model = BERTTripletRankingModel(bert_model_name, hidden_size=model.bert.config.hidden_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# 训练模型
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    
    for input_ids_q, attention_masks_q, input_ids_p, attention_masks_p, input_ids_n, attention_masks_n in dataloader:
        optimizer.zero_grad()
        
        embeddings_q = model(inputids_q, attention_masks_q)
        embeddings_p = model(input_ids_p, attention_masks_p)
        embeddings_n = model(input_ids_n, attention_masks_n)
        
        loss = triplet_loss(embeddings_q, embeddings_p, embeddings_n, margin)
        
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")

# 推断模型
model.eval()

with torch.no_grad():
    embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids_q)
    pairwise_distances = pairwise_distances(embeddings.numpy())

# 输出结果
for i, query in enumerate(queries):
    print(f"Query: {query}")
    print("Documents:")
    
    for j, doc in enumerate(positive_docs):
        doc_idx = pairwise_distances[0][i * len(positive_docs) + j]
        doc_dist = pairwise_distances[1][i * len(positive_docs) + j]
        
        print(f"Document index: {doc_idx}, Distance: {doc_dist:.4f}")
        print(f"Document: {doc}")
        print("")

    print("---------")
相关推荐
Amo Xiang15 分钟前
2024 Python3.10 系统入门+进阶(十五):文件及目录操作
开发语言·python
liangbm324 分钟前
数学建模笔记——动态规划
笔记·python·算法·数学建模·动态规划·背包问题·优化问题
B站计算机毕业设计超人36 分钟前
计算机毕业设计Python+Flask微博情感分析 微博舆情预测 微博爬虫 微博大数据 舆情分析系统 大数据毕业设计 NLP文本分类 机器学习 深度学习 AI
爬虫·python·深度学习·算法·机器学习·自然语言处理·数据可视化
羊小猪~~40 分钟前
深度学习基础案例5--VGG16人脸识别(体验学习的痛苦与乐趣)
人工智能·python·深度学习·学习·算法·机器学习·cnn
waterHBO3 小时前
python 爬虫 selenium 笔记
爬虫·python·selenium
编程零零七4 小时前
Python数据分析工具(三):pymssql的用法
开发语言·前端·数据库·python·oracle·数据分析·pymssql
AI大模型知识分享4 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
AIAdvocate6 小时前
Pandas_数据结构详解
数据结构·python·pandas
小言从不摸鱼6 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
C-SDN花园GGbond6 小时前
【探索数据结构与算法】插入排序:原理、实现与分析(图文详解)
c语言·开发语言·数据结构·排序算法