【Weaviate】使用递归排名融合RRF推荐酒店

本文实现了递归排名融合(RRF)算法,用于合并多个排序列表并生成一个新的排序结果。具体步骤如下:

初始化一个空字典 scores。

遍历每个排序列表 ranking,对每个文档计算其 RRF 分数并存储在 scores 中。

对所有文档的 RRF 分数进行排序,返回按分数从高到低排序的文档列表。

使用 RRF(Reciprocal Rank Fusion)算法 合并新旧搜索结果,可提升排序鲁棒性。

python 复制代码
import os
import re

from tqdm import tqdm
import weaviate
from weaviate import Client
from weaviate.util import generate_uuid5
from dotenv import load_dotenv
load_dotenv()

weaviate_url = os.environ.get("WEAVIATE_URL", "")
weaviate_key = os.environ.get("WEAVIATE_API_KEY", "")

def rrf(rankings, k=60):
    scores = dict()
    for ranking in rankings:
        for i, doc in enumerate(ranking):
            doc_id = doc["hotel_id"] if isinstance(doc, dict) else doc
            if doc_id not in scores:
                scores[doc_id] = (0, doc)
            scores[doc_id] = (scores[doc_id][0] + 1 / (k + i), doc)
    sorted_scores = sorted(scores.values(), key=lambda x: x[0], reverse=True)
    return [item[1] for item in sorted_scores]


class HotelDB:
    def __init__(self, ip="localhost", port=8080):
        header = {"X-OpenAI-Api-Key": os.getenv("OPENAI_API_KEY")}

        url = weaviate_url
        auth_config = weaviate.AuthApiKey(api_key=weaviate_key)

        self.client = weaviate.Client(
            url=url,
            additional_headers=header,
            auth_client_secret=auth_config,
        )

        # try:
        #     url = os.getenv("WEAVIATE_URL")
        #     # url = f"http://{ip}:{port}"
        #     self.client = Client(
        #         url=url, additional_headers=header, timeout_config=(3, 10)
        #     )
        # except Exception:
        #     ip = "weaviate"
        #     url = f"http://{ip}:{port}"
        #     self.client = Client(
        #         url=url, additional_headers=header, timeout_config=(3, 10)
        #     )

    def create(self, name="Hotel"):
        schema = {
            "classes": [
                {
                    "class": name,
                    "description": "hotel meta data",
                    "properties": [
                        {
                            "dataType": ["number"],
                            "description": "id of hotel",
                            "name": "hotel_id",
                        },
                        {
                            "dataType": ["text"],
                            "description": "name of hotel",
                            "name": "_name",  # 分词过用于搜索的
                            "indexSearchable": True,
                            "tokenization": "whitespace",
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "type of hotel",
                            "name": "name",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "type of hotel",
                            "name": "type",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "address of hotel",
                            "name": "_address",  # 分词过用于搜索的
                            "indexSearchable": True,
                            "tokenization": "whitespace",
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "type of hotel",
                            "name": "address",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "nearby subway",
                            "name": "subway",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["text"],
                            "description": "phone of hotel",
                            "name": "phone",
                            "indexSearchable": False,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": True}
                            },
                        },
                        {
                            "dataType": ["number"],
                            "description": "price of hotel",
                            "name": "price",
                        },
                        {
                            "dataType": ["number"],
                            "description": "rating of hotel",
                            "name": "rating",
                        },
                        {
                            "dataType": ["text"],
                            "description": "facilities provided",
                            "name": "facilities",
                            "indexSearchable": True,
                            "moduleConfig": {
                                "text2vec-contextionary": {"skip": False}
                            },
                        },
                    ],
                    "vectorizer": "text2vec-openai",
                    "moduleConfig": {
                        "text2vec-openai": {
                            "vectorizeClassName": False,
                            "model": "ada",
                            "modelVersion": "002",
                            "type": "text",
                        },
                    },
                }
            ]
        }
        self.client.schema.create(schema)
        # 单class创建也可用client.schema.create_class(schema)

    def delete(self, name="Hotel"):
        self.client.schema.delete_class(name)

    def insert(self, data, name="Hotel", batch=4):
        self.client.batch.configure(batch_size=batch, dynamic=True)
        for item in tqdm(data):
            self.client.batch.add_data_object(
                data_object=item,
                class_name=name,
                uuid=generate_uuid5(item, name),
            )
        self.client.batch.flush()

    def search(
        self,
        dsl,
        name="Hotel",
        output_fields=["hotel_id", "name", "type", "rating", "price"],
        limit=10,
    ):
        candidates = []
        if not dsl:
            return []
        if "hotel_id" not in output_fields:  # rrf排序中使用hotel_id
            output_fields.append("hotel_id")
        # ===================== assemble filters ========================= #
        filters = []
        keys = [
            "type",
            "price.range.low",
            "price.range.high",
            "rating.range.low",
            "rating.range.hight",
        ]
        if any(key in dsl for key in keys):
            if "type" in dsl:
                filters.append(
                    {
                        "path": ["type"],
                        "operator": "Equal",
                        "valueString": dsl["type"],
                    }
                )
            if "price.range.low" in dsl:
                filters.append(
                    {
                        "path": ["price"],
                        "operator": "GreaterThan",
                        "valueNumber": dsl["price.range.low"],
                    }
                )
            if "price.range.high" in dsl:
                filters.append(
                    {
                        "path": ["price"],
                        "operator": "LessThan",
                        "valueNumber": dsl["price.range.high"],
                    }
                )
            if "rating.range.low" in dsl:
                filters.append(
                    {
                        "path": ["rating"],
                        "operator": "GreaterThan",
                        "valueNumber": dsl["rating.range.low"],
                    }
                )
            if "rating.range.high" in dsl:
                filters.append(
                    {
                        "path": ["rating"],
                        "operator": "LessThan",
                        "valueNumber": dsl["rating.range.high"],
                    }
                )
            # 补丁,过滤掉未给价格的-1值
            filters.append(
                {
                    "path": ["price"],
                    "operator": "GreaterThan",
                    "valueNumber": 0,
                }
            )
        if (len(filters)) == 1:
            filters = filters[0]
        elif len(filters) > 1:
            filters = {"operator": "And", "operands": filters}
        # ===================== vector search ============================= #
        if "facilities" in dsl:
            query = self.client.query.get(name, output_fields)
            query = query.with_near_text(
                {"concepts": [f'酒店提供:{dsl["facilities"]}']}
            )
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = rrf([candidates, result["data"]["Get"][name]])
        # ===================== keyword search ============================ #
        if "name" in dsl:
            text = " ".join(re.findall(r"[\dA-Za-z\-]+|\w", dsl["name"]))
            query = self.client.query.get(name, output_fields)
            query = query.with_bm25(query=text, properties=["_name"])
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = rrf([candidates, result["data"]["Get"][name]])
        if "address" in dsl:
            text = " ".join(re.findall(r"[\dA-Za-z\-]+|\w", dsl["address"]))
            query = self.client.query.get(name, output_fields)
            query = query.with_bm25(query=text, properties=["_address"])
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = rrf([candidates, result["data"]["Get"][name]])
        # ====================== condition search ========================== #
        if not candidates:
            query = self.client.query.get(name, output_fields)
            if filters:
                query = query.with_where(filters)
            query = query.with_limit(limit)
            result = query.do()
            candidates = result["data"]["Get"][name]
        # ========================== sort ================================= #
        if "sort.slot" in dsl:
            if dsl["sort.ordering"] == "descend":
                candidates = sorted(
                    candidates, key=lambda x: x[dsl["sort.slot"]], reverse=True
                )
            else:
                candidates = sorted(
                    candidates, key=lambda x: x[dsl["sort.slot"]]
                )
        return candidates

以下是这段代码的逐层解析及其核心作用说明:


一、代码功能概述

这是一个 Reciprocal Rank Fusion (RRF) 算法 的实现,用于对多个排序列表 (如搜索结果、推荐列表)进行融合重排。通过聚合不同来源的排序结果,生成更鲁棒的综合排序。


二、参数说明

参数 类型 作用
rankings List[List] 多个排序列表,例如:[[doc1, doc2], [doc3, doc1]](每个列表已排序)
k int 平滑系数,用于控制排名位置对分数的影响强度(默认60)

三、分步代码解析

1. 初始化分数字典
python 复制代码
scores = dict()  # 存储文档ID到(总分数,文档对象)的映射
2. 遍历每个排序列表
python 复制代码
for ranking in rankings:  # 遍历每个来源的排序结果(如不同推荐模型的结果)
    for i, doc in enumerate(ranking):  # 遍历当前列表中的每个文档及其位置i
        # 提取文档唯一标识(假设文档是字典且包含"hotel_id",否则直接使用文档本身)
        doc_id = doc["hotel_id"] if isinstance(doc, dict) else doc
        # 初始化分数或更新分数
        if doc_id not in scores:
            scores[doc_id] = (0, doc)  # 首次出现时初始化为(0, 文档对象)
        # RRF核心公式:累加 1/(k + 当前排名位置i)
        scores[doc_id] = (scores[doc_id][0] + 1 / (k + i), doc)
3. 按总分排序并返回文档
python 复制代码
sorted_scores = sorted(scores.values(), key=lambda x: x[0], reverse=True)
return [item[1] for item in sorted_scores]  # 按分数从高到低返回文档对象

四、核心算法逻辑

1. RRF 公式

每个文档在第 i 个位置 的得分贡献为:

[ \text{score} = \frac{1}{k + i} ]

  • k 的作用 :防止分母过小(当 i=0 时,k 避免除零错误),同时控制位置影响的衰减速度。
    • k 值越大,排名位置的影响越平缓。
    • k 值越小,排名靠前的文档得分差异越显著。
2. 分数聚合
  • 同一文档在不同列表中出现的次数越多、位置越靠前(i 越小),总分越高。
  • 例如:
    • 文档A在列表1的第1位(i=0)、列表2的第3位(i=2):
      [ \text{总分} = \frac{1}{60+0} + \frac{1}{60+2} = 0.0167 + 0.0161 = 0.0328 ]
    • 文档B仅在列表1的第2位(i=1):
      [ \text{总分} = \frac{1}{60+1} = 0.0164 ]
    • 结果:文档A排在文档B之前。

五、使用场景示例

1. 酒店推荐系统
  • 输入 :3个推荐模型的排序结果:

    python 复制代码
    model1_ranking = [{"hotel_id": 101}, {"hotel_id": 102}]
    model2_ranking = [{"hotel_id": 103}, {"hotel_id": 101}]
    rankings = [model1_ranking, model2_ranking]
  • 调用

    python 复制代码
    final_ranking = rrf(rankings, k=60)
  • 输出

    • 总分计算:
      • 101号酒店:1/(60+0) + 1/(60+1) ≈ 0.0167 + 0.0164 = 0.0331
      • 102号酒店:1/(60+1) = 0.0164
      • 103号酒店:1/(60+0) = 0.0167
    • 最终排序:[101, 103, 102]
2. 搜索引擎结果融合
  • 合并 Bing、Google 的搜索结果,提升结果相关性。

六、关键设计分析

设计点 优势 潜在问题
使用 k 平滑系数 避免排名靠前位置权重过高,增强鲁棒性 需根据数据调整 k(经验值60)
支持字典或纯ID格式 灵活适配不同数据结构 需确保所有文档有唯一标识(如 hotel_id
时间复杂度 O(N*M) 适合中小规模数据(N=文档数,M=列表数) 大数据量需优化(如分块处理)

七、与其他算法的对比

算法 核心思想 适用场景
RRF 倒数排名加权融合 多来源排序结果聚合
Borda Count 线性加权排名求和 简单快速,但对噪声敏感
MAP 平均精度(需标注数据) 有监督学习的评估场景

八、扩展优化建议

  1. 动态调整 k :根据数据分布自动优化(如高频冲突时增大 k)。
  2. 去重逻辑:同一列表内重复文档的处理(当前代码允许多次出现)。
  3. 并行计算:对大规模数据使用多线程/进程加速。

通过这段代码,可以有效整合多个排序来源,生成更鲁棒的综合排序结果。

相关推荐
碳基学AI40 分钟前
厦大团队|报告:《读懂大模型概念、技术与应用实践》140 页 PPT(文末附链接下载)
大数据·人工智能·深度学习·机器学习·ai·知识图谱
windyrain4 小时前
AI 学习之路(一)- 重新认识 AI
人工智能·机器学习·aigc
Luis Li 的猫猫5 小时前
机器学习:特征提取
人工智能·目标检测·机器学习·视觉检测
python算法(魔法师版)6 小时前
自动驾驶FSD技术的核心算法与软件实现
人工智能·深度学习·神经网络·算法·机器学习·自动驾驶
云天徽上6 小时前
【目标检测】目标检测中的数据增强终极指南:从原理到实战,用Python解锁模型性能提升密码(附YOLOv5实战代码)
人工智能·python·yolo·目标检测·机器学习·计算机视觉
IT古董6 小时前
【漫话机器学习系列】116.矩阵(Matrices)
人工智能·机器学习·矩阵
SomeB1oody7 小时前
【Python机器学习】1.1. 机器学习(Machine Learning)介绍
开发语言·人工智能·python·机器学习
atwdy7 小时前
【决策树】分类属性的选择
决策树·机器学习·cart
Jason_Orton7 小时前
决策树(Decision Tree):机器学习中的经典算法
人工智能·算法·决策树·随机森林·机器学习
紫雾凌寒7 小时前
计算机视觉|ConvNeXt:CNN 的复兴,Transformer 的新对手
人工智能·神经网络·机器学习·计算机视觉·transformer·动态网络·convnext