本文实现了递归排名融合(RRF)算法,用于合并多个排序列表并生成一个新的排序结果。具体步骤如下:
初始化一个空字典 scores。
遍历每个排序列表 ranking,对每个文档计算其 RRF 分数并存储在 scores 中。
对所有文档的 RRF 分数进行排序,返回按分数从高到低排序的文档列表。
使用 RRF(Reciprocal Rank Fusion)算法 合并新旧搜索结果,可提升排序鲁棒性。
python
import os
import re
from tqdm import tqdm
import weaviate
from weaviate import Client
from weaviate.util import generate_uuid5
from dotenv import load_dotenv
load_dotenv()
weaviate_url = os.environ.get("WEAVIATE_URL", "")
weaviate_key = os.environ.get("WEAVIATE_API_KEY", "")
def rrf(rankings, k=60):
scores = dict()
for ranking in rankings:
for i, doc in enumerate(ranking):
doc_id = doc["hotel_id"] if isinstance(doc, dict) else doc
if doc_id not in scores:
scores[doc_id] = (0, doc)
scores[doc_id] = (scores[doc_id][0] + 1 / (k + i), doc)
sorted_scores = sorted(scores.values(), key=lambda x: x[0], reverse=True)
return [item[1] for item in sorted_scores]
class HotelDB:
def __init__(self, ip="localhost", port=8080):
header = {"X-OpenAI-Api-Key": os.getenv("OPENAI_API_KEY")}
url = weaviate_url
auth_config = weaviate.AuthApiKey(api_key=weaviate_key)
self.client = weaviate.Client(
url=url,
additional_headers=header,
auth_client_secret=auth_config,
)
# try:
# url = os.getenv("WEAVIATE_URL")
# # url = f"http://{ip}:{port}"
# self.client = Client(
# url=url, additional_headers=header, timeout_config=(3, 10)
# )
# except Exception:
# ip = "weaviate"
# url = f"http://{ip}:{port}"
# self.client = Client(
# url=url, additional_headers=header, timeout_config=(3, 10)
# )
def create(self, name="Hotel"):
schema = {
"classes": [
{
"class": name,
"description": "hotel meta data",
"properties": [
{
"dataType": ["number"],
"description": "id of hotel",
"name": "hotel_id",
},
{
"dataType": ["text"],
"description": "name of hotel",
"name": "_name", # 分词过用于搜索的
"indexSearchable": True,
"tokenization": "whitespace",
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["text"],
"description": "type of hotel",
"name": "name",
"indexSearchable": False,
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["text"],
"description": "type of hotel",
"name": "type",
"indexSearchable": False,
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["text"],
"description": "address of hotel",
"name": "_address", # 分词过用于搜索的
"indexSearchable": True,
"tokenization": "whitespace",
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["text"],
"description": "type of hotel",
"name": "address",
"indexSearchable": False,
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["text"],
"description": "nearby subway",
"name": "subway",
"indexSearchable": False,
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["text"],
"description": "phone of hotel",
"name": "phone",
"indexSearchable": False,
"moduleConfig": {
"text2vec-contextionary": {"skip": True}
},
},
{
"dataType": ["number"],
"description": "price of hotel",
"name": "price",
},
{
"dataType": ["number"],
"description": "rating of hotel",
"name": "rating",
},
{
"dataType": ["text"],
"description": "facilities provided",
"name": "facilities",
"indexSearchable": True,
"moduleConfig": {
"text2vec-contextionary": {"skip": False}
},
},
],
"vectorizer": "text2vec-openai",
"moduleConfig": {
"text2vec-openai": {
"vectorizeClassName": False,
"model": "ada",
"modelVersion": "002",
"type": "text",
},
},
}
]
}
self.client.schema.create(schema)
# 单class创建也可用client.schema.create_class(schema)
def delete(self, name="Hotel"):
self.client.schema.delete_class(name)
def insert(self, data, name="Hotel", batch=4):
self.client.batch.configure(batch_size=batch, dynamic=True)
for item in tqdm(data):
self.client.batch.add_data_object(
data_object=item,
class_name=name,
uuid=generate_uuid5(item, name),
)
self.client.batch.flush()
def search(
self,
dsl,
name="Hotel",
output_fields=["hotel_id", "name", "type", "rating", "price"],
limit=10,
):
candidates = []
if not dsl:
return []
if "hotel_id" not in output_fields: # rrf排序中使用hotel_id
output_fields.append("hotel_id")
# ===================== assemble filters ========================= #
filters = []
keys = [
"type",
"price.range.low",
"price.range.high",
"rating.range.low",
"rating.range.hight",
]
if any(key in dsl for key in keys):
if "type" in dsl:
filters.append(
{
"path": ["type"],
"operator": "Equal",
"valueString": dsl["type"],
}
)
if "price.range.low" in dsl:
filters.append(
{
"path": ["price"],
"operator": "GreaterThan",
"valueNumber": dsl["price.range.low"],
}
)
if "price.range.high" in dsl:
filters.append(
{
"path": ["price"],
"operator": "LessThan",
"valueNumber": dsl["price.range.high"],
}
)
if "rating.range.low" in dsl:
filters.append(
{
"path": ["rating"],
"operator": "GreaterThan",
"valueNumber": dsl["rating.range.low"],
}
)
if "rating.range.high" in dsl:
filters.append(
{
"path": ["rating"],
"operator": "LessThan",
"valueNumber": dsl["rating.range.high"],
}
)
# 补丁,过滤掉未给价格的-1值
filters.append(
{
"path": ["price"],
"operator": "GreaterThan",
"valueNumber": 0,
}
)
if (len(filters)) == 1:
filters = filters[0]
elif len(filters) > 1:
filters = {"operator": "And", "operands": filters}
# ===================== vector search ============================= #
if "facilities" in dsl:
query = self.client.query.get(name, output_fields)
query = query.with_near_text(
{"concepts": [f'酒店提供:{dsl["facilities"]}']}
)
if filters:
query = query.with_where(filters)
query = query.with_limit(limit)
result = query.do()
candidates = rrf([candidates, result["data"]["Get"][name]])
# ===================== keyword search ============================ #
if "name" in dsl:
text = " ".join(re.findall(r"[\dA-Za-z\-]+|\w", dsl["name"]))
query = self.client.query.get(name, output_fields)
query = query.with_bm25(query=text, properties=["_name"])
if filters:
query = query.with_where(filters)
query = query.with_limit(limit)
result = query.do()
candidates = rrf([candidates, result["data"]["Get"][name]])
if "address" in dsl:
text = " ".join(re.findall(r"[\dA-Za-z\-]+|\w", dsl["address"]))
query = self.client.query.get(name, output_fields)
query = query.with_bm25(query=text, properties=["_address"])
if filters:
query = query.with_where(filters)
query = query.with_limit(limit)
result = query.do()
candidates = rrf([candidates, result["data"]["Get"][name]])
# ====================== condition search ========================== #
if not candidates:
query = self.client.query.get(name, output_fields)
if filters:
query = query.with_where(filters)
query = query.with_limit(limit)
result = query.do()
candidates = result["data"]["Get"][name]
# ========================== sort ================================= #
if "sort.slot" in dsl:
if dsl["sort.ordering"] == "descend":
candidates = sorted(
candidates, key=lambda x: x[dsl["sort.slot"]], reverse=True
)
else:
candidates = sorted(
candidates, key=lambda x: x[dsl["sort.slot"]]
)
return candidates
以下是这段代码的逐层解析及其核心作用说明:
一、代码功能概述
这是一个 Reciprocal Rank Fusion (RRF) 算法 的实现,用于对多个排序列表 (如搜索结果、推荐列表)进行融合重排。通过聚合不同来源的排序结果,生成更鲁棒的综合排序。
二、参数说明
参数 | 类型 | 作用 |
---|---|---|
rankings |
List[List] |
多个排序列表,例如:[[doc1, doc2], [doc3, doc1]] (每个列表已排序) |
k |
int |
平滑系数,用于控制排名位置对分数的影响强度(默认60) |
三、分步代码解析
1. 初始化分数字典
python
scores = dict() # 存储文档ID到(总分数,文档对象)的映射
2. 遍历每个排序列表
python
for ranking in rankings: # 遍历每个来源的排序结果(如不同推荐模型的结果)
for i, doc in enumerate(ranking): # 遍历当前列表中的每个文档及其位置i
# 提取文档唯一标识(假设文档是字典且包含"hotel_id",否则直接使用文档本身)
doc_id = doc["hotel_id"] if isinstance(doc, dict) else doc
# 初始化分数或更新分数
if doc_id not in scores:
scores[doc_id] = (0, doc) # 首次出现时初始化为(0, 文档对象)
# RRF核心公式:累加 1/(k + 当前排名位置i)
scores[doc_id] = (scores[doc_id][0] + 1 / (k + i), doc)
3. 按总分排序并返回文档
python
sorted_scores = sorted(scores.values(), key=lambda x: x[0], reverse=True)
return [item[1] for item in sorted_scores] # 按分数从高到低返回文档对象
四、核心算法逻辑
1. RRF 公式
每个文档在第 i 个位置 的得分贡献为:
[ \text{score} = \frac{1}{k + i} ]
- k 的作用 :防止分母过小(当 i=0 时,k 避免除零错误),同时控制位置影响的衰减速度。
- k 值越大,排名位置的影响越平缓。
- k 值越小,排名靠前的文档得分差异越显著。
2. 分数聚合
- 同一文档在不同列表中出现的次数越多、位置越靠前(i 越小),总分越高。
- 例如:
- 文档A在列表1的第1位(i=0)、列表2的第3位(i=2):
[ \text{总分} = \frac{1}{60+0} + \frac{1}{60+2} = 0.0167 + 0.0161 = 0.0328 ] - 文档B仅在列表1的第2位(i=1):
[ \text{总分} = \frac{1}{60+1} = 0.0164 ] - 结果:文档A排在文档B之前。
- 文档A在列表1的第1位(i=0)、列表2的第3位(i=2):
五、使用场景示例
1. 酒店推荐系统
-
输入 :3个推荐模型的排序结果:
pythonmodel1_ranking = [{"hotel_id": 101}, {"hotel_id": 102}] model2_ranking = [{"hotel_id": 103}, {"hotel_id": 101}] rankings = [model1_ranking, model2_ranking]
-
调用 :
pythonfinal_ranking = rrf(rankings, k=60)
-
输出 :
- 总分计算:
- 101号酒店:1/(60+0) + 1/(60+1) ≈ 0.0167 + 0.0164 = 0.0331
- 102号酒店:1/(60+1) = 0.0164
- 103号酒店:1/(60+0) = 0.0167
- 最终排序:
[101, 103, 102]
- 总分计算:
2. 搜索引擎结果融合
- 合并 Bing、Google 的搜索结果,提升结果相关性。
六、关键设计分析
设计点 | 优势 | 潜在问题 |
---|---|---|
使用 k 平滑系数 |
避免排名靠前位置权重过高,增强鲁棒性 | 需根据数据调整 k (经验值60) |
支持字典或纯ID格式 | 灵活适配不同数据结构 | 需确保所有文档有唯一标识(如 hotel_id ) |
时间复杂度 O(N*M) | 适合中小规模数据(N=文档数,M=列表数) | 大数据量需优化(如分块处理) |
七、与其他算法的对比
算法 | 核心思想 | 适用场景 |
---|---|---|
RRF | 倒数排名加权融合 | 多来源排序结果聚合 |
Borda Count | 线性加权排名求和 | 简单快速,但对噪声敏感 |
MAP | 平均精度(需标注数据) | 有监督学习的评估场景 |
八、扩展优化建议
- 动态调整
k
值 :根据数据分布自动优化(如高频冲突时增大k
)。 - 去重逻辑:同一列表内重复文档的处理(当前代码允许多次出现)。
- 并行计算:对大规模数据使用多线程/进程加速。
通过这段代码,可以有效整合多个排序来源,生成更鲁棒的综合排序结果。