下面是用bert 训练pairwise rank 的 demo
python
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics import pairwise_distances_argmin_min
class PairwiseRankingDataset(Dataset):
def __init__(self, sentence_pairs, tokenizer, max_length):
self.input_ids = []
self.attention_masks = []
for pair in sentence_pairs:
encoded_pair = tokenizer(pair, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
self.input_ids.append(encoded_pair['input_ids'])
self.attention_masks.append(encoded_pair['attention_mask'])
self.input_ids = torch.cat(self.input_ids, dim=0)
self.attention_masks = torch.cat(self.attention_masks, dim=0)
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
input_id = self.input_ids[idx]
attention_mask = self.attention_masks[idx]
return input_id, attention_mask
class BERTPairwiseRankingModel(torch.nn.Module):
def __init__(self, bert_model_name):
super(BERTPairwiseRankingModel, self).__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.dropout = torch.nn.Dropout(0.1)
self.fc = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs[1])
logits = self.fc(pooled_output)
return logits.squeeze()
# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
# 示例输入数据
sentence_pairs = [
('I like cats', 'I like dogs'),
('The sun is shining', 'It is raining'),
('Apple is a fruit', 'Car is a vehicle')
]
# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5
# 创建数据集和数据加载器
dataset = PairwiseRankingDataset(sentence_pairs, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型并加载预训练权重
model = BERTPairwiseRankingModel(bert_model_name)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 训练模型
model.train()
for epoch in range(num_epochs):
total_loss = 0
for input_ids, attention_masks in dataloader:
optimizer.zero_grad()
logits = model(input_ids, attention_masks)
# 计算损失函数(使用对比损失函数)
pos_scores = logits[::2] # 正样本分数
neg_scores = logits[1::2] # 负样本分数
loss = torch.relu(1 - pos_scores + neg_scores).mean()
total_loss += loss.item()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")
# 推断模型
model.eval()
with torch.no_grad():
embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids)
pairwise_distances = pairwise_distances_argmin_min(embeddings.numpy())
# 输出结果
for i, pair in enumerate(sentence_pairs):
pos_idx = pairwise_distances[0][2 * i]
neg_idx = pairwise_distances[0][2 * i + 1]
pos_dist = pairwise_distances[1][2 * i]
neg_dist = pairwise_distances[1][2 * i + 1]
print(f"Pair: {pair}")
print(f"Positive example index: {pos_idx}, Distance: {pos_dist:.4f}")
print(f"Negative example index: {neg_idx}, Distance: {neg_dist:.4f}")
print()