llamaindex 使用火山embedding模型

火山doubao-embedding-vision-xx模型url不是以:/embeddings结尾,导致llama-index-embeddings-openai-like用不了,需要从写一下。

python 复制代码
import httpx
from typing import Any, List, Dict, Optional
from llama_index.embeddings.openai import OpenAIEmbedding


class ExtinctVolcengineEmbedding(OpenAIEmbedding):
    """
    终极版火山引擎多模态 Embedding 类。

    特点:
    1. 继承 OpenAIEmbedding 以满足类型检查。
    2. 【全覆盖】重写了所有父类方法(单条/批量/同步/异步),彻底屏蔽 OpenAI 官方客户端。
    3. 强制使用用户传入的原始 URL,绝不自动拼接 /embeddings。
    4. 自动转换 Payload 为 [{"type": "text"}] 格式。
    """

    def __init__(
            self,
            model_name: str,
            api_key: str,
            api_base: str,
            **kwargs: Any,
    ):
        # 初始化父类(仅为了混个脸熟,满足 isinstance 检查)
        super().__init__(
            model_name=model_name,
            api_key=api_key,
            api_base=api_base,
            **kwargs
        )

        # 保存真正要请求的 URL
        self._volc_endpoint = api_base

        # 初始化自己的客户端
        timeout = kwargs.get("timeout", 60.0)
        self._client_sync = httpx.Client(timeout=timeout)
        self._client_async = httpx.AsyncClient(timeout=timeout)

    def _construct_payload(self, texts: List[str]) -> Dict:
        """构造火山专用 Payload"""
        formatted_input = []
        for text in texts:
            formatted_input.append({"type": "text", "text": text})

        return {
            "model": self.model_name,
            "input": formatted_input,
            "encoding_format": "float"
        }

    def _do_request_sync(self, texts: List[str]) -> List[List[float]]:
        """统一的同步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        # 打印调试信息,确认走的是这里
        # print(f"DEBUG: [VolcEngine] POST {self._volc_endpoint}")

        response = self._client_sync.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    async def _do_request_async(self, texts: List[str]) -> List[List[float]]:
        """统一的异步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        response = await self._client_async.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    def _parse_response(self, data: Dict) -> List[List[float]]:
        """统一解析返回结果"""
        raw_data = data.get("data", [])

        if isinstance(raw_data, list):
            sorted_data = sorted(raw_data, key=lambda x: x.get("index", 0))
            return [item["embedding"] for item in sorted_data]
        elif isinstance(raw_data, dict) and "embedding" in raw_data:
            return [raw_data["embedding"]]
        return []

    # =================================================================
    #  核心覆盖区:以下方法必须全部重写,防止漏网之鱼调用父类
    # =================================================================

    def _get_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (同步)"""
        return self._do_request_sync([query])[0]

    def _get_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (同步)"""
        return self._do_request_sync([text])[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (同步)"""
        return self._do_request_sync(texts)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (异步)"""
        result = await self._do_request_async([query])
        return result[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (异步)"""
        result = await self._do_request_async([text])
        return result[0]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (异步)"""
        return await self._do_request_async(texts)

调用

python 复制代码
async def main():
    """Test VolcEngine embedding with correct endpoint URL."""
    embed_model = VolcEngineMultimodalEmbedding(
        model_name=settings.embedding_model,
        api_key=settings.volcengine_api_key,
        # 直接填完整的 URL,不需要为了凑路径而修改
        api_base="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
    )

    # 外部系统检查
    # isinstance(embed_model, OpenAIEmbedding)  # -> True

    # 调用
    result = embed_model.get_text_embedding("测试一下")
    print(result)


if __name__ == "__main__":
    asyncio.run(main())
相关推荐
老毛肚7 小时前
jeecg-boot-base-core 02 day
javascript·python
yaoxin5211237 小时前
434. Java 日期时间 API - Period 基于日期的时间段
java·开发语言·python
岁月宁静8 小时前
RAG 文档摄入全链路,从原理到生产落地
vue.js·人工智能·python
caimouse8 小时前
Reactos 第 8 章 结构化异常处理 — 8.2 系统空间的结构化异常处理
windows
JaydenAI8 小时前
[对比学习LangChain和MAF-07]如何引入人机交互的审批流程
python·ai·langchain·c#·agent·hitl·maf
caimouse9 小时前
Reactos 第 7 章 视窗报文 — 7.3 Win32k 的用户空间回调机制
windows
caimouse9 小时前
Reactos 第 9 章 设备驱动 — 9.5 一组PnP设备驱动模块的实例
网络·windows
神成19 小时前
vmware 上 win7 系统按照 vmware tool
windows
神奇元创9 小时前
商用级光路加速卡:大模型推理的极速落地方案
python·神经网络·fpga开发·dsp开发