【Weaviate】使用递归排名融合RRF推荐酒店

本文实现了递归排名融合(RRF)算法,用于合并多个排序列表并生成一个新的排序结果。具体步骤如下:

初始化一个空字典 scores。

遍历每个排序列表 ranking,对每个文档计算其 RRF 分数并存储在 scores 中。

对所有文档的 RRF 分数进行排序,返回按分数从高到低排序的文档列表。

使用 RRF(Reciprocal Rank Fusion)算法 合并新旧搜索结果,可提升排序鲁棒性。

python 复制代码
import os
import re

from tqdm import tqdm
import weaviate
from weaviate import Client
from weaviate.util import generate_uuid5
from dotenv import load_dotenv
load_dotenv()

weaviate_url = os.environ.get("WEAVIATE_URL", "")
weaviate_key = os.environ.get("WEAVIATE_API_KEY", "")

def rrf(rankings, k=60):
    scores = dict()
    for ranking in rankings:
        for i, doc in enumerate(ranking):
            doc_id = doc["hotel_id"] if isinstance(doc, dict) else doc
            if doc_id not in scores:
                scores[doc_id] = (0, doc)
            scores[doc_id] = (scores[doc_id][0] + 1 / (k + i), doc)
    sorted_scores = sorted(scores.values(), key=lambda x: x[0], reverse=True)
    return [item[1] for item in sorted_scores]


class HotelDB:
    def __init__(self, ip="localhost", port=8080):
        header = {"X-OpenAI-Api-Key": os.getenv("OPENAI_API_KEY")}

        url = weaviate_url
        auth_config = weaviate.AuthApiKey(api_key=weaviate_key)

        self.client = weaviate.Client(
            url=url,
            additional_headers=header,
            auth_client_secret=auth_config,
        )

        # try:
        #     url = os.getenv("WEAVIATE_URL")
        #     # url = f"http://{ip}:{port}"
        #     self.client = Client(
        #         url=url, additional_headers=header, timeout_config=(3, 10)
        #     )
        # except Exception:
        #     ip = "weaviate"
        #     url = f"http://{ip}:{port}"
        #     self.client = Client(
        #         url=url, additional_headers=header, timeout_config=(3, 10)
        #     )

    def create(self, name="Hotel"):
        schema = {
            "classes": [
                {
                    "class": name,
                    "description": "hotel meta data",
                    "properties": [
                        {
                            "dataType": ["number"],
                            "description": "id of hotel",
                            "name": "hotel_id",
                        },
                        {
                            "dataType": ["text"],
                            "description": "name of hotel",
                            "name": "_name",  # 分词过用于搜索的
                            "indexSearchable": True,
                            "tokenization": "whitespace",
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "type of hotel",
                            "name": "name",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "type of hotel",
                            "name": "type",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "address of hotel",
                            "name": "_address",  # 分词过用于搜索的
                            "indexSearchable": True,
                            "tokenization": "whitespace",
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "type of hotel",
                            "name": "address",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "nearby subway",
                            "name": "subway",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "phone of hotel",
                            "name": "phone",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["number"],
                            "description": "price of hotel",
                            "name": "price",
                        },
                        {
                            "dataType": ["number"],
                            "description": "rating of hotel",
                            "name": "rating",
                        },
                        {
                            "dataType": ["text"],
                            "description": "facilities provided",
                            "name": "facilities",
                            "indexSearchable": True,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": False}
                            },
                        },
                    ],
                    "vectorizer": "text2vec-openai",
                    "moduleConfig": {
                        "text2vec-openai": {
                            "vectorizeClassName": False,
                            "model": "ada",
                            "modelVersion": "002",
                            "type": "text",
                        },
                    },
                }
            ]
        }
        self.client.schema.create(schema)
        # 单class创建也可用client.schema.create_class(schema)

    def delete(self, name="Hotel"):
        self.client.schema.delete_class(name)

    def insert(self, data, name="Hotel", batch=4):
        self.client.batch.configure(batch_size=batch, dynamic=True)
        for item in tqdm(data):
            self.client.batch.add_data_object(
                data_object=item,
                class_name=name,
                uuid=generate_uuid5(item, name),
            )
        self.client.batch.flush()

    def search(
        self,
        dsl,
        name="Hotel",
        output_fields=["hotel_id", "name", "type", "rating", "price"],
        limit=10,
    ):
        candidates = []
        if not dsl:
            return []
        if "hotel_id" not in output_fields:  # rrf排序中使用hotel_id
            output_fields.append("hotel_id")
        # ===================== assemble filters ========================= #
        filters = []
        keys = [
            "type",
            "price.range.low",
            "price.range.high",
            "rating.range.low",
            "rating.range.hight",
        ]
        if any(key in dsl for key in keys):
            if "type" in dsl:
                filters.append(
                    {
                        "path": ["type"],
                        "operator": "Equal",
                        "valueString": dsl["type"],
                    }
                )
            if "price.range.low" in dsl:
                filters.append(
                    {
                        "path": ["price"],
                        "operator": "GreaterThan",
                        "valueNumber": dsl["price.range.low"],
                    }
                )
            if "price.range.high" in dsl:
                filters.append(
                    {
                        "path": ["price"],
                        "operator": "LessThan",
                        "valueNumber": dsl["price.range.high"],
                    }
                )
            if "rating.range.low" in dsl:
                filters.append(
                    {
                        "path": ["rating"],
                        "operator": "GreaterThan",
                        "valueNumber": dsl["rating.range.low"],
                    }
                )
            if "rating.range.high" in dsl:
                filters.append(
                    {
                        "path": ["rating"],
                        "operator": "LessThan",
                        "valueNumber": dsl["rating.range.high"],
                    }
                )
            # 补丁,过滤掉未给价格的-1值
            filters.append(
                {
                    "path": ["price"],
                    "operator": "GreaterThan",
                    "valueNumber": 0,
                }
            )
        if (len(filters)) == 1:
            filters = filters[0]
        elif len(filters) > 1:
            filters = {"operator": "And", "operands": filters}
        # ===================== vector search ============================= #
        if "facilities" in dsl:
            query = self.client.query.get(name, output_fields)
            query = query.with_near_text(
                {"concepts": [f'酒店提供:{dsl["facilities"]}']}
            )
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = rrf([candidates, result["data"]["Get"][name]])
        # ===================== keyword search ============================ #
        if "name" in dsl:
            text = " ".join(re.findall(r"[\dA-Za-z\-]+|\w", dsl["name"]))
            query = self.client.query.get(name, output_fields)
            query = query.with_bm25(query=text, properties=["_name"])
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = rrf([candidates, result["data"]["Get"][name]])
        if "address" in dsl:
            text = " ".join(re.findall(r"[\dA-Za-z\-]+|\w", dsl["address"]))
            query = self.client.query.get(name, output_fields)
            query = query.with_bm25(query=text, properties=["_address"])
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = rrf([candidates, result["data"]["Get"][name]])
        # ====================== condition search ========================== #
        if not candidates:
            query = self.client.query.get(name, output_fields)
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = result["data"]["Get"][name]
        # ========================== sort ================================= #
        if "sort.slot" in dsl:
            if dsl["sort.ordering"] == "descend":
                candidates = sorted(
                    candidates, key=lambda x: x[dsl["sort.slot"]], reverse=True
                )
            else:
                candidates = sorted(
                    candidates, key=lambda x: x[dsl["sort.slot"]]
                )
        return candidates

以下是这段代码的逐层解析及其核心作用说明:


一、代码功能概述

这是一个 Reciprocal Rank Fusion (RRF) 算法 的实现,用于对多个排序列表 (如搜索结果、推荐列表)进行融合重排。通过聚合不同来源的排序结果,生成更鲁棒的综合排序。


二、参数说明

参数 类型 作用
rankings List[List] 多个排序列表,例如:[[doc1, doc2], [doc3, doc1]](每个列表已排序)
k int 平滑系数,用于控制排名位置对分数的影响强度(默认60)

三、分步代码解析

1. 初始化分数字典
python 复制代码
scores = dict()  # 存储文档ID到(总分数,文档对象)的映射
2. 遍历每个排序列表
python 复制代码
for ranking in rankings:  # 遍历每个来源的排序结果(如不同推荐模型的结果)
    for i, doc in enumerate(ranking):  # 遍历当前列表中的每个文档及其位置i
        # 提取文档唯一标识(假设文档是字典且包含"hotel_id",否则直接使用文档本身)
        doc_id = doc["hotel_id"] if isinstance(doc, dict) else doc
        # 初始化分数或更新分数
        if doc_id not in scores:
            scores[doc_id] = (0, doc)  # 首次出现时初始化为(0, 文档对象)
        # RRF核心公式:累加 1/(k + 当前排名位置i)
        scores[doc_id] = (scores[doc_id][0] + 1 / (k + i), doc)
3. 按总分排序并返回文档
python 复制代码
sorted_scores = sorted(scores.values(), key=lambda x: x[0], reverse=True)
return [item[1] for item in sorted_scores]  # 按分数从高到低返回文档对象

四、核心算法逻辑

1. RRF 公式

每个文档在第 i 个位置 的得分贡献为:

\\text{score} = \\frac{1}{k + i}

  • k 的作用 :防止分母过小(当 i=0 时,k 避免除零错误),同时控制位置影响的衰减速度。
    • k 值越大,排名位置的影响越平缓。
    • k 值越小,排名靠前的文档得分差异越显著。
2. 分数聚合
  • 同一文档在不同列表中出现的次数越多、位置越靠前(i 越小),总分越高。
  • 例如:
    • 文档A在列表1的第1位(i=0)、列表2的第3位(i=2):

      \\text{总分} = \\frac{1}{60+0} + \\frac{1}{60+2} = 0.0167 + 0.0161 = 0.0328

    • 文档B仅在列表1的第2位(i=1):

      \\text{总分} = \\frac{1}{60+1} = 0.0164

    • 结果:文档A排在文档B之前。

五、使用场景示例

1. 酒店推荐系统
  • 输入 :3个推荐模型的排序结果:

    python 复制代码
    model1_ranking = [{"hotel_id": 101}, {"hotel_id": 102}]
    model2_ranking = [{"hotel_id": 103}, {"hotel_id": 101}]
    rankings = [model1_ranking, model2_ranking]
  • 调用

    python 复制代码
    final_ranking = rrf(rankings, k=60)
  • 输出

    • 总分计算:
      • 101号酒店:1/(60+0) + 1/(60+1) ≈ 0.0167 + 0.0164 = 0.0331
      • 102号酒店:1/(60+1) = 0.0164
      • 103号酒店:1/(60+0) = 0.0167
    • 最终排序:[101, 103, 102]
2. 搜索引擎结果融合
  • 合并 Bing、Google 的搜索结果,提升结果相关性。

六、关键设计分析

设计点 优势 潜在问题
使用 k 平滑系数 避免排名靠前位置权重过高,增强鲁棒性 需根据数据调整 k(经验值60)
支持字典或纯ID格式 灵活适配不同数据结构 需确保所有文档有唯一标识(如 hotel_id
时间复杂度 O(N*M) 适合中小规模数据(N=文档数,M=列表数) 大数据量需优化(如分块处理)

七、与其他算法的对比

算法 核心思想 适用场景
RRF 倒数排名加权融合 多来源排序结果聚合
Borda Count 线性加权排名求和 简单快速,但对噪声敏感
MAP 平均精度(需标注数据) 有监督学习的评估场景

八、扩展优化建议

  1. 动态调整 k :根据数据分布自动优化(如高频冲突时增大 k)。
  2. 去重逻辑:同一列表内重复文档的处理(当前代码允许多次出现)。
  3. 并行计算:对大规模数据使用多线程/进程加速。

通过这段代码,可以有效整合多个排序来源,生成更鲁棒的综合排序结果。

相关推荐
暴龙胡乱写博客2 小时前
机器学习 --- 数据集
人工智能·机器学习
SunsPlanter2 小时前
快速入门机器学习的专有名词
人工智能·机器学习
鸿蒙布道师5 小时前
宇树科技安全漏洞揭示智能机器人行业隐忧
运维·网络·科技·安全·机器学习·计算机视觉·机器人
陈苏同学5 小时前
MPC控制器从入门到进阶(小车动态避障变道仿真 - Python)
人工智能·python·机器学习·数学建模·机器人·自动驾驶
yzx9910136 小时前
支持向量机案例
算法·机器学习·支持向量机
IT古董15 小时前
【漫话机器学习系列】249.Word2Vec自然语言训练模型
机器学习·自然语言处理·word2vec
白光白光16 小时前
大语言模型训练的两个阶段
人工智能·机器学习·语言模型
BioRunYiXue16 小时前
一文了解氨基酸的分类、代谢和应用
人工智能·深度学习·算法·机器学习·分类·数据挖掘·代谢组学
IT古董17 小时前
【漫话机器学习系列】255.独立同分布(Independent and Identically Distributed,简称 IID)
人工智能·机器学习
fytianlan17 小时前
机器学习 day6 -线性回归练习
人工智能·机器学习·线性回归