Embedding Gemma,谷歌发布的小而精向量模型,仅需0.3B|附RAG实战代码

1. 简介

EmbeddingGemma是Google发布的开源小规模多语言文本嵌入模型,旨在常见设备(如手机、笔记本、台式机)上高效运行,同时在 MTEB / MMTEB 等评测任务中保持与同类模型竞争的性能。其核心价值在于:支持离线运行、保护隐私、内存占用小、兼容多种推理框架,并提供灵活的输出维度。

2. 关键特性

模型规模与上下文长度:参数量约为 308M,支持最长 2K tokens 的上下文窗口,默认输出向量维度为 768。

多语言支持:训练语料涵盖 100 多种语言。

MRL(嵌套表示学习):支持将 768 维的嵌入向量截断为 512/256/128 等更短维度,兼顾效率与效果。

工程优化:支持量化感知训练(QAT)与设备端优化,量化后模型大小可控制在 200MB 以内。

3. 架构与训练要点

基于 Gemma 3 的编码器风格骨干网络:EmbeddingGemma 使用 Gemma 3 作为主干网络,并将注意力机制改为双向,相比传统的解码器结构在大规模检索/嵌入任务中表现更优。

池化与映射结构:模型先输出 token 级别的向量,再通过均值池化聚合成文本向量,最后经由两层全连接层映射到 768 维的嵌入空间。

训练数据与清洗:模型基于大规模多语言语料训练,训练过程中实施了严格的质量过滤与安全过滤。

4. MRL与Prompt 模板

MRL原理和效果

EmbeddingGemma 在训练阶段引入了 MRL机制,使得推理时可选择不同的嵌入维度(768/512/256/128),而性能下降控制在可接受范围内。其原理类似"套娃"结构,高维嵌入向量的前一部分包含"核心信息",后一部分为"补充信息"。因此既可使用完整 768 维向量,也可仅使用前 512/256/128 维,而不导致效果显著下降。

在资源受限环境中,可通过牺牲少量精度换取更低的存储与计算开销。而在分层检索系统中,可先用低维嵌入进行粗筛,再使用高维嵌入对候选结果进行精排。

支持针对特定任务的Prompt

EmbeddingGemma 在训练时使用了一系列 prompt_name(例如 query、document、STS 等)。在实际推理或测试中,建议沿用这些prompt以获得最佳效果。

若使用sentence-transformers库,其内置的 encode_query/encode_document方法会自动添加相应 prompt。若直接使用 Transformers 或其他底层框架,需在输入文本前手动拼接对应的prompt。

支持的 Prompt 模板(根据不同任务使用可提升效果):

bash 复制代码
query: "task: search result | query: "

document: "title: none | text: "

BitextMining: "task: search result | query: "

Clustering: "task: clustering | query: "

Classification: "task: classification | query: "

InstructionRetrieval: "task: code retrieval | query: "

MultilabelClassification: "task: classification | query: "

PairClassification: "task: sentence similarity | query: "

Reranking: "task: search result | query: "

Retrieval-query: "task: search result | query: "

Retrieval-document: "title: none | text: "

STS: "task: sentence similarity | query: "

Summarization: "task: summarization | query: "

5. 实战:一个简单的RAG系统示例

1 环境依赖

bash 复制代码
pip install sentence-transformers==2.2.2 transformers huggingface-hub
pip install faiss-cpu  # 若使用 GPU 可安装 faiss-gpu

2加载生成语言模型

javascript 复制代码
# 加载 Gemma 3 生成模型
from transformers import pipeline

pipeline = pipeline(
    task="text-generation",
    model="google/gemma-3-4b-it",
    device_map="auto",
    dtype="auto"
)

3 加载嵌入模型

javascript 复制代码
import torch
from sentence_transformers import SentenceTransformer

# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"运行设备: {model.device}")
print(model)
print("模型总参数量:", sum([p.numel() for _, p in model.named_parameters()]))

4 使用 Prompt 模板

在 RAG 系统中,为查询和文档使用不同的 prompt 模板可提升效果:

  • 查询编码:使用prompt_name="Retrieval-query"
javascript 复制代码
query_embedding = model.encode(
    "How do I use prompts with this model?",
    prompt_name="Retrieval-query"
)
  • 文档编码:使用prompt_name="Retrieval-document"

为进一步优化,可加入标题信息:

  • 带标题:
javascript 复制代码
doc_embedding = model.encode(
    "文档正文内容...",
    prompt_name="Retrieval-document",
    prompt="title: 在 RAG 中使用 Prompt | text: "
)
  • 无标题
javascript 复制代码
doc_embedding = model.encode(
    "文档正文内容...",
    prompt_name="Retrieval-document",
    prompt="title: none | text: "
)

5 构建示例知识库

javascript 复制代码
# 公司知识库示例
corp_knowledge_base = [
    {
        "category": "人力资源与请假政策",
        "documents": [
            {
                "title": "非计划缺勤流程",
                "content": "如因生病或紧急情况无法工作,请于东京时间上午9:30前通过邮件通知直接上级和人力资源部门,邮件主题请注明"病假 - [你的姓名]"。若连续缺勤超过两天,返岗时需提供医生证明(診断書)。"
            },
            {
                "title": "年假政策",
                "content": "全职员工首年可享受10天带薪年假。年假自入职满六个月后生效,并随服务年限增加。例如,服务满三年的员工每年可享受14天年假。详细 accrual 表请参阅附件《年假累积表》。"
            },
        ]
    },
    {
        "category": "IT 与安全",
        "documents": [
            {
                "title": "账户密码管理",
                "content": "如忘记密码或账户被锁定,请使用自助重置门户:https://reset.ourcompany。系统将提示您回答预设的安全问题。出于安全考虑,IT帮助台无法通过电话或邮件重置密码。若未设置安全问题,请携带员工ID卡至涩谷办公室12楼IT支持台处理。"
            },
            {
                "title": "软件采购流程",
                "content": "所有新软件申请须通过'IT服务台'门户中的'软件申请'类别提交,需包含业务理由。所有软件许可须经部门负责人批准后方可采购。请注意,标准办公软件已预批准,无需走此流程。"
            },
        ]
    },
    {
        "category": "财务与报销",
        "documents": [
            {
                "title": "费用报销政策",
                "content": "为确保及时处理,当月所有费用报销须在次月第5个工作日前提交审批。例如,7月产生的所有费用须在8月第5个工作日前提交。逾期提交的费用可能延至下一付款周期处理。"
            },
            {
                "title": "差旅费用指南",
                "content": "差旅费用原则上按最合理、最经济路线的实际成本报销。使用新干线或飞机前请提前提交差旅费用申请。仅当公共交通不可用或运输重型设备时允许乘坐出租车。必须保留收据。"
            },
        ]
    },
    {
        "category": "办公与设施",
        "documents": [
            {
                "title": "会议室预订指南",
                "content": "涩谷办公室所有会议室可通过日历应用预订。创建新会议邀请,添加参会人后,使用'会议室查找'功能选择可用房间。请确保选择正确的楼层。10人以上会议请预订14楼的'樱'或'富士'房间。"
            },
            {
                "title": "邮件与快递政策",
                "content": "公司邮件服务仅用于业务相关信函。出于安全与责任考虑,请员工避免将私人包裹或邮件寄至涩谷办公室地址。前台无法代收或保管私人快递。"
            },
        ]
    },
]

6 设置辅助函数并检索

javascript 复制代码
question = "如何重置密码?"  # 可替换为其他问题
similarity_threshold = 0.4  # 相似度阈值,低于此值认为不匹配语义搜索核心逻辑:
javascript 复制代码
# --- 语义搜索相关辅助函数 ---

def _calculate_best_match(similarities):
    print("相似度分数:", similarities)
    if similarities is None or similarities.nelement() == 0:
        return None, 0.0

    # 找出最高分及其索引
    best_index = similarities.argmax().item()
    best_score = similarities[0, best_index].item()

    return best_index, best_score

def find_best_category(model, query, candidates):
    """
    从候选类别中找出与查询最相关的类别。

    参数:
        model: SentenceTransformer 模型
        query: 用户查询字符串
        candidates: 类别名称列表

    返回:
        最佳类别索引和相似度分数
    """
    if not candidates:
        return None, 0.0

    # 对查询和候选类别进行编码(使用分类任务prompt)
    query_embedding = model.encode(query, prompt_name="Classification")
    candidate_embeddings = model.encode(candidates, prompt_name="Classification")

    print("候选类别:", candidates)
    return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))

def find_best_doc(model, query, candidates):
    """
    从候选文档中找出与查询最相关的文档。

    参数:
        model: SentenceTransformer 模型
        query: 用户查询字符串
        candidates: 文档列表,每个文档应包含 'title' 和 'content'

    返回:
        最佳文档索引和相似度分数
    """
    if not candidates:
        return None, 0.0

    # 对查询进行编码(使用检索任务prompt)
    query_embedding = model.encode(query, prompt_name="Retrieval-query")

    # 构建文档文本并编码
    doc_texts = [
        f"title: {doc.get('title', 'none')} | text: {doc.get('content', '')}"
        for doc in candidates
    ]
    candidate_embeddings = model.encode(doc_texts)

    print("候选文档:", [doc['title'] for doc in candidates])

    # 计算余弦相似度
    return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))

# --- 主搜索逻辑 ---

best_document = None  # 初始化最佳文档

# 1. 寻找最相关类别
print("步骤1: 寻找最相关类别...")
categories = [item["category"] for item in corp_knowledge_base]
best_category_index, category_score = find_best_category(model, question, categories)

# 检查类别匹配分数是否超过阈值
if category_score < similarity_threshold:
    print(f" `-> 🤷 未找到相关类别。最高分仅为 {category_score:.2f}。")
else:
    best_category = corp_knowledge_base[best_category_index]
    print(f" `-> ✅ 找到类别: '{best_category['category']}' (分数: {category_score:.2f})")

    # 2. 仅当找到合适类别后,在该类别中寻找最相关文档
    print("\n步骤2: 在该类别中寻找最相关文档...")
    best_document_index, document_score = find_best_doc(
        model, question, best_category["documents"]
    )

    # 检查文档匹配分数是否超过阈值
    if document_score < similarity_threshold:
        print(f" `-> 🤷 未找到相关文档。最高分仅为 {document_score:.2f}。")
    else:
        best_document = best_category["documents"][best_document_index]
        print(f" `-> ✅ 找到文档: '{best_document['title']}' (分数: {document_score:.2f})")

7 基于检索结果生成答案

python 复制代码
qa_prompt_template = """请仅根据以下CONTEXT回答问题。如果CONTEXT中未包含答案,请回答"我不知道"。

---
CONTEXT:
{context}
---
QUESTION:
{question}
"""

if best_document and "content" in best_document:
    # 如果找到有效文档,生成答案
    context = best_document["content"]
    prompt = qa_prompt_template.format(context=context, question=question)

    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        },
    ]

    print("问题🙋‍♂️: " + question)
    # 调用生成模型获取答案
    answer = pipeline(messages, max_new_tokens=256, disable_compile=True)[0]["generated_text"][1]["content"]
    print("使用文档: " + best_document["title"])
    print("回答🤖: " + answer)
else:
    # 未找到相关文档时的回复
    print("问题🙋‍♂️: " + question)
    print("回答🤖: 抱歉,未找到相关文档来回答该问题。")

6. 一些工程经验

统一使用 Prompt:模型训练时使用了特定 prompt 名称和内容,因此生产环境中应保持使用这些prompt以保证效果一致性。

维度选择权衡:MRL 支持降维,在对吞吐量或存储敏感的场景中,可优先尝试512或256维,而非直接降至 128 维。

量化与精度平衡:该模型主打轻量级使用,量化可显著减少内存占用,但应根据精度需求和应用场景谨慎选择量化策略(如QAT 或 PTQ)。

相关推荐
安思派Anspire5 小时前
Google 新 LLM 仅需 0.5GB 内存即可运行——如何在本地对其进行微调
aigc·openai·agent
奇舞精选5 小时前
别让 AI 代码变成技术负债:Vibe Coding 提效实践
openai
mortimer6 小时前
一次 ModelScope 替代 Hugging Face 的模型下载实战指南
人工智能·llm
六月的可乐7 小时前
Vue3项目中集成AI对话功能的实战经验分享
前端·人工智能·openai
Baihai_IDP7 小时前
2025 年大语言模型架构演进:DeepSeek V3、OLMo 2、Gemma 3 与 Mistral 3.1 核心技术剖析
人工智能·llm·aigc
程序员爱钓鱼8 小时前
Go语言实战案例 — 工具开发篇:Go 实现二维码生成器
后端·google·go
新智元19 小时前
世界首富换人!81 岁硅谷狂人 4000 亿身价碾压马斯克,33 岁华裔才女逆袭
人工智能·openai
机器之心19 小时前
交互扩展时代来临:创智复旦字节重磅发布AgentGym-RL,昇腾加持,开创智能体训练新范式
人工智能·openai
AI大模型19 小时前
无所不能的Embedding(06) - 跨入Transformer时代~模型详解&代码实现
程序员·llm·agent