llamaindex 使用火山embedding模型

火山doubao-embedding-vision-xx模型url不是以:/embeddings结尾,导致llama-index-embeddings-openai-like用不了,需要从写一下。

python 复制代码
import httpx
from typing import Any, List, Dict, Optional
from llama_index.embeddings.openai import OpenAIEmbedding


class ExtinctVolcengineEmbedding(OpenAIEmbedding):
    """
    终极版火山引擎多模态 Embedding 类。

    特点:
    1. 继承 OpenAIEmbedding 以满足类型检查。
    2. 【全覆盖】重写了所有父类方法(单条/批量/同步/异步),彻底屏蔽 OpenAI 官方客户端。
    3. 强制使用用户传入的原始 URL,绝不自动拼接 /embeddings。
    4. 自动转换 Payload 为 [{"type": "text"}] 格式。
    """

    def __init__(
            self,
            model_name: str,
            api_key: str,
            api_base: str,
            **kwargs: Any,
    ):
        # 初始化父类(仅为了混个脸熟,满足 isinstance 检查)
        super().__init__(
            model_name=model_name,
            api_key=api_key,
            api_base=api_base,
            **kwargs
        )

        # 保存真正要请求的 URL
        self._volc_endpoint = api_base

        # 初始化自己的客户端
        timeout = kwargs.get("timeout", 60.0)
        self._client_sync = httpx.Client(timeout=timeout)
        self._client_async = httpx.AsyncClient(timeout=timeout)

    def _construct_payload(self, texts: List[str]) -> Dict:
        """构造火山专用 Payload"""
        formatted_input = []
        for text in texts:
            formatted_input.append({"type": "text", "text": text})

        return {
            "model": self.model_name,
            "input": formatted_input,
            "encoding_format": "float"
        }

    def _do_request_sync(self, texts: List[str]) -> List[List[float]]:
        """统一的同步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        # 打印调试信息,确认走的是这里
        # print(f"DEBUG: [VolcEngine] POST {self._volc_endpoint}")

        response = self._client_sync.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    async def _do_request_async(self, texts: List[str]) -> List[List[float]]:
        """统一的异步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        response = await self._client_async.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    def _parse_response(self, data: Dict) -> List[List[float]]:
        """统一解析返回结果"""
        raw_data = data.get("data", [])

        if isinstance(raw_data, list):
            sorted_data = sorted(raw_data, key=lambda x: x.get("index", 0))
            return [item["embedding"] for item in sorted_data]
        elif isinstance(raw_data, dict) and "embedding" in raw_data:
            return [raw_data["embedding"]]
        return []

    # =================================================================
    #  核心覆盖区:以下方法必须全部重写,防止漏网之鱼调用父类
    # =================================================================

    def _get_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (同步)"""
        return self._do_request_sync([query])[0]

    def _get_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (同步)"""
        return self._do_request_sync([text])[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (同步)"""
        return self._do_request_sync(texts)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (异步)"""
        result = await self._do_request_async([query])
        return result[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (异步)"""
        result = await self._do_request_async([text])
        return result[0]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (异步)"""
        return await self._do_request_async(texts)

调用

python 复制代码
async def main():
    """Test VolcEngine embedding with correct endpoint URL."""
    embed_model = VolcEngineMultimodalEmbedding(
        model_name=settings.embedding_model,
        api_key=settings.volcengine_api_key,
        # 直接填完整的 URL,不需要为了凑路径而修改
        api_base="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
    )

    # 外部系统检查
    # isinstance(embed_model, OpenAIEmbedding)  # -> True

    # 调用
    result = embed_model.get_text_embedding("测试一下")
    print(result)


if __name__ == "__main__":
    asyncio.run(main())
相关推荐
前端摸鱼匠3 小时前
YOLOv8 环境配置全攻略:Python、PyTorch 与 CUDA 的和谐共生
人工智能·pytorch·python·yolo·目标检测
WangYaolove13143 小时前
基于python的在线水果销售系统(源码+文档)
python·mysql·django·毕业设计·源码
AALoveTouch3 小时前
大麦网协议分析
javascript·python
ZH15455891313 小时前
Flutter for OpenHarmony Python学习助手实战:自动化脚本开发的实现
python·学习·flutter
xcLeigh3 小时前
Python入门:Python3 requests模块全面学习教程
开发语言·python·学习·模块·python3·requests
xcLeigh3 小时前
Python入门:Python3 statistics模块全面学习教程
开发语言·python·学习·模块·python3·statistics
YongCheng_Liang4 小时前
从零开始学 Python:自动化 / 运维开发实战(核心库 + 3 大实战场景)
python·自动化·运维开发
鸽芷咕4 小时前
为什么越来越多开发者转向 CANN 仓库中的 Python 自动化方案?
python·microsoft·自动化·cann
秋邱4 小时前
用 Python 写出 C++ 的性能?用CANN中PyPTO 算子开发硬核上手指南
开发语言·c++·python
wazmlp0018873695 小时前
python第三次作业
开发语言·python