llamaindex 使用火山embedding模型

火山doubao-embedding-vision-xx模型url不是以:/embeddings结尾,导致llama-index-embeddings-openai-like用不了,需要从写一下。

python 复制代码
import httpx
from typing import Any, List, Dict, Optional
from llama_index.embeddings.openai import OpenAIEmbedding


class ExtinctVolcengineEmbedding(OpenAIEmbedding):
    """
    终极版火山引擎多模态 Embedding 类。

    特点:
    1. 继承 OpenAIEmbedding 以满足类型检查。
    2. 【全覆盖】重写了所有父类方法(单条/批量/同步/异步),彻底屏蔽 OpenAI 官方客户端。
    3. 强制使用用户传入的原始 URL,绝不自动拼接 /embeddings。
    4. 自动转换 Payload 为 [{"type": "text"}] 格式。
    """

    def __init__(
            self,
            model_name: str,
            api_key: str,
            api_base: str,
            **kwargs: Any,
    ):
        # 初始化父类(仅为了混个脸熟,满足 isinstance 检查)
        super().__init__(
            model_name=model_name,
            api_key=api_key,
            api_base=api_base,
            **kwargs
        )

        # 保存真正要请求的 URL
        self._volc_endpoint = api_base

        # 初始化自己的客户端
        timeout = kwargs.get("timeout", 60.0)
        self._client_sync = httpx.Client(timeout=timeout)
        self._client_async = httpx.AsyncClient(timeout=timeout)

    def _construct_payload(self, texts: List[str]) -> Dict:
        """构造火山专用 Payload"""
        formatted_input = []
        for text in texts:
            formatted_input.append({"type": "text", "text": text})

        return {
            "model": self.model_name,
            "input": formatted_input,
            "encoding_format": "float"
        }

    def _do_request_sync(self, texts: List[str]) -> List[List[float]]:
        """统一的同步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        # 打印调试信息,确认走的是这里
        # print(f"DEBUG: [VolcEngine] POST {self._volc_endpoint}")

        response = self._client_sync.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    async def _do_request_async(self, texts: List[str]) -> List[List[float]]:
        """统一的异步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        response = await self._client_async.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    def _parse_response(self, data: Dict) -> List[List[float]]:
        """统一解析返回结果"""
        raw_data = data.get("data", [])

        if isinstance(raw_data, list):
            sorted_data = sorted(raw_data, key=lambda x: x.get("index", 0))
            return [item["embedding"] for item in sorted_data]
        elif isinstance(raw_data, dict) and "embedding" in raw_data:
            return [raw_data["embedding"]]
        return []

    # =================================================================
    #  核心覆盖区:以下方法必须全部重写,防止漏网之鱼调用父类
    # =================================================================

    def _get_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (同步)"""
        return self._do_request_sync([query])[0]

    def _get_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (同步)"""
        return self._do_request_sync([text])[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (同步)"""
        return self._do_request_sync(texts)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (异步)"""
        result = await self._do_request_async([query])
        return result[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (异步)"""
        result = await self._do_request_async([text])
        return result[0]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (异步)"""
        return await self._do_request_async(texts)

调用

python 复制代码
async def main():
    """Test VolcEngine embedding with correct endpoint URL."""
    embed_model = VolcEngineMultimodalEmbedding(
        model_name=settings.embedding_model,
        api_key=settings.volcengine_api_key,
        # 直接填完整的 URL,不需要为了凑路径而修改
        api_base="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
    )

    # 外部系统检查
    # isinstance(embed_model, OpenAIEmbedding)  # -> True

    # 调用
    result = embed_model.get_text_embedding("测试一下")
    print(result)


if __name__ == "__main__":
    asyncio.run(main())
相关推荐
DevnullCoffe1 天前
Open Claw × 跨境电商:5个最有价值的 AI Agent 应用场景深度拆解
python·api
有点傻的小可爱1 天前
【MATLAB】新安装并口如何实现能通过PTB启用?
开发语言·windows·经验分享·matlab
zh路西法1 天前
【宇树机器人强化学习】(六):TensorBoard图表与手柄遥控go2测试
python·深度学习·机器学习·机器人
szcsun51 天前
关于在pycharm中新建项目创建虚拟化环境venv
ide·python·pycharm
码路飞1 天前
体验完阿里「悟空」之后,我花 2 小时用 Python 撸了个 AI Agent 🔥
python·aigc
万里沧海寄云帆1 天前
pytorch+cpu版本对Intel Ultra 9 275HX性能的影响
人工智能·pytorch·python
java资料站1 天前
python爬虫入门
python
Drone_xjw1 天前
【环境搭建】Windows 10上使用Docker搭建本地Git仓库(Gitea)完整教程
windows·git·docker
1941s1 天前
Google Agent Development Kit (ADK) 指南 第二章:环境搭建与快速开始
人工智能·python·adk·google agent
深蓝轨迹1 天前
彻底删除VMware虚拟机并清理残留,解决虚拟网卡消失问题
windows·ubuntu·centos