llamaindex 使用火山embedding模型

火山doubao-embedding-vision-xx模型url不是以:/embeddings结尾,导致llama-index-embeddings-openai-like用不了,需要从写一下。

python 复制代码
import httpx
from typing import Any, List, Dict, Optional
from llama_index.embeddings.openai import OpenAIEmbedding


class ExtinctVolcengineEmbedding(OpenAIEmbedding):
    """
    终极版火山引擎多模态 Embedding 类。

    特点:
    1. 继承 OpenAIEmbedding 以满足类型检查。
    2. 【全覆盖】重写了所有父类方法(单条/批量/同步/异步),彻底屏蔽 OpenAI 官方客户端。
    3. 强制使用用户传入的原始 URL,绝不自动拼接 /embeddings。
    4. 自动转换 Payload 为 [{"type": "text"}] 格式。
    """

    def __init__(
            self,
            model_name: str,
            api_key: str,
            api_base: str,
            **kwargs: Any,
    ):
        # 初始化父类(仅为了混个脸熟,满足 isinstance 检查)
        super().__init__(
            model_name=model_name,
            api_key=api_key,
            api_base=api_base,
            **kwargs
        )

        # 保存真正要请求的 URL
        self._volc_endpoint = api_base

        # 初始化自己的客户端
        timeout = kwargs.get("timeout", 60.0)
        self._client_sync = httpx.Client(timeout=timeout)
        self._client_async = httpx.AsyncClient(timeout=timeout)

    def _construct_payload(self, texts: List[str]) -> Dict:
        """构造火山专用 Payload"""
        formatted_input = []
        for text in texts:
            formatted_input.append({"type": "text", "text": text})

        return {
            "model": self.model_name,
            "input": formatted_input,
            "encoding_format": "float"
        }

    def _do_request_sync(self, texts: List[str]) -> List[List[float]]:
        """统一的同步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        # 打印调试信息,确认走的是这里
        # print(f"DEBUG: [VolcEngine] POST {self._volc_endpoint}")

        response = self._client_sync.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    async def _do_request_async(self, texts: List[str]) -> List[List[float]]:
        """统一的异步请求处理函数"""
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        payload = self._construct_payload(texts)

        response = await self._client_async.post(
            self._volc_endpoint,
            headers=headers,
            json=payload
        )

        if response.status_code != 200:
            raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")

        return self._parse_response(response.json())

    def _parse_response(self, data: Dict) -> List[List[float]]:
        """统一解析返回结果"""
        raw_data = data.get("data", [])

        if isinstance(raw_data, list):
            sorted_data = sorted(raw_data, key=lambda x: x.get("index", 0))
            return [item["embedding"] for item in sorted_data]
        elif isinstance(raw_data, dict) and "embedding" in raw_data:
            return [raw_data["embedding"]]
        return []

    # =================================================================
    #  核心覆盖区:以下方法必须全部重写,防止漏网之鱼调用父类
    # =================================================================

    def _get_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (同步)"""
        return self._do_request_sync([query])[0]

    def _get_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (同步)"""
        return self._do_request_sync([text])[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (同步)"""
        return self._do_request_sync(texts)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """覆盖:单条 Query (异步)"""
        result = await self._do_request_async([query])
        return result[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """覆盖:单条 Text (异步)"""
        result = await self._do_request_async([text])
        return result[0]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """覆盖:批量 Text (异步)"""
        return await self._do_request_async(texts)

调用

python 复制代码
async def main():
    """Test VolcEngine embedding with correct endpoint URL."""
    embed_model = VolcEngineMultimodalEmbedding(
        model_name=settings.embedding_model,
        api_key=settings.volcengine_api_key,
        # 直接填完整的 URL,不需要为了凑路径而修改
        api_base="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
    )

    # 外部系统检查
    # isinstance(embed_model, OpenAIEmbedding)  # -> True

    # 调用
    result = embed_model.get_text_embedding("测试一下")
    print(result)


if __name__ == "__main__":
    asyncio.run(main())
相关推荐
X56613 小时前
如何在 Laravel 中正确保存嵌套动态表单数据(主服务与子服务)
jvm·数据库·python
ZhengEnCi3 小时前
03ab-PyTorch安装教程 📚
python
狐狐生风4 小时前
LangChain 向量存储:Chroma、FAISS
人工智能·python·学习·langchain·faiss·agentai
狐狐生风4 小时前
LangChain RAG 基础
人工智能·python·学习·langchain·rag·agentai
老前端的功夫4 小时前
【Java从入门到入土】28:Stream API:告别for循环的新时代
java·开发语言·python
蚰蜒螟5 小时前
深入 Linux 内核同步机制:从 futex 到 spinlock 的完整旅程
linux·windows·microsoft
yaoxin5211235 小时前
397. Java 文件操作基础 - 创建常规文件与临时文件
java·开发语言·python
dFObBIMmai5 小时前
MySQL主从同步中大事务导致的延迟_如何拆分大事务优化同步
jvm·数据库·python
szccyw05 小时前
mysql如何限制特定存储过程执行权限_MySQL存储过程安全访问
jvm·数据库·python
小白学大数据5 小时前
Python 自动化爬取网易云音乐歌手歌词实战教程
爬虫·python·okhttp·自动化