火山doubao-embedding-vision-xx模型url不是以:/embeddings结尾,导致llama-index-embeddings-openai-like用不了,需要从写一下。
python
import httpx
from typing import Any, List, Dict, Optional
from llama_index.embeddings.openai import OpenAIEmbedding
class ExtinctVolcengineEmbedding(OpenAIEmbedding):
"""
终极版火山引擎多模态 Embedding 类。
特点:
1. 继承 OpenAIEmbedding 以满足类型检查。
2. 【全覆盖】重写了所有父类方法(单条/批量/同步/异步),彻底屏蔽 OpenAI 官方客户端。
3. 强制使用用户传入的原始 URL,绝不自动拼接 /embeddings。
4. 自动转换 Payload 为 [{"type": "text"}] 格式。
"""
def __init__(
self,
model_name: str,
api_key: str,
api_base: str,
**kwargs: Any,
):
# 初始化父类(仅为了混个脸熟,满足 isinstance 检查)
super().__init__(
model_name=model_name,
api_key=api_key,
api_base=api_base,
**kwargs
)
# 保存真正要请求的 URL
self._volc_endpoint = api_base
# 初始化自己的客户端
timeout = kwargs.get("timeout", 60.0)
self._client_sync = httpx.Client(timeout=timeout)
self._client_async = httpx.AsyncClient(timeout=timeout)
def _construct_payload(self, texts: List[str]) -> Dict:
"""构造火山专用 Payload"""
formatted_input = []
for text in texts:
formatted_input.append({"type": "text", "text": text})
return {
"model": self.model_name,
"input": formatted_input,
"encoding_format": "float"
}
def _do_request_sync(self, texts: List[str]) -> List[List[float]]:
"""统一的同步请求处理函数"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = self._construct_payload(texts)
# 打印调试信息,确认走的是这里
# print(f"DEBUG: [VolcEngine] POST {self._volc_endpoint}")
response = self._client_sync.post(
self._volc_endpoint,
headers=headers,
json=payload
)
if response.status_code != 200:
raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")
return self._parse_response(response.json())
async def _do_request_async(self, texts: List[str]) -> List[List[float]]:
"""统一的异步请求处理函数"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = self._construct_payload(texts)
response = await self._client_async.post(
self._volc_endpoint,
headers=headers,
json=payload
)
if response.status_code != 200:
raise ValueError(f"VolcEngine Error ({response.status_code}): {response.text}")
return self._parse_response(response.json())
def _parse_response(self, data: Dict) -> List[List[float]]:
"""统一解析返回结果"""
raw_data = data.get("data", [])
if isinstance(raw_data, list):
sorted_data = sorted(raw_data, key=lambda x: x.get("index", 0))
return [item["embedding"] for item in sorted_data]
elif isinstance(raw_data, dict) and "embedding" in raw_data:
return [raw_data["embedding"]]
return []
# =================================================================
# 核心覆盖区:以下方法必须全部重写,防止漏网之鱼调用父类
# =================================================================
def _get_query_embedding(self, query: str) -> List[float]:
"""覆盖:单条 Query (同步)"""
return self._do_request_sync([query])[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""覆盖:单条 Text (同步)"""
return self._do_request_sync([text])[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""覆盖:批量 Text (同步)"""
return self._do_request_sync(texts)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""覆盖:单条 Query (异步)"""
result = await self._do_request_async([query])
return result[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""覆盖:单条 Text (异步)"""
result = await self._do_request_async([text])
return result[0]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""覆盖:批量 Text (异步)"""
return await self._do_request_async(texts)
调用
python
async def main():
"""Test VolcEngine embedding with correct endpoint URL."""
embed_model = VolcEngineMultimodalEmbedding(
model_name=settings.embedding_model,
api_key=settings.volcengine_api_key,
# 直接填完整的 URL,不需要为了凑路径而修改
api_base="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
)
# 外部系统检查
# isinstance(embed_model, OpenAIEmbedding) # -> True
# 调用
result = embed_model.get_text_embedding("测试一下")
print(result)
if __name__ == "__main__":
asyncio.run(main())