这周在做 Agent 项目的 RAG 模块,有了很多感悟,有感而发写了这篇博客
一、什么是RAG
RAG(Retrieval-Augmented Generation,检索增强生成)是一种将外部知识库检索与大语言模型(LLM)文本生成能力相结合的技术。简单来说,它的核心思想是:在AI生成回答前,先从外部数据源中检索与问题相关的上下文,再将这些参考资料作为输入传递给模型进行回答。
这就像是一场"开卷考试":遇到不会的问题时,AI会先"翻书"查阅相关资料,然后再根据查到的内容给出准确的答案。
二、为什么需要 RAG?
大语言模型虽然具备强大的文本生成能力,但在实际应用中存在几个明显的痛点,而 RAG 恰好能解决这些问题:
- 知识时效性不足:大模型的训练数据有截止日期,无法获取最新信息(如实时新闻、最新产品文档)。
- 事实准确性偏差(幻觉):模型有时会一本正经地胡说八道,编造不符合事实的内容。
- 缺乏私有知识:通用模型无法了解企业内部的机密文档或特定业务逻辑。
- 数据安全与成本:企业无需将敏感的私有数据上传用于重新训练模型,降低了数据泄露风险和高昂的持续训练成本。
三、RAG 的核心工作流程
一个典型的 RAG 系统架构通常分为三个关键阶段:
1. 数据处理阶段(Data Pipeline)
这是构建知识库的过程。系统会将原始的结构化或非结构化数据(如 PDF、Word、网页等)进行加载和清洗;接着通过"文本分块(Chunking)"将长文档拆分为短文本块;最后使用嵌入模型(Embedding Model)将这些文本块转换为高维向量,并存储在向量数据库中。
2. 检索阶段(Retrieval)
当用户提出问题时,系统会先将用户的查询转化为向量,然后在向量数据库中进行相似度搜索,找出与问题最相关的几个文本片段。为了提升准确性,现代 RAG 还会结合关键词过滤、查询重写等混合检索策略。
3. 生成阶段(Generation)
系统将检索到的相关文本片段作为上下文(Context),连同用户的原始问题一起拼接成提示词(Prompt),最终交给大语言模型(LLM)。模型基于这些提供的参考资料生成准确、连贯且可溯源的最终回答。
四、RAG 的演进:从标准到智能体驱动
随着技术的发展,RAG 也在不断进化:
- Naive RAG(基础 RAG):采用最基础的"切块 → 搜索 → 生成"固定流程,问一句查一下。
- Agentic RAG(智能体 RAG):赋予了 AI 自主决策的能力。它不再死板地执行固定流程,而是由 AI Agent 自己判断"是否需要查资料"、"应该去哪个知识库或网页查"以及"查到的结果是否足够"。如果信息不够,它还会主动更换数据源继续检索,表现得更加灵活和聪明。
五、典型应用场景
- 企业知识库问答:从内部规章制度、FAQ 中精准检索答案,避免 AI 瞎编乱造。
- 智能客服与虚拟助手:结合最新的业务文档和历史工单,提供准确且个性化的客户服务响应。
- 实时新闻分析与内容推荐:结合互联网最新数据进行评论生成或个性化内容推荐。
企业知识库问答
在企业知识库问答场景中,为了确保AI能够基于企业内部数据给出准确、合规且可追溯的回答,其检索流程通常比通用问答更为严谨。一个典型的企业级 RAG(检索增强生成)系统流程主要包含以下关键步骤:
1. 文档解析与智能分块
企业知识通常以多种格式存在(如PDF、Word、Markdown等)。系统首先需要读取文件内容并保留关键的元数据(如文档名称、页码、章节层级等),这为后续的答案溯源打下基础。随后,系统会将长文档切分为较小的文本块(Chunking)。合理的切分策略至关重要:
按结构切分:根据标题、段落、表格或代码块进行分块,避免破坏上下文逻辑。
参数设置:通常会设定合适的分块大小(如300-800字符)和重叠区域(如50-100字符),以保证语义的连贯性。
2. 向量化与索引存储
文本被切分后,系统会使用嵌入模型(Embedding Model)将每个文本片段转换成高维向量(一串数字)。这些向量连同原文内容、来源信息以及权限元数据等,会被统一存入向量数据库中,以便计算机能够计算"哪段文字和用户的问题最相似"。
3. 混合检索与重排序
当员工提出问题时,系统会先将问题也转化为向量,然后进入核心的检索环节。为了兼顾语义理解和精确匹配,企业级系统常采用以下机制:
混合检索:结合向量检索(擅长理解自然语言语义)和关键词检索(如BM25,擅长匹配精确的编号、代码函数名或合同条款),综合评估后召回初步候选结果。
权限过滤:在检索阶段同步进行权限校验,确保用户只能获取其有权限访问的内容,防止企业机密泄露。
重排序(Rerank):由于初步召回的结果可能包含大量无关信息,系统会使用重排序模型对候选片段进行更精细的相关性判断,筛选出最相关的几个片段(例如从50个中选出最匹配的5个),从而显著提升回答的准确性。
4. 上下文注入与生成溯源
经过严格筛选的高质量文本片段会被注入到提示词(Prompt)中,连同用户的原始问题一起交由大语言模型(LLM)处理。最终生成的回答不仅要求准确,还必须附带段落溯源信息。这意味着系统需要明确指出答案来源于哪份文档的第几页或哪个章节,并提供原文片段供用户核实。这种"给证据"而非仅仅"给答案"的机制,是企业知识库建立信任度、满足合规审计要求的核心所在。
六、文件结构
RAG是一个整的流程,有关代码尽量放在一起,如下:
internal/rag/
├── service.go # 入库服务(IngestionService)
├── retriever.go # 检索服务(RAGRetriever)
├── md_parser.go # Markdown AST 解析器
├── parent_transformer.go # ParentBlock 构建器
├── child_transformer.go # ChildChunk 分割器
├── semantic_transformer.go # 语义增强器
├── milvus_indexer.go # Milvus 读写操作
└── embedder_factory.go # Embedder 工厂
七、RAG模块详解
RAG分为两个阶段:
- 离线阶段:将文档存入向量数据库
- 在线阶段:从向量数据库中检索有关内容
数据入库流程
原始 Markdown 文本 │ ▼ ① MarkdownParser.Parse() AST 解析,按标题分章节 │ ▼ ② ParentTransformer.Transform() 章节 → ParentBlock(≤1000 tokens) │ ▼ ③ ChildTransformer.Transform() ParentBlock → ChildChunk(≤400 tokens) │ ▼ ④ SemanticTransformer.Transform() 为 ChildChunk 注入章节路径上下文 │ ▼ ⑤ WrapDocuments() 添加 source_id 元数据 │ ▼ ⑥ Embedder.EmbedStrings() 批量向量化(每批 256 条,带重试) │ ▼ ⑦ GenerateSparseVector() 生成 BM25 稀疏向量 │ ▼ ⑧ MilvusWriter.StoreWithSparse() 写入 Milvus(dense + sparse) │ ▼ ⑨ ParentBlockRepo.BatchCreate() 写入 MySQL(用于 Parent Recovery) │ ▼ ⑩ 更新 source 状态为 "ready"
入库相关代码
md_parser.go (Markdown AST 解析器)
Go
package rag
import (
"context"
"io"
"strings"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/text"
)
// ASTDocument 解析后的文档结构
type ASTDocument struct {
Sections []*Section
}
// Section 文档章节
type Section struct {
Heading string
Level int // 1=H1, 2=H2, ...
ChapterPath string // 完整路径,如 "第一章 > 1.1 背景"
Content []ContentBlock
}
// ContentBlock 内容块
type ContentBlock struct {
Type string // paragraph/code/table/image/mermaid/quote
Content string
Language string // 代码块语言
}
// MarkdownParser Markdown AST 解析器,实现 eino parser.Parser 接口
type MarkdownParser struct {
md goldmark.Markdown
}
// NewMarkdownParser 创建 Markdown 解析器
func NewMarkdownParser() *MarkdownParser {
md := goldmark.New(
goldmark.WithExtensions(
extension.Table, // 启用表格扩展
),
)
return &MarkdownParser{
md: md,
}
}
// Parse 实现 eino parser.Parser 接口。
// 将 Markdown io.Reader 解析为 []*schema.Document。
// 每个章节生成一个包含结构信息的文档。
func (p *MarkdownParser) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
content, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
doc := p.md.Parser().Parse(text.NewReader(content))
astDoc := p.buildASTDocument(doc, content)
var docs []*schema.Document
for i, section := range astDoc.Sections {
d := &schema.Document{
ID: "",
Content: section.buildContent(),
MetaData: map[string]any{
"heading": section.Heading,
"level": section.Level,
"chapter_path": section.ChapterPath,
"section_idx": i,
},
}
docs = append(docs, d)
}
return docs, nil
}
// 解析 markdown 语法树 doc,生成自定义的 ASTDocument 结构
func (p *MarkdownParser) buildASTDocument(doc ast.Node, source []byte) *ASTDocument {
result := &ASTDocument{} // 结果容器
var currentSection *Section // 当前正在处理的章节
var headingStack []string // 记录标题层级路径
// 遍历 doc 的所有子节点,广度优先
for n := doc.FirstChild(); n != nil; n = n.NextSibling() {
switch n.Kind() {
case ast.KindHeading: // 标题节点
heading := n.(*ast.Heading) // 类型断言为具体 Heading 结构
headingText := string(n.Text(source)) // 提取标题文本
level := heading.Level // 标题等级(1-6)
// 若新标题等级 ≤ 当前栈长度,弹出栈顶元素直到长度匹配
if level <= len(headingStack) {
headingStack = headingStack[:level-1]
}
// 将当前标题文本加入路径栈
headingStack = append(headingStack, headingText)
// 创建新章节
currentSection = &Section{
Heading: headingText,
Level: level,
ChapterPath: joinPath(headingStack), // 用 "." 拼接路径栈
}
result.Sections = append(result.Sections, currentSection)
default: // 非标题节点(段落、列表等)
if currentSection == nil { // 如果没有标题开头的章节
currentSection = &Section{ // 创建一个无标题章节
Heading: "",
Level: 0,
ChapterPath: "",
}
result.Sections = append(result.Sections, currentSection)
}
block := p.extractContentBlock(n, source) // 提取内容块
if block != nil {
currentSection.Content = append(currentSection.Content, *block)
}
}
}
return result
}
// 将 Markdown AST 节点转换为自定义的 ContentBlock
func (p *MarkdownParser) extractContentBlock(n ast.Node, source []byte) *ContentBlock {
switch n.Kind() {
case ast.KindParagraph: // 段落
return &ContentBlock{Type: "paragraph", Content: string(n.Text(source))}
case ast.KindFencedCodeBlock: // 代码块
lang := ""
codeBlock := n.(*ast.FencedCodeBlock) // 类型断言
if codeBlock.Info != nil { // 代码块头部的语言标识
lang = string(codeBlock.Info.Text(source))
}
var code []byte
lines := codeBlock.Lines() // 获取代码行集合
for i := 0; i < lines.Len(); i++ { // 遍历所有行
seg := lines.At(i) // 每行的文本片段
code = append(code, seg.Value(source)...)
code = append(code, '\n') // 补回换行符
}
if lang == "mermaid" { // 特殊处理 mermaid 图表
return &ContentBlock{Type: "mermaid", Content: string(code)}
}
return &ContentBlock{Type: "code", Content: string(code), Language: lang}
case ast.KindBlockquote: // 引用块
return &ContentBlock{Type: "quote", Content: string(n.Text(source))}
case ast.KindList: // 列表
return &ContentBlock{Type: "paragraph", Content: string(n.Text(source))}
default: // 其他类型(如图片、表格等)
tableContent := p.extractTable(n, source) // 尝试解析表格
if tableContent != "" {
return &ContentBlock{Type: "table", Content: tableContent}
}
nodeText := string(n.Text(source))
if strings.TrimSpace(nodeText) == "" { // 忽略空节点
return nil
}
return &ContentBlock{Type: "paragraph", Content: nodeText}
}
}
// extractTable 提取表格内容为 Markdown 格式
func (p *MarkdownParser) extractTable(n ast.Node, source []byte) string {
// 检查是否是表格节点
if n.Kind().String() != "Table" {
return ""
}
var rows []string
var headerRow string
isHeader := true
// 遍历表格的子节点(TableHeader 和 TableRow)
for row := n.FirstChild(); row != nil; row = row.NextSibling() {
var cells []string
// 遍历行的子节点(TableCell)
for cell := row.FirstChild(); cell != nil; cell = cell.NextSibling() {
cellText := strings.TrimSpace(string(cell.Text(source)))
cells = append(cells, cellText)
}
if len(cells) > 0 {
rowStr := "| " + strings.Join(cells, " | ") + " |"
if isHeader {
headerRow = rowStr
// 首行作为表头,其后添加分隔行如 | --- | --- |
separator := "| " + strings.Repeat("--- | ", len(cells))
rows = append(rows, headerRow)
rows = append(rows, separator)
isHeader = false
} else {
rows = append(rows, rowStr)
}
}
}
if len(rows) > 0 {
return strings.Join(rows, "\n")
}
return ""
}
// buildContent 将章节的 ContentBlock 列表合并为纯文本
// 代码块和 mermaid 会重新添加 ``` 标记,保持类型信息
// 表格保持 Markdown 格式
func (s *Section) buildContent() string {
var parts []string
for _, block := range s.Content {
switch block.Type {
case "code":
// 重新添加代码块标记
parts = append(parts, "```"+block.Language+"\n"+block.Content+"```")
case "mermaid":
// 重新添加 mermaid 标记
parts = append(parts, "```mermaid\n"+block.Content+"```")
case "table":
// 表格已经是 Markdown 格式,直接保留
parts = append(parts, block.Content)
default:
parts = append(parts, block.Content)
}
}
return strings.Join(parts, "\n\n")
}
// joinPath 将标题栈连接为章节路径
func joinPath(stack []string) string {
return strings.Join(stack, " > ")
}
parent_transformer.go (ParentBlock 构建器)
Go
package rag
import (
"context"
"fmt"
"strings"
"YoudaoNoteLm/internal/model/entity"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/schema"
)
// ParentTransformer 将 eino 文档转换为 ParentBlock
// 实现 eino document.Transformer 接口
type ParentTransformer struct {
maxTokens int // 默认 1000
}
// NewParentTransformer 创建 ParentBlock 构建器
func NewParentTransformer(maxTokens int) *ParentTransformer {
if maxTokens <= 0 {
maxTokens = 1000
}
return &ParentTransformer{maxTokens: maxTokens}
}
// Transform 实现 eino document.Transformer 接口
// 输入: eino 文档列表(每个文档代表一个章节)
// 输出: 转换后的 eino 文档列表(每个文档代表一个 ParentBlock)
func (t *ParentTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
var result []*schema.Document
blockIndex := 0 // 父块全局索引
for _, doc := range src {
// 从元数据提取字段,类型断言 + 默认值
heading, _ := doc.MetaData["heading"].(string)
level, _ := doc.MetaData["level"].(int)
chapterPath, _ := doc.MetaData["chapter_path"].(string)
// 按 maxTokens 切分正文
chunks := t.splitByTokens(doc.Content, t.maxTokens)
for _, chunk := range chunks {
newDoc := &schema.Document{
Content: chunk,
MetaData: map[string]any{
"heading": heading,
"level": level,
"chapter_path": chapterPath,
"parent_index": blockIndex, // 记录原章节内顺序
"block_type": "parent",
},
}
result = append(result, newDoc)
blockIndex++
}
}
return result, nil
}
// 按 token 上限切分文本,尽量保持段落完整
func (t *ParentTransformer) splitByTokens(content string, maxTokens int) []string {
paragraphs := strings.Split(content, "\n\n") // 按空行分成段落
var chunks []string
var current []string // 当前块包含的段落
currentTokens := 0
for _, p := range paragraphs {
tokens := estimateTokens(p) // 估算当前段落 token 数
// 如果加上当前段落会超限,且当前块非空 → 结束当前块
if currentTokens+tokens > maxTokens && len(current) > 0 {
chunks = append(chunks, strings.Join(current, "\n\n"))
current = []string{p} // 新块从当前段落开始
currentTokens = tokens
} else {
current = append(current, p)
currentTokens += tokens
}
}
// 最后一块
if len(current) > 0 {
chunks = append(chunks, strings.Join(current, "\n\n"))
}
return chunks
}
// 估算 token 数:中文每字1 token,英文每单词1 token
func estimateTokens(text string) int {
chars := 0 // 非ASCII字符数(中文等)
words := 0 // 英文单词数
inWord := false // 是否处于单词中
for _, r := range text {
if r > 127 { // 非ASCII(中文)
chars++
inWord = false
} else if r == ' ' || r == '\n' || r == '\t' { // 分隔符
if inWord {
words++
}
inWord = false
} else { // 英文/数字/符号
inWord = true
}
}
if inWord { // 末尾单词
words++
}
return chars + words
}
// ToParentBlocks 将 eino 文档列表转换为 ParentBlock 实体列表
// 供 IngestionService 写入 MySQL
func ToParentBlocks(docs []*schema.Document, sourceID uint) []entity.ParentBlock {
var blocks []entity.ParentBlock
for _, doc := range docs {
level, _ := doc.MetaData["level"].(int)
chapterPath, _ := doc.MetaData["chapter_path"].(string)
parentIndex, _ := doc.MetaData["parent_index"].(int)
heading, _ := doc.MetaData["heading"].(string)
blocks = append(blocks, entity.ParentBlock{
SourceID: sourceID,
Heading: heading,
Level: level,
ChapterPath: chapterPath,
Content: doc.Content,
ChunkIndex: parentIndex,
Metadata: fmt.Sprintf(`{"chapter_path":"%s","level":%d}`, chapterPath, level),
})
}
return blocks
}
child_transformer.go (ChildChunk 分割器)
Go
package rag
import (
"context"
"strings"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/schema"
)
// ChildChunk 子块分割(内部使用)
type ChildChunk struct {
ParentBlockIndex int
Content string
ChunkType string // paragraph/code/table/image/mermaid/quote
ChapterPath string
}
// ChildTransformer 将 ParentBlock 文档分割为 ChildChunk 文档
// 实现 eino document.Transformer 接口
type ChildTransformer struct {
maxTokens int // 默认 400
}
// NewChildTransformer 创建 ChildChunk 分割器
func NewChildTransformer(maxTokens int) *ChildTransformer {
if maxTokens <= 0 {
maxTokens = 400
}
return &ChildTransformer{maxTokens: maxTokens}
}
// Transform 实现 eino document.Transformer 接口
func (t *ChildTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
var result []*schema.Document
for _, parentDoc := range src {
parentIndex, _ := parentDoc.MetaData["parent_index"].(int)
chapterPath, _ := parentDoc.MetaData["chapter_path"].(string)
heading, _ := parentDoc.MetaData["heading"].(string)
chunks := t.splitContent(parentDoc.Content)
for i, chunk := range chunks {
childDoc := &schema.Document{
Content: chunk.Content,
MetaData: map[string]any{
"parent_index": parentIndex,
"chunk_type": chunk.ChunkType,
"chapter_path": chapterPath,
"heading": heading,
"child_index": i,
"block_type": "child",
},
}
result = append(result, childDoc)
}
}
return result, nil
}
// splitContent 分割内容,保持代码块等特殊块的完整性
func (t *ChildTransformer) splitContent(content string) []ChildChunk {
// 先识别并提取特殊块(代码块、表格、mermaid)
specialBlocks, remainingContent := extractSpecialBlocks(content)
var chunks []ChildChunk
// 特殊块作为独立 chunk
for _, block := range specialBlocks {
chunks = append(chunks, ChildChunk{
Content: block.Content,
ChunkType: block.BlockType,
})
}
// 剩余内容按段落分割
if strings.TrimSpace(remainingContent) != "" {
paragraphs := splitIntoParagraphs(remainingContent)
textChunks := t.mergeParagraphs(paragraphs, t.maxTokens)
for _, tc := range textChunks {
chunks = append(chunks, ChildChunk{
Content: tc,
ChunkType: "paragraph",
})
}
}
return chunks
}
// specialBlock 特殊块
type specialBlock struct {
Content string
BlockType string
}
// extractSpecialBlocks 从内容中提取特殊块(代码块、表格、mermaid)
func extractSpecialBlocks(content string) ([]specialBlock, string) {
var blocks []specialBlock
var remaining strings.Builder
lines := strings.Split(content, "\n")
inCodeBlock := false
codeBlockContent := strings.Builder{}
codeBlockLang := ""
for i := 0; i < len(lines); i++ {
line := lines[i]
trimmed := strings.TrimSpace(line)
// 检测代码块开始/结束
if strings.HasPrefix(trimmed, "```") {
if !inCodeBlock {
// 代码块开始
inCodeBlock = true
codeBlockContent.Reset()
codeBlockLang = strings.TrimPrefix(trimmed, "```")
codeBlockLang = strings.TrimSpace(codeBlockLang)
continue
} else {
// 代码块结束
inCodeBlock = false
blockType := "code"
if strings.EqualFold(codeBlockLang, "mermaid") {
blockType = "mermaid"
}
blocks = append(blocks, specialBlock{
Content: "```" + codeBlockLang + "\n" + codeBlockContent.String() + "```",
BlockType: blockType,
})
continue
}
}
if inCodeBlock {
codeBlockContent.WriteString(line)
codeBlockContent.WriteString("\n")
continue
}
// 检测表格(以 | 开头,包含 ---)
if strings.HasPrefix(trimmed, "|") && strings.Contains(trimmed, "---") {
// 收集整个表格
tableContent := strings.Builder{}
tableContent.WriteString(line)
tableContent.WriteString("\n")
for j := i + 1; j < len(lines); j++ {
nextLine := strings.TrimSpace(lines[j])
if !strings.HasPrefix(nextLine, "|") {
break
}
tableContent.WriteString(lines[j])
tableContent.WriteString("\n")
i = j
}
blocks = append(blocks, specialBlock{
Content: tableContent.String(),
BlockType: "table",
})
continue
}
// 普通行
remaining.WriteString(line)
remaining.WriteString("\n")
}
// 如果代码块没有闭合,作为代码块处理
if inCodeBlock {
blocks = append(blocks, specialBlock{
Content: "```" + codeBlockLang + "\n" + codeBlockContent.String(),
BlockType: "code",
})
}
return blocks, remaining.String()
}
// splitIntoParagraphs 按空行分割为段落
func splitIntoParagraphs(content string) []string {
// 按两个换行符分割(即空行)
paragraphs := strings.Split(content, "\n\n")
var result []string
for _, p := range paragraphs {
trimmed := strings.TrimSpace(p)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// mergeParagraphs 将段落合并为不超过 maxTokens 的 chunk
func (t *ChildTransformer) mergeParagraphs(paragraphs []string, maxTokens int) []string {
if len(paragraphs) == 0 {
return nil
}
var chunks []string
var current []string
currentTokens := 0
for _, p := range paragraphs {
tokens := estimateTokens(p)
if currentTokens+tokens > maxTokens && len(current) > 0 {
chunks = append(chunks, strings.Join(current, "\n\n"))
current = []string{p}
currentTokens = tokens
} else {
current = append(current, p)
currentTokens += tokens
}
}
if len(current) > 0 {
chunks = append(chunks, strings.Join(current, "\n\n"))
}
return chunks
}
semantic_transformer.go (语义增强器)
Go
package rag
import (
"context"
"fmt"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/schema"
)
// SemanticTransformer 语义增强器
// 在 Embedding 前为 ChildChunk 注入结构化上下文
// 实现 eino document.Transformer 接口
type SemanticTransformer struct{}
// NewSemanticTransformer 创建语义增强器
func NewSemanticTransformer() *SemanticTransformer {
return &SemanticTransformer{}
}
// Transform 实现 eino document.Transformer 接口
// 为每个 ChildChunk 文档的内容注入章节路径和结构信息
func (t *SemanticTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
result := make([]*schema.Document, len(src))
for i, doc := range src {
chapterPath, _ := doc.MetaData["chapter_path"].(string)
chunkType, _ := doc.MetaData["chunk_type"].(string)
enhanced := t.enhance(doc.Content, chapterPath, chunkType)
newDoc := &schema.Document{
ID: doc.ID,
Content: enhanced,
MetaData: doc.MetaData,
}
result[i] = newDoc
}
return result, nil
}
// enhance 注入结构化上下文
func (t *SemanticTransformer) enhance(content, chapterPath, chunkType string) string {
if chapterPath == "" {
return content
}
switch chunkType {
case "code":
return fmt.Sprintf("标题路径:%s\n代码内容:\n%s", chapterPath, content)
case "table":
return fmt.Sprintf("标题路径:%s\n表格数据:\n%s", chapterPath, content)
case "mermaid":
return fmt.Sprintf("标题路径:%s\n流程图:\n%s", chapterPath, content)
default:
return fmt.Sprintf("标题路径:%s\n正文:%s", chapterPath, content)
}
}
检索流程
用户查询 "机器学习的优缺点"
│
▼
① Embedder.EmbedStrings() 查询向量化
│
├──并行──────────────────────┐
▼ ▼
② SemanticSearch() ③ KeywordSearch()
(Dense COSINE) (Sparse IP)
topK=20 topK=20
│ │
└───────────┬───────────────┘
▼
④ fuse() --- RRF 融合
1/(60+rank+1) 求和
│
▼
⑤ 取 topK 个候选
│
▼
⑥ parentRecovery()
填充 ParentBlock 完整内容
填充 Source 名称
│
▼
返回 \[\]*RetrieveResult
检索相关代码
retriever.go (检索服务(RAGRetriever))
Go
package rag
import (
"context"
"fmt"
"sort"
"sync"
"github.com/cloudwego/eino/components/embedding"
"go.uber.org/zap"
"YoudaoNoteLm/internal/model/entity"
"YoudaoNoteLm/internal/repository"
"YoudaoNoteLm/pkg/logger"
)
// RAGRetriever RAG 检索接口
type RAGRetriever interface {
Retrieve(ctx context.Context, req *RetrieveRequest) ([]*RetrieveResult, error)
}
// RetrieveRequest 检索请求
type RetrieveRequest struct {
Query string // 改写后的查询文本
UserID uint // 用户 ID(定位 Milvus collection)
SourceIDs []uint // 限定的资料来源范围
TopK int // 最终返回数量,默认 5
QueryVector []float32 // 预计算的查询向量(可选)
}
// RetrieveResult 检索结果
type RetrieveResult struct {
Content string // chunk 内容
SourceID uint // 资料来源 ID
SourceName string // 资料来源名称
ParentBlockID int64 // 父块 ID
ParentContent string // 父块完整内容
Heading string // 父块标题
ChapterPath string // 章节路径
Score float32 // 最终相关度分数
ChunkType string // chunk 类型
Metadata string // 元数据 JSON
}
const (
defaultTopK = 8
semanticCandidateK = 20
keywordCandidateK = 20
rrfK = 60
)
// RetrieverEmbedderProvider 根据 userID 获取 Embedder
type RetrieverEmbedderProvider func(ctx context.Context, userID uint) (embedding.Embedder, error)
type ragRetriever struct {
milvusSearcher MilvusSearcher
parentBlockRepo repository.ParentBlockRepository
sourceRepo repository.SourceRepository
embedderProvider RetrieverEmbedderProvider
topK int
}
// NewRAGRetriever 创建 RAGRetriever
func NewRAGRetriever(
milvusSearcher MilvusSearcher,
parentBlockRepo repository.ParentBlockRepository,
sourceRepo repository.SourceRepository,
embedderProvider RetrieverEmbedderProvider,
topK int,
) RAGRetriever {
if topK <= 0 {
topK = defaultTopK
}
return &ragRetriever{
milvusSearcher: milvusSearcher,
parentBlockRepo: parentBlockRepo,
sourceRepo: sourceRepo,
embedderProvider: embedderProvider,
topK: topK,
}
}
// Retrieve 执行 RAG 检索:语义 + 关键词双路召回 -> RRF 融合 -> Rerank -> Parent Recovery
func (r *ragRetriever) Retrieve(ctx context.Context, req *RetrieveRequest) ([]*RetrieveResult, error) {
topK := r.topK
if req.TopK > 0 {
topK = req.TopK
}
// 1. 获取查询向量
queryVector := req.QueryVector
if len(queryVector) == 0 {
embedder, err := r.embedderProvider(ctx, req.UserID)
if err != nil {
return nil, fmt.Errorf("获取 embedder 失败: %w", err)
}
vectors, err := embedder.EmbedStrings(ctx, []string{req.Query})
if err != nil {
return nil, fmt.Errorf("查询向量化失败: %w", err)
}
if len(vectors) > 0 {
queryVector = make([]float32, len(vectors[0]))
for i, v := range vectors[0] {
queryVector[i] = float32(v)
}
}
}
// 2. 并行语义检索 + 关键词检索
var (
semanticResults []MilvusSearchResult
keywordResults []MilvusSearchResult
semanticErr error
keywordErr error
wg sync.WaitGroup
)
wg.Add(2)
go func() {
defer wg.Done()
semanticResults, semanticErr = r.milvusSearcher.SemanticSearch(ctx, req.UserID, queryVector, req.SourceIDs, semanticCandidateK)
}()
go func() {
defer wg.Done()
keywordResults, keywordErr = r.milvusSearcher.KeywordSearch(ctx, req.UserID, req.Query, req.SourceIDs, keywordCandidateK)
}()
wg.Wait()
if semanticErr != nil {
logger.Warn("语义检索失败,降级为仅关键词检索", zap.Error(semanticErr))
}
if keywordErr != nil {
logger.Warn("关键词检索失败,降级为仅语义检索", zap.Error(keywordErr))
}
// 3. RRF 融合
fused := r.fuse(semanticResults, keywordResults)
if len(fused) == 0 {
return nil, nil
}
// 4. 候选
candidateK := topK * 4
if candidateK > len(fused) {
candidateK = len(fused)
}
candidates := fused[:candidateK]
// 5. TopK
if len(candidates) > topK {
candidates = candidates[:topK]
}
// 6. Parent Recovery
results, err := r.parentRecovery(ctx, candidates)
if err != nil {
logger.Warn("Parent Recovery 失败,返回原始结果", zap.Error(err))
return candidates, nil
}
return results, nil
}
// fuse 使用 RRF (Reciprocal Rank Fusion) 融合语义检索和关键词检索结果
func (r *ragRetriever) fuse(semanticResults, keywordResults []MilvusSearchResult) []*RetrieveResult {
type resultKey struct {
sourceID int64
parentBlockID int64
}
scoreMap := make(map[resultKey]*RetrieveResult)
rankMap := make(map[resultKey][]float64)
for rank, item := range semanticResults {
key := resultKey{sourceID: item.SourceID, parentBlockID: item.ParentBlockID}
if _, exists := scoreMap[key]; !exists {
scoreMap[key] = &RetrieveResult{
Content: item.Content,
SourceID: uint(item.SourceID),
ParentBlockID: item.ParentBlockID,
Score: item.Score,
ChunkType: item.ChunkType,
Metadata: item.Metadata,
}
}
rankMap[key] = append(rankMap[key], 1.0/float64(rrfK+rank+1))
}
for rank, item := range keywordResults {
key := resultKey{sourceID: item.SourceID, parentBlockID: item.ParentBlockID}
if _, exists := scoreMap[key]; !exists {
scoreMap[key] = &RetrieveResult{
Content: item.Content,
SourceID: uint(item.SourceID),
ParentBlockID: item.ParentBlockID,
Score: item.Score,
ChunkType: item.ChunkType,
Metadata: item.Metadata,
}
}
rankMap[key] = append(rankMap[key], 1.0/float64(rrfK+rank+1))
}
for key, scores := range rankMap {
var total float64
for _, s := range scores {
total += s
}
scoreMap[key].Score = float32(total)
}
results := make([]*RetrieveResult, 0, len(scoreMap))
for _, res := range scoreMap {
results = append(results, res)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
return results
}
// parentRecovery 为候选结果填充 ParentBlock 的完整内容、标题、章节路径以及资料来源名称
func (r *ragRetriever) parentRecovery(ctx context.Context, candidates []*RetrieveResult) ([]*RetrieveResult, error) {
seen := make(map[uint]bool)
var parentIDs []uint
for _, c := range candidates {
pid := uint(c.ParentBlockID)
if !seen[pid] {
seen[pid] = true
parentIDs = append(parentIDs, pid)
}
}
if len(parentIDs) == 0 {
return candidates, nil
}
parentBlocks, err := r.parentBlockRepo.FindByIDs(parentIDs)
if err != nil {
return nil, fmt.Errorf("查询 ParentBlock 失败: %w", err)
}
parentMap := make(map[uint]*entity.ParentBlock)
for _, pb := range parentBlocks {
parentMap[pb.ID] = pb
}
sourceSeen := make(map[uint]bool)
var sourceIDs []uint
for _, c := range candidates {
if !sourceSeen[c.SourceID] {
sourceSeen[c.SourceID] = true
sourceIDs = append(sourceIDs, c.SourceID)
}
}
sourceNames := make(map[uint]string)
for _, sid := range sourceIDs {
source, err := r.sourceRepo.FindByID(sid)
if err == nil && source != nil {
sourceNames[sid] = source.Name
}
}
for _, c := range candidates {
pid := uint(c.ParentBlockID)
if pb, ok := parentMap[pid]; ok {
c.ParentContent = pb.Content
c.Heading = pb.Heading
c.ChapterPath = pb.ChapterPath
} else {
logger.Warn("[DEBUG] parentRecovery miss",
zap.Int64("parent_block_id", c.ParentBlockID),
zap.Uint("source_id", c.SourceID),
)
}
if name, ok := sourceNames[c.SourceID]; ok {
c.SourceName = name
}
}
return candidates, nil
}
通用代码
embedder_factory.go (Embedder 工厂)
Go
package rag
import (
"context"
"fmt"
"YoudaoNoteLm/internal/model/entity"
einoArk "github.com/cloudwego/eino-ext/components/embedding/ark"
einoOpenai "github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/embedding"
)
// EmbeddingProvider 定义支持的 embedding 提供商类型
type EmbeddingProvider string
const (
ProviderArk EmbeddingProvider = "ark" // 火山引擎(豆包)
ProviderOpenAI EmbeddingProvider = "openai" // OpenAI
)
// EmbeddingConfig embedding 模型配置
type EmbeddingConfig struct {
Provider EmbeddingProvider `json:"provider"` // 提供商
APIKey string `json:"api_key"` // API Key
Model string `json:"model"` // 模型名称或接入点 ID
BaseURL string `json:"base_url,omitempty"` // 自定义 API 地址(可选)
Dimensions *int `json:"dimensions,omitempty"` // 向量维度(可选,部分模型支持)
// Ark 特有配置
ArkAPIType string `json:"ark_api_type,omitempty"` // Ark API 类型: "text_api" 或 "multi_modal_api"
}
// NewEmbedder 根据配置创建 eino Embedder
// 支持所有 eino-ext 集成的 embedding 模型
func NewEmbedder(ctx context.Context, cfg *EmbeddingConfig) (embedding.Embedder, error) {
if cfg == nil {
return nil, fmt.Errorf("embedding 配置不能为空")
}
switch cfg.Provider {
case ProviderArk:
return createArkEmbedder(ctx, cfg)
case ProviderOpenAI:
return createOpenAIEmbedder(ctx, cfg)
default:
return nil, fmt.Errorf("不支持的 embedding 提供商: %s", cfg.Provider)
}
}
// createArkEmbedder 创建火山引擎 Ark Embedder
func createArkEmbedder(ctx context.Context, cfg *EmbeddingConfig) (embedding.Embedder, error) {
if cfg.Model == "" {
return nil, fmt.Errorf("Ark embedding 模型名称或接入点 ID 未配置")
}
conf := &einoArk.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
}
if cfg.BaseURL != "" {
conf.BaseURL = cfg.BaseURL
}
if cfg.ArkAPIType != "" {
apiType := einoArk.APIType(cfg.ArkAPIType)
conf.APIType = &apiType
}
if cfg.Dimensions != nil {
conf.Dimensions = cfg.Dimensions
}
embedder, err := einoArk.NewEmbedder(ctx, conf)
if err != nil {
return nil, fmt.Errorf("创建 Ark Embedder 失败: %w", err)
}
return embedder, nil
}
// createOpenAIEmbedder 创建 OpenAI Embedder
func createOpenAIEmbedder(ctx context.Context, cfg *EmbeddingConfig) (embedding.Embedder, error) {
if cfg.Model == "" {
return nil, fmt.Errorf("OpenAI embedding 模型名称未配置")
}
conf := &einoOpenai.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
}
if cfg.BaseURL != "" {
conf.BaseURL = cfg.BaseURL
}
if cfg.Dimensions != nil {
conf.Dimensions = cfg.Dimensions
}
embedder, err := einoOpenai.NewEmbedder(ctx, conf)
if err != nil {
return nil, fmt.Errorf("创建 OpenAI Embedder 失败: %w", err)
}
return embedder, nil
}
// NewEmbedderFromConfig 从 entity.UserConfig 创建 eino Embedder
// 这是一个便捷函数,自动将 UserConfig 转换为 EmbeddingConfig
func NewEmbedderFromConfig(ctx context.Context, cfg *entity.UserConfig) (embedding.Embedder, error) {
if cfg == nil {
return nil, fmt.Errorf("embedding 配置不能为空")
}
embeddingCfg := &EmbeddingConfig{
Provider: EmbeddingProvider(cfg.Provider),
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.APIURL,
Dimensions: cfg.Dimensions,
}
return NewEmbedder(ctx, embeddingCfg)
}
milvus_indexer.go (Milvus 读写操作)
Go
package rag
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"YoudaoNoteLm/pkg/logger"
"github.com/cloudwego/eino/schema"
milvusclient "github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"go.uber.org/zap"
)
const (
VectorDim = 2048 // 默认维度,实际由 embedder 决定
)
// UserCollectionName 返回用户专属的 Milvus Collection 名称
func UserCollectionName(userID uint) string {
return fmt.Sprintf("user_%d_chunks", userID)
}
// MilvusIndexerConfig Milvus 连接配置
type MilvusIndexerConfig struct {
Address string
}
// MilvusWriter 封装 Milvus 写入操作
type MilvusWriter struct {
client milvusclient.Client
}
var newMilvusClient = milvusclient.NewClient
// NewMilvusWriter 创建 Milvus 写入器
func NewMilvusWriter(ctx context.Context, cfg MilvusIndexerConfig) (*MilvusWriter, error) {
start := time.Now()
logger.Info("Milvus connection started", zap.String("address", cfg.Address))
cli, err := newMilvusClient(ctx, milvusclient.Config{
Address: cfg.Address,
})
if err != nil {
logger.Error("Milvus connection failed",
zap.String("address", cfg.Address),
zap.Duration("elapsed", time.Since(start)),
zap.Error(err),
)
return nil, fmt.Errorf("创建 Milvus 客户端失败: %w", err)
}
logger.Info("Milvus connection succeeded",
zap.String("address", cfg.Address),
zap.Duration("elapsed", time.Since(start)),
)
return &MilvusWriter{client: cli}, nil
}
// EnsureCollection 确保用户专属 Collection 存在,不存在则创建
func (w *MilvusWriter) EnsureCollection(ctx context.Context, userID uint) error {
collName := UserCollectionName(userID)
has, err := w.client.HasCollection(ctx, collName)
if err != nil {
return fmt.Errorf("检查 Collection 失败: %w", err)
}
if has {
return nil
}
schema := &entity.Schema{
CollectionName: collName,
AutoID: true,
Fields: []*entity.Field{
{
Name: "id",
DataType: entity.FieldTypeInt64,
AutoID: true,
PrimaryKey: true,
},
{
Name: "content",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "32768",
},
},
{
Name: "vector",
DataType: entity.FieldTypeFloatVector,
TypeParams: map[string]string{
"dim": fmt.Sprintf("%d", VectorDim),
},
},
{
Name: "parent_block_id",
DataType: entity.FieldTypeInt64,
},
{
Name: "source_id",
DataType: entity.FieldTypeInt64,
},
{
Name: "chunk_type",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "32",
},
},
{
Name: "metadata",
DataType: entity.FieldTypeVarChar,
TypeParams: map[string]string{
"max_length": "2048",
},
},
{
Name: "sparse_vector",
DataType: entity.FieldTypeSparseVector,
},
},
}
if err := w.client.CreateCollection(ctx, schema, 2); err != nil {
return fmt.Errorf("创建 Collection 失败: %w", err)
}
// 创建 HNSW 索引
idxParam, err := entity.NewIndexHNSW(entity.COSINE, 16, 200)
if err != nil {
return fmt.Errorf("创建 HNSW 索引参数失败: %w", err)
}
if err := w.client.CreateIndex(ctx, collName, "vector", idxParam, false); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
// 创建 BM25 索引(全文检索)
// 注意: SDK v2.4.2 没有 NewIndexBM25Sparse,使用 NewIndexSparseInverted + IP 代替
bm25IdxParam, err := entity.NewIndexSparseInverted(entity.IP, 0.0)
if err != nil {
return fmt.Errorf("创建 BM25 索引参数失败: %w", err)
}
if err := w.client.CreateIndex(ctx, collName, "sparse_vector", bm25IdxParam, false); err != nil {
return fmt.Errorf("创建 BM25 索引失败: %w", err)
}
// 加载 Collection
if err := w.client.LoadCollection(ctx, collName, false); err != nil {
return fmt.Errorf("加载 Collection 失败: %w", err)
}
return nil
}
// Store 将文档和对应的向量写入用户专属 Collection
// Deprecated: 使用 StoreWithSparse 代替,新 collection 包含 sparse_vector 字段
func (w *MilvusWriter) Store(ctx context.Context, userID uint, docs []*schema.Document, vectors [][]float32) error {
if len(docs) == 0 || len(docs) != len(vectors) {
return fmt.Errorf("文档和向量数量不匹配: docs=%d, vectors=%d", len(docs), len(vectors))
}
n := len(docs)
contents := make([]string, n)
vectorsData := make([][]float32, n)
parentBlockIDs := make([]int64, n)
sourceIDs := make([]int64, n)
chunkTypes := make([]string, n)
metadatas := make([]string, n)
const maxContentLen = 32768
for i, doc := range docs {
c := doc.Content
if len(c) > maxContentLen {
c = c[:maxContentLen]
}
contents[i] = c
vectorsData[i] = vectors[i]
// 优先使用真实的 parent_block_id(MySQL 自增 ID),回退到 parent_index
if pid, ok := doc.MetaData["parent_block_id"].(uint); ok {
parentBlockIDs[i] = int64(pid)
} else if pid, ok := doc.MetaData["parent_index"].(int); ok {
parentBlockIDs[i] = int64(pid)
}
if sid, ok := doc.MetaData["source_id"].(uint); ok {
sourceIDs[i] = int64(sid)
}
if ct, ok := doc.MetaData["chunk_type"].(string); ok {
chunkTypes[i] = ct
}
metaJSON, err := json.Marshal(doc.MetaData)
if err != nil {
return fmt.Errorf("序列化文档元数据失败: %w", err)
}
metadatas[i] = string(metaJSON)
}
// 按 batch 写入
batchSize := 100
for start := 0; start < n; start += batchSize {
end := start + batchSize
if end > n {
end = n
}
columns := []entity.Column{
entity.NewColumnVarChar("content", contents[start:end]),
entity.NewColumnFloatVector("vector", VectorDim, vectorsData[start:end]),
entity.NewColumnInt64("parent_block_id", parentBlockIDs[start:end]),
entity.NewColumnInt64("source_id", sourceIDs[start:end]),
entity.NewColumnVarChar("chunk_type", chunkTypes[start:end]),
entity.NewColumnVarChar("metadata", metadatas[start:end]),
}
if _, err := w.client.Insert(ctx, UserCollectionName(userID), "", columns...); err != nil {
return fmt.Errorf("写入 Milvus 失败 (batch %d-%d): %w", start, end, err)
}
}
return nil
}
// mapToSparseEmbedding 将 map[int32]float32 转换为 Milvus SDK 的 SparseEmbedding
func mapToSparseEmbedding(m map[int32]float32) (entity.SparseEmbedding, error) {
positions := make([]uint32, 0, len(m))
values := make([]float32, 0, len(m))
for k, v := range m {
positions = append(positions, uint32(k))
values = append(values, v)
}
return entity.NewSliceSparseEmbedding(positions, values)
}
// mapsToSparseEmbeddings 批量转换
func mapsToSparseEmbeddings(maps []map[int32]float32) ([]entity.SparseEmbedding, error) {
result := make([]entity.SparseEmbedding, len(maps))
for i, m := range maps {
se, err := mapToSparseEmbedding(m)
if err != nil {
return nil, fmt.Errorf("转换 sparse embedding #%d 失败: %w", i, err)
}
result[i] = se
}
return result, nil
}
// StoreWithSparse 将文档、dense vector 和 sparse vector 写入用户专属 Collection
func (w *MilvusWriter) StoreWithSparse(ctx context.Context, userID uint, docs []*schema.Document, denseVectors [][]float32, sparseVectors []map[int32]float32) error {
if len(docs) == 0 {
return fmt.Errorf("没有可写入的文档(内容可能为空或解析结果为空)")
}
if len(docs) != len(denseVectors) || len(docs) != len(sparseVectors) {
return fmt.Errorf("文档、dense向量和sparse向量数量不匹配: docs=%d, dense=%d, sparse=%d", len(docs), len(denseVectors), len(sparseVectors))
}
n := len(docs)
contents := make([]string, n)
denseData := make([][]float32, n)
parentBlockIDs := make([]int64, n)
sourceIDs := make([]int64, n)
chunkTypes := make([]string, n)
metadatas := make([]string, n)
const maxContentLen = 32768
for i, doc := range docs {
c := doc.Content
if len(c) > maxContentLen {
c = c[:maxContentLen]
}
contents[i] = c
denseData[i] = denseVectors[i]
// 优先使用真实的 parent_block_id(MySQL 自增 ID),回退到 parent_index
if pid, ok := doc.MetaData["parent_block_id"].(uint); ok {
parentBlockIDs[i] = int64(pid)
} else if pid, ok := doc.MetaData["parent_index"].(int); ok {
parentBlockIDs[i] = int64(pid)
}
if sid, ok := doc.MetaData["source_id"].(uint); ok {
sourceIDs[i] = int64(sid)
}
if ct, ok := doc.MetaData["chunk_type"].(string); ok {
chunkTypes[i] = ct
}
metaJSON, err := json.Marshal(doc.MetaData)
if err != nil {
return fmt.Errorf("序列化文档元数据失败: %w", err)
}
metadatas[i] = string(metaJSON)
}
// 转换 sparse vectors 为 Milvus SDK 格式
allSparseEmbeddings, err := mapsToSparseEmbeddings(sparseVectors)
if err != nil {
return fmt.Errorf("转换 sparse vectors 失败: %w", err)
}
// 按 batch 写入
batchSize := 100
for start := 0; start < n; start += batchSize {
end := start + batchSize
if end > n {
end = n
}
sparseColumn := entity.NewColumnSparseVectors("sparse_vector", allSparseEmbeddings[start:end])
columns := []entity.Column{
entity.NewColumnVarChar("content", contents[start:end]),
entity.NewColumnFloatVector("vector", VectorDim, denseData[start:end]),
entity.NewColumnInt64("parent_block_id", parentBlockIDs[start:end]),
entity.NewColumnInt64("source_id", sourceIDs[start:end]),
entity.NewColumnVarChar("chunk_type", chunkTypes[start:end]),
entity.NewColumnVarChar("metadata", metadatas[start:end]),
sparseColumn,
}
if _, err := w.client.Insert(ctx, UserCollectionName(userID), "", columns...); err != nil {
return fmt.Errorf("写入 Milvus 失败 (batch %d-%d): %w", start, end, err)
}
}
return nil
}
// GenerateSparseVector 将文本转换为 BM25 sparse vector
func GenerateSparseVector(text string) map[int32]float32 {
words := segmentText(text)
freq := make(map[string]int)
for _, w := range words {
freq[w]++
}
sv := make(map[int32]float32)
for word, count := range freq {
idx := hashToIndex(word)
sv[idx] = float32(count) / float32(count+1)
}
return sv
}
// segmentText 简单分词:中文按单字+bigram,英文按空格分词
func segmentText(text string) []string {
var words []string
runes := []rune(text)
i := 0
for i < len(runes) {
r := runes[i]
if r >= 0x4e00 && r <= 0x9fff {
words = append(words, string(r))
if i+1 < len(runes) && runes[i+1] >= 0x4e00 && runes[i+1] <= 0x9fff {
words = append(words, string(runes[i:i+2]))
}
i++
} else if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
j := i
for j < len(runes) && ((runes[j] >= 'a' && runes[j] <= 'z') || (runes[j] >= 'A' && runes[j] <= 'Z') || (runes[j] >= '0' && runes[j] <= '9')) {
j++
}
words = append(words, strings.ToLower(string(runes[i:j])))
i = j
} else {
i++
}
}
return words
}
// hashToIndex 将词 hash 为 int32 索引 (FNV-1a 变体)
func hashToIndex(word string) int32 {
var h uint32 = 2166136261
for _, b := range []byte(word) {
h ^= uint32(b)
h *= 16777619
}
return int32(h % 100000)
}
// DeleteBySourceID 删除指定 source 在用户专属 Collection 中的所有 ChildChunk
func (w *MilvusWriter) DeleteBySourceID(ctx context.Context, userID uint, sourceID uint) error {
expr := fmt.Sprintf(`source_id == %d`, sourceID)
return w.client.Delete(ctx, UserCollectionName(userID), "", expr)
}
// MilvusSearchResult Milvus 检索结果
type MilvusSearchResult struct {
ID int64
Content string
ParentBlockID int64
SourceID int64
ChunkType string
Metadata string
Score float32
}
// MilvusSearcher Milvus 检索接口
type MilvusSearcher interface {
SemanticSearch(ctx context.Context, userID uint, queryVector []float32, sourceIDs []uint, topK int) ([]MilvusSearchResult, error)
KeywordSearch(ctx context.Context, userID uint, queryText string, sourceIDs []uint, topK int) ([]MilvusSearchResult, error)
}
// SemanticSearch 基于 dense vector 的语义检索
func (w *MilvusWriter) SemanticSearch(ctx context.Context, userID uint, queryVector []float32, sourceIDs []uint, topK int) ([]MilvusSearchResult, error) {
collName := UserCollectionName(userID)
filter := ""
if len(sourceIDs) > 0 {
ids := make([]string, len(sourceIDs))
for i, id := range sourceIDs {
ids[i] = fmt.Sprintf("%d", id)
}
filter = fmt.Sprintf("source_id in [%s]", strings.Join(ids, ","))
}
searchParams, err := entity.NewIndexHNSWSearchParam(200)
if err != nil {
return nil, fmt.Errorf("创建 HNSW 搜索参数失败: %w", err)
}
outputFields := []string{"content", "parent_block_id", "source_id", "chunk_type", "metadata"}
results, err := w.client.Search(ctx, collName, []string{}, filter, outputFields,
[]entity.Vector{entity.FloatVector(queryVector)},
"vector", entity.COSINE, topK, searchParams,
)
if err != nil {
return nil, fmt.Errorf("语义检索失败: %w", err)
}
return parseSearchResults(results), nil
}
// KeywordSearch 基于 sparse vector 的关键词检索
func (w *MilvusWriter) KeywordSearch(ctx context.Context, userID uint, queryText string, sourceIDs []uint, topK int) ([]MilvusSearchResult, error) {
collName := UserCollectionName(userID)
filter := ""
if len(sourceIDs) > 0 {
ids := make([]string, len(sourceIDs))
for i, id := range sourceIDs {
ids[i] = fmt.Sprintf("%d", id)
}
filter = fmt.Sprintf("source_id in [%s]", strings.Join(ids, ","))
}
searchParams, err := entity.NewIndexSparseInvertedSearchParam(0.0)
if err != nil {
return nil, fmt.Errorf("创建 sparse 搜索参数失败: %w", err)
}
outputFields := []string{"content", "parent_block_id", "source_id", "chunk_type", "metadata"}
// 将 queryText 转换为 sparse vector
querySparseMap := GenerateSparseVector(queryText)
querySparseVec, err := mapToSparseEmbedding(querySparseMap)
if err != nil {
return nil, fmt.Errorf("转换查询 sparse vector 失败: %w", err)
}
results, err := w.client.Search(ctx, collName, []string{}, filter, outputFields,
[]entity.Vector{querySparseVec},
"sparse_vector", entity.IP, topK, searchParams,
)
if err != nil {
return nil, fmt.Errorf("关键词检索失败: %w", err)
}
return parseSearchResults(results), nil
}
// parseSearchResults 将 Milvus SearchResult 转换为 MilvusSearchResult 切片
func parseSearchResults(results []milvusclient.SearchResult) []MilvusSearchResult {
if len(results) == 0 {
return nil
}
var parsed []MilvusSearchResult
for _, result := range results {
if result.Err != nil {
continue
}
for i := 0; i < result.ResultCount; i++ {
item := MilvusSearchResult{
Score: result.Scores[i],
}
// 从 ID 列获取主键
if result.IDs != nil {
if val, err := result.IDs.Get(i); err == nil {
if id, ok := val.(int64); ok {
item.ID = id
}
}
}
// 从 Fields 获取各字段
if col := result.Fields.GetColumn("content"); col != nil {
if val, err := col.Get(i); err == nil {
if s, ok := val.(string); ok {
item.Content = s
}
}
}
if col := result.Fields.GetColumn("parent_block_id"); col != nil {
if val, err := col.Get(i); err == nil {
if id, ok := val.(int64); ok {
item.ParentBlockID = id
}
}
}
if col := result.Fields.GetColumn("source_id"); col != nil {
if val, err := col.Get(i); err == nil {
if id, ok := val.(int64); ok {
item.SourceID = id
}
}
}
if col := result.Fields.GetColumn("chunk_type"); col != nil {
if val, err := col.Get(i); err == nil {
if s, ok := val.(string); ok {
item.ChunkType = s
}
}
}
if col := result.Fields.GetColumn("metadata"); col != nil {
if val, err := col.Get(i); err == nil {
if s, ok := val.(string); ok {
item.Metadata = s
}
}
}
parsed = append(parsed, item)
}
}
return parsed
}
// Close 关闭 Milvus 客户端
func (w *MilvusWriter) Close() {
w.client.Close()
}
// WrapDocuments 为 ChildChunk Document 添加 source_id 元数据
func WrapDocuments(docs []*schema.Document, sourceID uint) []*schema.Document {
for _, doc := range docs {
doc.MetaData["source_id"] = sourceID
}
return docs
}
service.go # 主入口
Go
package rag
import (
"context"
"fmt"
"strings"
"YoudaoNoteLm/internal/repository"
"YoudaoNoteLm/pkg/logger"
"github.com/cloudwego/eino/components/embedding"
"go.uber.org/zap"
)
// EmbedderProvider 根据 userID 获取对应的 Embedder
type EmbedderProvider func(ctx context.Context, userID uint) (embedding.Embedder, error)
// IngestionService 入库服务接口
type IngestionService interface {
// Ingest 批量入库源内容
Ingest(ctx context.Context, sourceIDs []uint) error
// IngestSingle 单个源入库
IngestSingle(ctx context.Context, sourceID uint) error
// DeleteSource 删除源的向量数据
DeleteSource(ctx context.Context, userID uint, sourceID uint) error
}
type ingestionService struct {
sourceRepo repository.SourceRepository
parentRepo repository.ParentBlockRepository
embedderProvider EmbedderProvider
milvusWriter *MilvusWriter
maxRetries int // 默认 3
}
// NewIngestionService 创建入库服务
func NewIngestionService(
sourceRepo repository.SourceRepository,
parentRepo repository.ParentBlockRepository,
embedderProvider EmbedderProvider,
milvusWriter *MilvusWriter,
) IngestionService {
return &ingestionService{
sourceRepo: sourceRepo,
parentRepo: parentRepo,
embedderProvider: embedderProvider,
milvusWriter: milvusWriter,
maxRetries: 3,
}
}
// IngestSingle 单个源入库
func (s *ingestionService) IngestSingle(ctx context.Context, sourceID uint) error {
// 1. 查询源
source, err := s.sourceRepo.FindByID(sourceID)
if err != nil {
return fmt.Errorf("查询源失败: %w", err)
}
if source.MarkdownContent == "" {
logger.Info("源内容为空,跳过入库", zap.Uint("source_id", sourceID))
return nil
}
logger.Info("开始入库流程",
zap.Uint("source_id", sourceID),
zap.Uint("user_id", source.UserID),
zap.Int("content_len", len(source.MarkdownContent)),
)
// 2. 更新状态为处理中
if err := s.sourceRepo.UpdateStatus(sourceID, "processing", ""); err != nil {
logger.Warn("更新源状态为处理中失败", zap.Uint("source_id", sourceID), zap.Error(err))
}
// 3. AST 解析
p := NewMarkdownParser()
docs, err := p.Parse(ctx, strings.NewReader(source.MarkdownContent))
if err != nil {
s.updateFailedStatus(sourceID, err.Error())
return err
}
// 4. 构建 ParentBlock
parentTransformer := NewParentTransformer(1000)
parentDocs, err := parentTransformer.Transform(ctx, docs)
if err != nil {
s.updateFailedStatus(sourceID, err.Error())
return err
}
// 5. 分割 ChildChunk
childTransformer := NewChildTransformer(400)
childDocs, err := childTransformer.Transform(ctx, parentDocs)
if err != nil {
s.updateFailedStatus(sourceID, err.Error())
return err
}
// 6. 语义增强
enhancer := NewSemanticTransformer()
enhancedDocs, err := enhancer.Transform(ctx, childDocs)
if err != nil {
s.updateFailedStatus(sourceID, err.Error())
return err
}
// 7. 检查是否有可入库的文档
if len(enhancedDocs) == 0 {
s.updateFailedStatus(sourceID, "解析后无有效内容,跳过入库")
return fmt.Errorf("源 %d 解析后无有效内容", sourceID)
}
enhancedDocs = WrapDocuments(enhancedDocs, sourceID)
// 8. 先写入 MySQL ParentBlock,拿到真实的自增 ID
blocks := ToParentBlocks(parentDocs, sourceID)
if err := s.retry(func() error {
return s.parentRepo.BatchCreate(blocks)
}); err != nil {
s.updateFailedStatus(sourceID, "写入 MySQL 失败: "+err.Error())
return err
}
logger.Info("MySQL ParentBlock 写入成功",
zap.Uint("source_id", sourceID),
zap.Int("block_count", len(blocks)),
)
// 8.5 构建 parent_index → MySQL ID 的映射,更新子块 metadata
parentIndexToID := make(map[int]uint, len(blocks))
for _, b := range blocks {
parentIndexToID[b.ChunkIndex] = b.ID
}
for _, doc := range enhancedDocs {
if pidx, ok := doc.MetaData["parent_index"].(int); ok {
if realID, exists := parentIndexToID[pidx]; exists {
doc.MetaData["parent_block_id"] = realID
}
}
}
logger.Info("准备向量化",
zap.Uint("source_id", sourceID),
zap.Int("chunk_count", len(enhancedDocs)),
)
embedder, err := s.embedderProvider(ctx, source.UserID)
if err != nil {
errMsg := "获取 Embedder 失败: " + err.Error()
logger.Error(errMsg, zap.Uint("source_id", sourceID), zap.Uint("user_id", source.UserID))
s.updateFailedStatus(sourceID, errMsg)
return fmt.Errorf("%s", errMsg)
}
// 提取所有文本用于批量 embedding
texts := make([]string, len(enhancedDocs))
for i, doc := range enhancedDocs {
texts[i] = doc.Content
}
// 分批调用 Embedding API(豆包限制每次最多 256 条)
const embedBatchSize = 256
vectors := make([][]float32, len(texts))
logger.Info("调用 Embedding API", zap.Uint("source_id", sourceID), zap.Int("text_count", len(texts)))
for start := 0; start < len(texts); start += embedBatchSize {
end := start + embedBatchSize
if end > len(texts) {
end = len(texts)
}
batchTexts := texts[start:end]
var batchVectors [][]float64
if err := s.retry(func() error {
var err error
batchVectors, err = embedder.EmbedStrings(ctx, batchTexts)
if err != nil {
logger.Warn("Embedding 批次调用失败,重试中",
zap.Uint("source_id", sourceID),
zap.Int("batch_start", start),
zap.Int("batch_size", len(batchTexts)),
zap.Error(err),
)
return err
}
return nil
}); err != nil {
errMsg := "Embedding 失败: " + err.Error()
logger.Error(errMsg, zap.Uint("source_id", sourceID))
s.updateFailedStatus(sourceID, errMsg)
return fmt.Errorf("%s", errMsg)
}
// float64 → float32
for i, v := range batchVectors {
vectors[start+i] = make([]float32, len(v))
for j, f := range v {
vectors[start+i][j] = float32(f)
}
}
logger.Info("Embedding 批次完成",
zap.Uint("source_id", sourceID),
zap.Int("batch_start", start),
zap.Int("batch_end", end),
)
}
logger.Info("Embedding 全部完成", zap.Uint("source_id", sourceID), zap.Int("vector_count", len(vectors)))
// 9. 生成 sparse vector
sparseVectors := make([]map[int32]float32, len(enhancedDocs))
for i, doc := range enhancedDocs {
sparseVectors[i] = GenerateSparseVector(doc.Content)
}
logger.Info("Sparse vector 生成完成", zap.Uint("source_id", sourceID), zap.Int("count", len(sparseVectors)))
// 10. 确保用户的 Milvus Collection 存在
if err := s.milvusWriter.EnsureCollection(ctx, source.UserID); err != nil {
errMsg := "确保 Milvus Collection 失败: " + err.Error()
logger.Error(errMsg, zap.Uint("source_id", sourceID), zap.Uint("user_id", source.UserID))
s.updateFailedStatus(sourceID, errMsg)
return fmt.Errorf("%s", errMsg)
}
// 11. 写入 Milvus
logger.Info("写入 Milvus", zap.Uint("source_id", sourceID), zap.Uint("user_id", source.UserID), zap.Int("doc_count", len(enhancedDocs)))
if err := s.retry(func() error {
return s.milvusWriter.StoreWithSparse(ctx, source.UserID, enhancedDocs, vectors, sparseVectors)
}); err != nil {
errMsg := "写入 Milvus 失败: " + err.Error()
logger.Error(errMsg, zap.Uint("source_id", sourceID))
s.updateFailedStatus(sourceID, errMsg)
return fmt.Errorf("%s", errMsg)
}
logger.Info("Milvus 写入成功", zap.Uint("source_id", sourceID))
// 11. 更新状态为就绪
if err := s.sourceRepo.UpdateStatus(sourceID, "ready", ""); err != nil {
logger.Warn("更新源状态为就绪失败", zap.Uint("source_id", sourceID), zap.Error(err))
}
if err := s.sourceRepo.SetVectorized(sourceID); err != nil {
logger.Warn("标记源已向量化失败", zap.Uint("source_id", sourceID), zap.Error(err))
}
return nil
}
// updateFailedStatus 更新源状态为失败
func (s *ingestionService) updateFailedStatus(sourceID uint, errMsg string) {
if err := s.sourceRepo.UpdateStatus(sourceID, "failed", errMsg); err != nil {
logger.Warn("更新源状态为失败失败", zap.Uint("source_id", sourceID), zap.Error(err))
}
}
// Ingest 批量入库
func (s *ingestionService) Ingest(ctx context.Context, sourceIDs []uint) error {
var lastErr error
for _, sourceID := range sourceIDs {
if err := s.IngestSingle(ctx, sourceID); err != nil {
lastErr = err
logger.Warn("入库失败", zap.Uint("source_id", sourceID), zap.Error(err))
continue
}
}
return lastErr
}
// DeleteSource 删除源的向量数据和父块数据
func (s *ingestionService) DeleteSource(ctx context.Context, userID uint, sourceID uint) error {
logger.Info("删除源数据",
zap.Uint("user_id", userID),
zap.Uint("source_id", sourceID),
)
if err := s.milvusWriter.DeleteBySourceID(ctx, userID, sourceID); err != nil {
logger.Error("删除 Milvus 数据失败",
zap.Uint("user_id", userID),
zap.Uint("source_id", sourceID),
zap.Error(err),
)
return fmt.Errorf("删除向量数据失败: %w", err)
}
// 删除 MySQL 中的 parent_blocks(source 软删除不会触发 CASCADE)
if err := s.parentRepo.DeleteBySourceID(sourceID); err != nil {
logger.Error("删除 parent_blocks 失败",
zap.Uint("source_id", sourceID),
zap.Error(err),
)
return fmt.Errorf("删除 parent_blocks 失败: %w", err)
}
logger.Info("删除源数据成功",
zap.Uint("user_id", userID),
zap.Uint("source_id", sourceID),
)
return nil
}
// retry 重试逻辑
func (s *ingestionService) retry(fn func() error) error {
var err error
for i := 0; i <= s.maxRetries; i++ {
err = fn()
if err == nil {
return nil
}
}
return err
}