简易版的EINO基于redis库的向量搜索项目v2

介绍

项目结构

```

eino_redis_rag/

├── main.go # 程序入口,HTTP 服务启动

├── config.yaml # 系统配置文件

├── config/

│ └── config.go # 配置结构体定义与加载

├── rag/

│ ├── rag.go # RAG 引擎核心

│ ├── splitter.go # 文档切分器

│ └── xlsx_parser.go # XLSX 解析器

├── service/

│ ├── indexer_service.go # 索引服务

│ ├── retriever_service.go # 检索服务

│ └── watcher_service.go # 文件监听服务

├── test_txt/ # 测试文档目录

├── go.mod # Go 模块定义

├── go.sum # Go 模块校验和

└──

```


详细执行逻辑

一、系统启动流程

```

main()

├─ 1. initLogger() # 初始化 logrus JSON 格式日志

├─ 2. config.LoadConfig() # 从 config.yaml 加载配置

│ ├─ 解析 server/redis/embedding/chat_model/index/retriever 配置

│ ├─ 解析 watcher 文件监听配置

│ └─ 解析 prompt 提示词配置

├─ 3. rag.NewRAGEngine() # 创建 RAG 引擎实例

│ ├─ 创建 Redis 客户端

│ ├─ 创建 Embedding 模型(BGE-M3 via Ollama)

│ └─ 创建 Chat 模型(Qwen3-Coder-Next)

├─ 4. ragEngine.InitLoader() # 初始化文件加载器

│ ├─ 注册 .pdf → pdfParser

│ ├─ 注册 .xlsx/.xls → XlsxParserForWholeFile

│ └─ 注册 .md/.txt → TextParser

├─ 5. ragEngine.InitSplitter() # 初始化文档分割器

│ ├─ header 模式: Markdown HeaderSplitter

│ ├─ recursive 模式: 递归字符分割器

│ └─ fixed 模式: 固定大小分割器

├─ 6. ragEngine.InitIndexer() # 初始化索引器

├─ 7. ragEngine.InitRetriever() # 初始化检索器

├─ 8. ragEngine.InitVectorIndex() # 创建 Redis 向量索引

│ ├─ 检查索引是否存在

│ ├─ FT.CREATE 创建 HNSW/FLAT 索引

│ └─ Schema: content(TEXT) + filename(TAG) + filetype(TAG) + source_id(TAG) + vector_content(VECTOR)

├─ 9. 创建服务实例

│ ├─ IndexerService

│ ├─ RetrieverService

│ └─ WatcherService

├─ 10. 启动 HTTP 服务器 (端口 8080)

├─ 11. 可选: 启动定时索引任务

└─ 12. 可选: 启动文件监听服务

```

二、文档索引流程 (IndexFolder)

```

POST /api/v1/index {"folder_path": "./test_txt"}

├─ 1. 递归遍历文件夹,筛选支持的文件类型

│ (.md, .txt, .pdf, .xlsx)

├─ 2. 对每个文件:

│ │

│ ├─ 2.1 计算文件 MD5

│ │ └─ 与 Redis 中存储的 MD5 比较

│ │ ├─ 相同 → 跳过(内容未变化)

│ │ └─ 不同 → 继续处理

│ │

│ ├─ 2.2 如果之前已索引过 → 删除旧文档块

│ │

│ ├─ 2.3 Loader.Load(file) # 根据扩展名选择解析器

│ │ ├─ .pdf → pdfParser (pdfcpu 提取 Unicode 文本)

│ │ ├─ .xlsx → XlsxParserForWholeFile (整文件解析)

│ │ └─ .md/.txt → TextParser

│ │

│ ├─ 2.4 Splitter.Transform(docs) # 文档切分

│ │ ├─ header 模式: 按 # 标题切分

│ │ ├─ recursive 模式: 按 \n\n → \n → 空格 → 字符递归切分

│ │ └─ fixed 模式: 按固定字符数切分

│ │

│ ├─ 2.5 为每个 chunk 生成唯一 ID

│ │ 格式: {key_prefix}{文件路径}:{块索引}

│ │ 示例: OuterCyrex:test_txt/mysql-1.md:0

│ │

│ ├─ 2.6 添加元数据

│ │ ├─ filename: 文件名

│ │ ├─ filetype: 文件扩展名

│ │ ├─ source_id: 原始文件路径

│ │ ├─ file_size: 文件大小

│ │ └─ indexed_at: 索引时间

│ │

│ ├─ 2.7 Indexer.Store(chunks) # 向量化 + 存储

│ │ ├─ 调用 BGE-M3 将内容转为 1024 维向量

│ │ └─ 存储为 Redis Hash

│ │ Key: OuterCyrex:test_txt/mysql-1.md:0

│ │ Fields: content, vector_content, filename, filetype, source_id, ...

│ │

│ └─ 2.8 更新 MD5 到 Redis

│ Key: OuterCyrex_metadata:test_txt/mysql-1.md

└─ 3. 返回 {"code": 200, "message": "Indexing successful"}

```

三、检索流程 (Retrieve)

```

POST /api/v1/retrieve {"query": "MySQL 是什么", "top_k": 5, "hybrid": true}

├─ 1. 将查询文本通过 Embedding 模型转为向量

├─ 2. 如果 hybrid=true (混合检索):

│ │

│ ├─ 2.1 向量检索

│ │ FT.SEARCH OuterIndex "*" VECTOR vector_content <params> RETURN <top_k>

│ │

│ ├─ 2.2 全文检索

│ │ FT.SEARCH OuterIndex "<query>" RETURN <top_k>

│ │

│ └─ 2.3 RRF 融合

│ score = 0.6 * 1/(60 + vector_rank) + 0.4 * 1/(60 + text_rank)

│ 按 score 排序,去重,返回 Top-K

├─ 3. 如果 hybrid=false (纯向量检索):

│ └─ FT.SEARCH 向量相似度搜索

└─ 4. 返回检索结果

{"code": 200, "docs": [{id, content, metaData}, ...]}

```

四、生成回答流程 (Generate)

```

POST /api/v1/generate {"query": "MySQL 是什么", "top_k": 5}

├─ 1. 调用 Retrieve 检索相关文档

├─ 2. 将检索到的文档内容拼接

├─ 3. 构建提示词

│ System: 角色定义 + 规则 + 检索文档

│ User: 用户问题

├─ 4. 调用 ChatModel.Stream() 流式生成

└─ 5. 读取流式输出,拼接为完整答案返回

{"code": 200, "content": "MySQL 是一种关系型数据库..."}

```

五、SSE 流式输出流程 (Generate Stream)

```

POST /api/v1/generate/stream {"query": "MySQL 是什么"}

├─ 1. 设置响应头

│ Content-Type: text/event-stream

│ Cache-Control: no-cache

│ Connection: keep-alive

│ X-Accel-Buffering: no

├─ 2. 发送开始事件

│ data: {"event":"start","data":"","timestamp":1234567890}

├─ 3. 流式发送内容片段

│ data: {"event":"chunk","data":"MySQL ","timestamp":1234567891}

│ data: {"event":"chunk","data":"是一种 ","timestamp":1234567892}

│ ...

└─ 4. 发送结束事件

data: {"event":"done","data":"","timestamp":1234567900}

```

六、文件监听流程 (Watcher)

```

启动时自动启动(配置 watcher.enable: true)

├─ 1. 使用 fsnotify 监听指定目录

├─ 2. 监听事件:

│ ├─ CREATE → 延迟 3 秒后索引新文件

│ ├─ WRITE → 延迟 3 秒后重新索引文件

│ ├─ RENAME → 延迟 3 秒后检查并处理

│ └─ REMOVE → 删除对应文档

└─ 3. 防抖机制: 同一文件多次变化只触发一次索引

```

七、Redis 数据结构

```

=== 文档数据 ===

Key: OuterCyrex:test_txt/mysql-1.md:0

Type: Hash

Fields:

content → "文档文本内容..."

vector_content → [1024维FLOAT32向量]

filename → "mysql-1.md"

filetype → ".md"

source_id → "test_txt/mysql-1.md"

file_size → "1024"

indexed_at → "2026-05-12T15:00:00+08:00"

=== MD5 元数据 ===

Key: OuterCyrex_metadata:test_txt/mysql-1.md

Type: String

Value: d41d8cd98f00b204e9800998ecf8427e

=== 向量索引 ===

Index Name: OuterIndex

Type: HASH

Prefix: OuterCyrex:

Schema:

  • content: TEXT

  • filename: TAG

  • filetype: TAG

  • source_id: TAG

  • vector_content: VECTOR (HNSW, FLOAT32, 1024 DIM, COSINE)

```


API 接口清单

| 方法 | 路径 | 说明 |

|------|------|------|

| GET | `/health` | 健康检查 |

| POST | `/api/v1/index` | 索引文件夹 |

| POST | `/api/v1/retrieve` | 检索文档(支持混合检索) |

| POST | `/api/v1/generate` | 生成回答(JSON) |

| POST | `/api/v1/generate/stream` | 生成回答(SSE流式) |

| POST | `/api/v1/retrieve-and-generate` | 检索并生成 |

| POST | `/api/v1/document/delete` | 删除文档 |


编译验证

```bash

cd eino_redis_rag

go build -o eino_redis_rag.exe .

```

✅ 编译成功,无错误无警告。


后续优化建议

  1. **P3 - 低优先级**
  • 添加 Prometheus 指标监控

  • 实现检索结果重排序(Cross-Encoder)

  • 添加分布式追踪支持

  • 实现向量索引增量更新

  1. **测试**
  • 编写单元测试

  • 编写集成测试

  • 压力测试

mian.go

Go 复制代码
见版本1

config.yaml

Go 复制代码
# RAG 系统配置文件

# Redis 数据库配置
redis:
  host: "100.102.39.213"
  port: 6379
  db: 0
  password: "redis2026"
  protocol: 2
  unstable_resp3: true

# 向量嵌入模型配置
embedding:
  api_key: "dsdsdsds"
  base_url: "http://100.102.39.213:11434/v1"
  model: "ollama.rnd.huawei.com/library/bge-m3:latest"
  timeout_seconds: 30
  dimension: 1024

# 聊天模型配置
chat_model:
  api_key: "dsdsdsds"
  base_url: "http://100.102.39.213:8000/v1"
  model: "Qwen/Qwen3-Coder-Next"
  timeout_seconds: 30
  # 生成参数配置
  temperature: 0.7
  max_tokens: 2048

# Redis 索引配置
index:
  name: "OuterIndex"
  key_prefix: "OuterCyrex:"
  vector_field: "vector_content"
  # 向量索引算法:FLAT 或 HNSW
  algorithm: "HNSW"
  # HNSW 参数(仅当 algorithm=HNSW 时生效)
  hnsw_m: 16
  hnsw_ef_construction: 200
  hnsw_ef_runtime: 10

# 文档分割配置
splitter:
  # 分割策略:header(按标题), recursive(递归分割), fixed(固定大小)
  strategy: "recursive"
  # 每块最大字符数
  chunk_size: 1000
  # 重叠字符数
  chunk_overlap: 200
  # Markdown 标题分割配置(仅当 strategy=header 时生效)
  markdown_headers:
    - "#"
    - "##"
    - "###"
  # 是否保留标题行
  trim_headers: false

# 文件加载配置
file_loader:
  # 支持的文件扩展名
  supported_extensions:
    - ".md"
    - ".txt"
    - ".pdf"
    - ".xlsx"
  # 是否使用文件名作为文档 ID
  use_name_as_id: true

# 检索配置
retriever:
  # 返回最相似的文档数量
  top_k: 5
  # 距离阈值,nil 表示不限制
  distance_threshold: null
  # 使用的 Dialect 版本
  dialect: 2
  # 要返回的字段
  return_fields:
    - "vector_content"
    - "content"
  # 是否启用混合检索(向量 + 全文)
  enable_hybrid: true
  # 向量检索权重(0.0-1.0,1.0表示纯向量,0.5表示均衡)
  vector_weight: 0.7
  # 混合检索时向量TopK
  hybrid_vector_topk: 20
  # 混合检索时全文TopK
  hybrid_text_topk: 20
  # 融合策略:rrf(倒数排名融合), weighted(加权融合)
  fusion_strategy: "rrf"
  # RRF 偏移量
  rrf_offset: 60

# Prompt 配置
prompt:
  # 系统提示词模板,支持 {documents} 和 {query} 占位符
  system_template: |
    # Role: Student Learning Assistant

    # Language: Chinese

    # Critical Rules:
    1. You MUST ONLY use the information provided in the documents below to answer questions.
    2. If the documents do not contain information relevant to the question, you MUST state "根据现有资料,我无法找到相关信息来回答这个问题。"
    3. Do NOT make up information or use your own knowledge to answer.
    4. Do NOT mention people, facts, or data that are not in the provided documents.

    - When providing assistance:
      • Be clear and concise
      • Only reference information from the provided documents
      • If multiple documents are provided, clearly indicate which source each piece of information comes from

    here's documents searched for you:
    ==== doc start ====
      {documents}
    ==== doc end ====

# 服务配置
server:
  # 服务端口
  port: 8080
  # 是否启用定时任务
  enable_cron: true
  # 定时任务执行时间(cron 表达式)
  cron_schedule: "0 2 * * *"  # 每天凌晨 2 点执行
  # 读取超时时间(秒)
  read_timeout_seconds: 30
  # 写入超时时间(秒)
  write_timeout_seconds: 30

# 文件监听配置
watcher:
  # 是否启用文件监听
  enable: true
  # 监听的目录路径
  watch_directory: "./test_txt"
  # 防抖时间(秒),避免短时间内多次触发
  debounce_seconds: 2

# 日志配置
log:
  # 日志级别:debug, info, warn, error
  level: "info"
  # 是否使用 JSON 格式输出
  json_format: true
  # 日志文件路径
  file_path: "app.log"

eino_redis_rag

config\config.go
Go 复制代码
package config

import (
	"fmt"
	"os"
	"sync"

	"gopkg.in/yaml.v3"
)

// Config 是 RAG 系统的配置结构体
// 包含所有运行时需要的配置参数
type Config struct {
	// Redis 数据库配置
	Redis struct {
		Host          string `yaml:"host"`           // Redis 服务器地址
		Port          int    `yaml:"port"`           // Redis 服务器端口
		DB            int    `yaml:"db"`             // Redis 数据库编号
		Password      string `yaml:"password"`       // Redis 密码
		Protocol      int    `yaml:"protocol"`       // Redis 协议版本 (2 或 3)
		UnstableResp3 bool   `yaml:"unstable_resp3"` // 是否使用不稳定的 RESP3 协议
	} `yaml:"redis"` // Redis 配置块

	// 向量嵌入模型配置
	Embedding struct {
		APIKey         string `yaml:"api_key"`         // API 密钥
		BaseURL        string `yaml:"base_url"`        // API 基础 URL
		Model          string `yaml:"model"`           // 模型名称
		TimeoutSeconds int    `yaml:"timeout_seconds"` // 请求超时时间(秒)
		Dimension      int    `yaml:"dimension"`       // 向量维度
	} `yaml:"embedding"` // 嵌入模型配置块

	// 聊天模型配置
	ChatModel struct {
		APIKey         string  `yaml:"api_key"`         // API 密钥
		BaseURL        string  `yaml:"base_url"`        // API 基础 URL
		Model          string  `yaml:"model"`           // 模型名称
		TimeoutSeconds int     `yaml:"timeout_seconds"` // 请求超时时间(秒)
		Temperature    float64 `yaml:"temperature"`     // 生成温度
		MaxTokens      int     `yaml:"max_tokens"`      // 最大生成 token 数
	} `yaml:"chat_model"` // 聊天模型配置块

	// Redis 索引配置
	Index struct {
		Name               string `yaml:"name"`                // Redis 向量索引名称
		KeyPrefix          string `yaml:"key_prefix"`          // Redis 文档键前缀
		VectorField        string `yaml:"vector_field"`        // 向量字段名称
		Algorithm          string `yaml:"algorithm"`           // 向量索引算法:FLAT 或 HNSW
		HNSWMM             int    `yaml:"hnsw_m"`              // HNSW M 参数
		HNSWEfConstruction int    `yaml:"hnsw_ef_construction"` // HNSW ef_construction
		HNSWEfRuntime      int    `yaml:"hnsw_ef_runtime"`     // HNSW ef_runtime
	} `yaml:"index"` // 索引配置块

	// 文档分割配置
	Splitter struct {
		Strategy        string   `yaml:"strategy"`         // 分割策略:header, recursive, fixed
		ChunkSize       int      `yaml:"chunk_size"`       // 每块最大字符数
		ChunkOverlap    int      `yaml:"chunk_overlap"`    // 重叠字符数
		MarkdownHeaders []string `yaml:"markdown_headers"` // Markdown 标题分割配置
		TrimHeaders     bool     `yaml:"trim_headers"`     // 是否移除标题行
	} `yaml:"splitter"` // 分割器配置块

	// 文件加载配置
	FileLoader struct {
		SupportedExtensions []string `yaml:"supported_extensions"` // 支持的文件扩展名
		UseNameAsID         bool     `yaml:"use_name_as_id"`       // 是否使用文件名作为文档 ID
	} `yaml:"file_loader"` // 文件加载器配置块

	// 检索配置
	Retriever struct {
		TopK              int      `yaml:"top_k"`               // 返回最相似的文档数量
		DistanceThreshold *float64 `yaml:"distance_threshold"`  // 距离阈值(可选)
		Dialect           int      `yaml:"dialect"`             // Redis Dialect 版本
		ReturnFields      []string `yaml:"return_fields"`       // 要返回的字段
		EnableHybrid      bool     `yaml:"enable_hybrid"`       // 是否启用混合检索
		VectorWeight      float64  `yaml:"vector_weight"`       // 向量检索权重
		HybridVectorTopK  int      `yaml:"hybrid_vector_topk"`  // 混合检索向量TopK
		HybridTextTopK    int      `yaml:"hybrid_text_topk"`    // 混合检索全文TopK
		FusionStrategy    string   `yaml:"fusion_strategy"`     // 融合策略:rrf, weighted
		RRFOffset         int      `yaml:"rrf_offset"`          // RRF 偏移量
	} `yaml:"retriever"` // 检索器配置块

	// Prompt 配置
	Prompt struct {
		SystemTemplate string `yaml:"system_template"` // 系统提示词模板
	} `yaml:"prompt"` // Prompt 配置块

	// 服务配置
	Server struct {
		Port             int    `yaml:"port"`                // 服务端口
		EnableCron       bool   `yaml:"enable_cron"`         // 是否启用定时任务
		CronSchedule     string `yaml:"cron_schedule"`       // 定时任务执行时间(cron 表达式)
		ReadTimeout      int    `yaml:"read_timeout_seconds"`  // 读取超时时间(秒)
		WriteTimeout     int    `yaml:"write_timeout_seconds"` // 写入超时时间(秒)
	} `yaml:"server"` // 服务配置块

	// 文件监听配置
	Watcher struct {
		Enable          bool   `yaml:"enable"`           // 是否启用文件监听
		WatchDirectory  string `yaml:"watch_directory"`  // 监听的目录路径
		DebounceSeconds int    `yaml:"debounce_seconds"` // 防抖时间(秒)
	} `yaml:"watcher"` // 文件监听配置块

	// 日志配置
	Log struct {
		Level      string `yaml:"level"`       // 日志级别
		JSONFormat bool   `yaml:"json_format"` // 是否使用 JSON 格式
		FilePath   string `yaml:"file_path"`   // 日志文件路径
	} `yaml:"log"` // 日志配置块
}

// 全局配置实例
var (
	config *Config    // 配置实例
	once   sync.Once  // 用于确保只初始化一次
	configMu sync.RWMutex // 配置读写锁,支持热更新
)

// LoadConfig 加载配置文件
// 从 config.yaml 文件中读取配置并解析为 Config 结构体
// 使用单例模式确保只加载一次配置
//
// 返回:
//   - *Config: 配置实例
//   - error: 错误信息
func LoadConfig() (*Config, error) {
	var loadErr error

	once.Do(func() {
		// 读取配置文件
		data, err := os.ReadFile("config.yaml")
		if err != nil {
			loadErr = fmt.Errorf("读取配置文件失败: %w", err)
			return
		}

		// 解析 YAML
		cfg := new(Config)
		err = yaml.Unmarshal(data, cfg)
		if err != nil {
			loadErr = fmt.Errorf("解析配置文件失败: %w", err)
			return
		}

		// 应用默认值
		applyDefaults(cfg)

		configMu.Lock()
		config = cfg
		configMu.Unlock()
	})

	if loadErr != nil {
		return nil, loadErr
	}

	if config == nil {
		return nil, os.ErrNotExist
	}

	return config, nil
}

// applyDefaults 应用配置默认值
func applyDefaults(cfg *Config) {
	// 聊天模型默认值
	if cfg.ChatModel.Temperature <= 0 {
		cfg.ChatModel.Temperature = 0.7
	}
	if cfg.ChatModel.MaxTokens <= 0 {
		cfg.ChatModel.MaxTokens = 2048
	}

	// 索引算法默认值
	if cfg.Index.Algorithm == "" {
		cfg.Index.Algorithm = "HNSW"
	}
	if cfg.Index.HNSWMM <= 0 {
		cfg.Index.HNSWMM = 16
	}
	if cfg.Index.HNSWEfConstruction <= 0 {
		cfg.Index.HNSWEfConstruction = 200
	}
	if cfg.Index.HNSWEfRuntime <= 0 {
		cfg.Index.HNSWEfRuntime = 10
	}

	// 分割器默认值
	if cfg.Splitter.Strategy == "" {
		cfg.Splitter.Strategy = "recursive"
	}
	if cfg.Splitter.ChunkSize <= 0 {
		cfg.Splitter.ChunkSize = 1000
	}
	if cfg.Splitter.ChunkOverlap < 0 {
		cfg.Splitter.ChunkOverlap = 0
	}
	if cfg.Splitter.ChunkOverlap >= cfg.Splitter.ChunkSize {
		cfg.Splitter.ChunkOverlap = cfg.Splitter.ChunkSize / 5
	}

	// 检索默认值
	if cfg.Retriever.TopK <= 0 {
		cfg.Retriever.TopK = 5
	}
	if cfg.Retriever.VectorWeight <= 0 || cfg.Retriever.VectorWeight > 1 {
		cfg.Retriever.VectorWeight = 0.7
	}
	if cfg.Retriever.HybridVectorTopK <= 0 {
		cfg.Retriever.HybridVectorTopK = 20
	}
	if cfg.Retriever.HybridTextTopK <= 0 {
		cfg.Retriever.HybridTextTopK = 20
	}
	if cfg.Retriever.FusionStrategy == "" {
		cfg.Retriever.FusionStrategy = "rrf"
	}
	if cfg.Retriever.RRFOffset <= 0 {
		cfg.Retriever.RRFOffset = 60
	}

	// Prompt 默认值
	if cfg.Prompt.SystemTemplate == "" {
		cfg.Prompt.SystemTemplate = defaultSystemPrompt
	}

	// 服务默认值
	if cfg.Server.Port <= 0 {
		cfg.Server.Port = 8080
	}
	if cfg.Server.ReadTimeout <= 0 {
		cfg.Server.ReadTimeout = 30
	}
	if cfg.Server.WriteTimeout <= 0 {
		cfg.Server.WriteTimeout = 30
	}

	// 日志默认值
	if cfg.Log.Level == "" {
		cfg.Log.Level = "info"
	}
	if cfg.Log.FilePath == "" {
		cfg.Log.FilePath = "app.log"
	}
}

// defaultSystemPrompt 默认系统提示词
const defaultSystemPrompt = `
# Role: Student Learning Assistant

# Language: Chinese

# Critical Rules:
1. You MUST ONLY use the information provided in the documents below to answer questions.
2. If the documents do not contain information relevant to the question, you MUST state "根据现有资料,我无法找到相关信息来回答这个问题。"
3. Do NOT make up information or use your own knowledge to answer.
4. Do NOT mention people, facts, or data that are not in the provided documents.

- When providing assistance:
  • Be clear and concise
  • Only reference information from the provided documents
  • If multiple documents are provided, clearly indicate which source each piece of information comes from

here's documents searched for you:
==== doc start ====
  {documents}
==== doc end ====
`

// GetConfig 获取全局配置实例(线程安全)
func GetConfig() *Config {
	configMu.RLock()
	defer configMu.RUnlock()
	return config
}

// ReloadConfig 重新加载配置文件(支持热更新)
func ReloadConfig() (*Config, error) {
	// 读取配置文件
	data, err := os.ReadFile("config.yaml")
	if err != nil {
		return nil, fmt.Errorf("读取配置文件失败: %w", err)
	}

	// 解析 YAML
	cfg := new(Config)
	err = yaml.Unmarshal(data, cfg)
	if err != nil {
		return nil, fmt.Errorf("解析配置文件失败: %w", err)
	}

	// 应用默认值
	applyDefaults(cfg)

	configMu.Lock()
	config = cfg
	configMu.Unlock()

	return config, nil
}

// ResetConfigForTest 重置配置(用于测试)
func ResetConfigForTest() {
	configMu.Lock()
	config = nil
	once = sync.Once{}
	configMu.Unlock()
}

rag

rag.go

Go 复制代码
package rag

import (
	"context"
	"fmt"
	"strings"
	"time"

	"eino_redis_rag/config"

	"github.com/cloudwego/eino-ext/components/document/loader/file"
	pdfParser "github.com/cloudwego/eino-ext/components/document/parser/pdf"
	embedding2 "github.com/cloudwego/eino-ext/components/embedding/openai"
	redisInd "github.com/cloudwego/eino-ext/components/indexer/redis"
	redisRet "github.com/cloudwego/eino-ext/components/retriever/redis"
	"github.com/cloudwego/eino-ext/components/model/openai"
	"github.com/cloudwego/eino/components/prompt"
	"github.com/cloudwego/eino/components/document"
	"github.com/cloudwego/eino/components/document/parser"
	"github.com/cloudwego/eino/schema"
	"github.com/redis/go-redis/v9"
	"github.com/sirupsen/logrus"
)

// RAGEngine 是 RAG(Retrieval-Augmented Generation)引擎的核心结构体
// 负责管理所有 RAG 组件:嵌入模型、文件加载器、文档分割器、检索器、索引器和聊天模型
type RAGEngine struct {
	indexName       string          // Redis 向量索引名称
	prefix          string          // Redis 文档键前缀
	dimension       int             // 向量维度
	Redis           *redis.Client   // Redis 客户端(导出字段,供外部访问)
	Embedder        *embedding2.Embedder    // 嵌入模型,用于将文本转换为向量
	Loader          *file.FileLoader        // 文件加载器,用于加载文档
	Splitter        document.Transformer    // 文档分割器,用于将文档分割成块
	Retriever       *redisRet.Retriever     // 检索器,用于从 Redis 检索相关文档
	RetrieverConfig *redisRet.RetrieverConfig // 检索器配置(导出字段,支持动态修改 TopK)
	Indexer         *redisInd.Indexer       // 索引器,用于将文档存储到 Redis
	ChatModel       *openai.ChatModel       // 聊天模型,用于生成回答
	systemPrompt    string              // 系统提示词模板
}

// NewRAGEngine 初始化 RAG 引擎
// 创建 Redis 客户端、嵌入模型和聊天模型,并返回 RAGEngine 实例
func NewRAGEngine(ctx context.Context, cfg *config.Config) (*RAGEngine, error) {
	// 创建 Redis 客户端
	redisClient := redis.NewClient(&redis.Options{
		Addr:          fmt.Sprintf("%s:%d", cfg.Redis.Host, cfg.Redis.Port),
		Password:      cfg.Redis.Password,
		DB:            cfg.Redis.DB,
		Protocol:      cfg.Redis.Protocol,
		UnstableResp3: cfg.Redis.UnstableResp3,
	})

	// 测试 Redis 连接
	if err := redisClient.Ping(ctx).Err(); err != nil {
		return nil, fmt.Errorf("Redis 连接失败: %w", err)
	}
	logrus.Info("Redis 连接成功")

	// 创建嵌入模型
	timeout := time.Duration(cfg.Embedding.TimeoutSeconds) * time.Second
	embedder, err := embedding2.NewEmbedder(ctx, &embedding2.EmbeddingConfig{
		APIKey:  cfg.Embedding.APIKey,
		BaseURL: cfg.Embedding.BaseURL,
		Model:   cfg.Embedding.Model,
		Timeout: timeout,
	})
	if err != nil {
		return nil, fmt.Errorf("创建嵌入模型失败: %w", err)
	}
	logrus.Infof("嵌入模型初始化成功: %s, 维度: %d", cfg.Embedding.Model, cfg.Embedding.Dimension)

	// 创建聊天模型
	chatTimeout := time.Duration(cfg.ChatModel.TimeoutSeconds) * time.Second
	temp := float32(cfg.ChatModel.Temperature)
	maxTok := cfg.ChatModel.MaxTokens
	chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
		APIKey:      cfg.ChatModel.APIKey,
		BaseURL:     cfg.ChatModel.BaseURL,
		Model:       cfg.ChatModel.Model,
		Timeout:     chatTimeout,
		Temperature: &temp,
		MaxTokens:   &maxTok,
	})
	if err != nil {
		return nil, fmt.Errorf("创建聊天模型失败: %w", err)
	}
	logrus.Infof("聊天模型初始化成功: %s, temperature: %.2f, max_tokens: %d",
		cfg.ChatModel.Model, cfg.ChatModel.Temperature, cfg.ChatModel.MaxTokens)

	// 返回 RAGEngine 实例
	return &RAGEngine{
		indexName:    cfg.Index.Name,
		prefix:       cfg.Index.KeyPrefix,
		dimension:    cfg.Embedding.Dimension,
		Redis:        redisClient,
		Embedder:     embedder,
		ChatModel:    chatModel,
		systemPrompt: cfg.Prompt.SystemTemplate,
	}, nil
}

// InitLoader 初始化文件加载器
func (r *RAGEngine) InitLoader(ctx context.Context, cfg *config.Config) error {
	// 注册 PDF 解析器
	pdfP, err := pdfParser.NewPDFParser(ctx, &pdfParser.Config{
		ToPages: false,
	})
	if err != nil {
		return fmt.Errorf("创建 PDF 解析器失败: %w", err)
	}

	// 注册 XLSX 解析器
	xlsxP := NewXlsxParserForWholeFile()

	// 构建 ExtParser
	extParser, err := parser.NewExtParser(ctx, &parser.ExtParserConfig{
		Parsers: map[string]parser.Parser{
			".pdf":   pdfP,
			".xlsx":  xlsxP,
			".xls":   xlsxP,
			".md":    parser.TextParser{},
			".txt":   parser.TextParser{},
		},
		FallbackParser: parser.TextParser{},
	})
	if err != nil {
		return fmt.Errorf("创建 ExtParser 失败: %w", err)
	}

	l, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{
		UseNameAsID: cfg.FileLoader.UseNameAsID,
		Parser:      extParser,
	})
	if err != nil {
		return err
	}
	r.Loader = l
	logrus.Info("文件加载器初始化成功")
	return nil
}

// InitSplitter 初始化文档分割器
func (r *RAGEngine) InitSplitter(ctx context.Context, cfg *config.Config) error {
	splitter, err := NewSplitter(ctx, cfg)
	if err != nil {
		return err
	}
	r.Splitter = splitter
	return nil
}

// InitIndexer 初始化索引器
func (r *RAGEngine) InitIndexer(ctx context.Context) error {
	i, err := redisInd.NewIndexer(ctx, &redisInd.IndexerConfig{
		Client:           r.Redis,
		KeyPrefix:        r.prefix,
		DocumentToHashes: nil,
		BatchSize:        10,
		Embedding:        r.Embedder,
	})
	if err != nil {
		return err
	}
	r.Indexer = i
	logrus.Info("索引器初始化成功")
	return nil
}

// InitRetriever 初始化检索器
func (r *RAGEngine) InitRetriever(ctx context.Context, cfg *config.Config) error {
	retrieverConfig := &redisRet.RetrieverConfig{
		Client:            r.Redis,
		Index:             r.indexName,
		VectorField:       cfg.Index.VectorField,
		DistanceThreshold: cfg.Retriever.DistanceThreshold,
		Dialect:           cfg.Retriever.Dialect,
		ReturnFields:      cfg.Retriever.ReturnFields,
		DocumentConverter: nil,
		TopK:              cfg.Retriever.TopK,
		Embedding:         r.Embedder,
	}

	re, err := redisRet.NewRetriever(ctx, retrieverConfig)
	if err != nil {
		return err
	}
	r.Retriever = re
	r.RetrieverConfig = retrieverConfig
	logrus.Infof("检索器初始化成功: index=%s, topK=%d", r.indexName, cfg.Retriever.TopK)
	return nil
}

// InitVectorIndex 初始化向量索引(FT.CREATE)
// 支持 FLAT 和 HNSW 两种算法
// 添加 content_tag 字段支持混合检索
func (r *RAGEngine) InitVectorIndex(ctx context.Context, cfg *config.Config) error {
	// 检查索引是否已存在
	if _, err := r.Redis.Do(ctx, "FT.INFO", r.indexName).Result(); err == nil {
		logrus.Infof("索引 '%s' 已存在,跳过创建", r.indexName)
		return nil
	}

	// 构建索引创建命令
	algorithm := strings.ToUpper(cfg.Index.Algorithm)
	if algorithm == "" {
		algorithm = "HNSW"
	}

	logrus.Infof("创建向量索引: algorithm=%s, dimension=%d", algorithm, r.dimension)

	var createIndexArgs []interface{}

	if algorithm == "HNSW" {
		createIndexArgs = []interface{}{
			"FT.CREATE", r.indexName,
			"ON", "HASH",
			"PREFIX", "1", r.prefix,
			"SCHEMA",
			// 内容字段 - TEXT 支持全文搜索
			"content", "TEXT", "WEIGHT", "1.0",
			// 文件名字段 - TAG 支持过滤
			"filename", "TAG", "CASESENSITIVE",
			// 文件类型字段 - TAG 支持过滤
			"filetype", "TAG", "CASESENSITIVE",
			// 原始ID字段 - TAG 支持过滤
			"source_id", "TAG", "CASESENSITIVE",
			// 向量字段
			"vector_content", "VECTOR", "HNSW",
			"14",
			"TYPE", "FLOAT32",
			"DIM", r.dimension,
			"DISTANCE_METRIC", "COSINE",
			"M", cfg.Index.HNSWMM,
			"EF_CONSTRUCTION", cfg.Index.HNSWEfConstruction,
			"EF_RUNTIME", cfg.Index.HNSWEfRuntime,
		}
	} else {
		// FLAT 算法
		createIndexArgs = []interface{}{
			"FT.CREATE", r.indexName,
			"ON", "HASH",
			"PREFIX", "1", r.prefix,
			"SCHEMA",
			"content", "TEXT", "WEIGHT", "1.0",
			"filename", "TAG", "CASESENSITIVE",
			"filetype", "TAG", "CASESENSITIVE",
			"source_id", "TAG", "CASESENSITIVE",
			"vector_content", "VECTOR", "FLAT",
			"6",
			"TYPE", "FLOAT32",
			"DIM", r.dimension,
			"DISTANCE_METRIC", "COSINE",
		}
	}

	if err := r.Redis.Do(ctx, createIndexArgs...).Err(); err != nil {
		return fmt.Errorf("创建向量索引失败: %w", err)
	}

	// 验证索引创建成功
	if _, err := r.Redis.Do(ctx, "FT.INFO", r.indexName).Result(); err != nil {
		return fmt.Errorf("验证索引失败: %w", err)
	}

	logrus.Infof("向量索引 '%s' 创建成功 (algorithm=%s)", r.indexName, algorithm)
	return nil
}

// DropVectorIndex 删除向量索引(用于调试/重置)
func (r *RAGEngine) DropVectorIndex(ctx context.Context) error {
	// 先删除索引中的所有文档
	_, err := r.Redis.Do(ctx, "FT.INFO", r.indexName).Result()
	if err == nil {
		// 索引存在,先删除
		if err := r.Redis.Do(ctx, "FT.DROPINDEX", r.indexName, "DD").Err(); err != nil {
			logrus.Warnf("删除索引失败: %v", err)
		}
		logrus.Info("已删除旧索引及其数据")
	} else {
		logrus.Debug("索引不存在,无需删除")
	}
	return nil
}

// Generate 根据查询生成回答
func (r *RAGEngine) Generate(ctx context.Context, query string) (*schema.StreamReader[*schema.Message], error) {
	docs, err := r.Retriever.Retrieve(ctx, query)
	if err != nil {
		return nil, fmt.Errorf("检索文档失败: %w", err)
	}
	return r.GenerateWithDocs(ctx, query, docs)
}

// GenerateWithDocs 使用指定的文档生成回答
func (r *RAGEngine) GenerateWithDocs(ctx context.Context, query string, docs []*schema.Document) (*schema.StreamReader[*schema.Message], error) {
	documentsStr := documentsToString(docs)

	// 使用可配置的系统提示词
	systemTpl := r.systemPrompt
	if systemTpl == "" {
		systemTpl = defaultSystemPrompt
	}

	// 替换 {documents} 占位符
	systemContent := strings.ReplaceAll(systemTpl, "{documents}", documentsStr)

	tpl := prompt.FromMessages(schema.FString, []schema.MessagesTemplate{
		schema.SystemMessage(systemContent),
		schema.UserMessage("{content}"),
	}...)

	messages, err := tpl.Format(ctx, map[string]any{
		"content": query,
	})
	if err != nil {
		return nil, fmt.Errorf("格式化提示词失败: %w", err)
	}

	return r.ChatModel.Stream(ctx, messages)
}

// RetrieveWithFilter 带过滤条件的检索
// 支持按文件名、文件类型等条件过滤
func (r *RAGEngine) RetrieveWithFilter(ctx context.Context, query string, filter map[string]string, topK int) ([]*schema.Document, error) {
	// 构建过滤条件
	var filterParts []string
	for k, v := range filter {
		filterParts = append(filterParts, fmt.Sprintf("@%s:{%s}", k, v))
	}

	// 构建查询
	var searchQuery string
	if len(filterParts) > 0 {
		searchQuery = strings.Join(filterParts, " ") + " (" + query + ")"
	} else {
		searchQuery = query
	}

	// 动态修改 TopK
	originalTopK := r.RetrieverConfig.TopK
	r.RetrieverConfig.TopK = topK
	defer func() {
		r.RetrieverConfig.TopK = originalTopK
	}()

	return r.Retriever.Retrieve(ctx, searchQuery)
}

// documentsToString 将文档列表转换为可读的字符串格式
// 包含文件名、来源等元数据信息
func documentsToString(docs []*schema.Document) string {
	if len(docs) == 0 {
		return "(无相关文档)"
	}

	var sb strings.Builder
	for i, doc := range docs {
		sb.WriteString(fmt.Sprintf("\n=== 来源 %d ===\n", i+1))

		// 添加文件名信息
		if filename, ok := doc.MetaData["filename"]; ok {
			sb.WriteString(fmt.Sprintf("文件: %v\n", filename))
		}
		if sourceID, ok := doc.MetaData["source_id"]; ok {
			sb.WriteString(fmt.Sprintf("来源ID: %v\n", sourceID))
		}
		if chunkIdx, ok := doc.MetaData["chunk_index"]; ok {
			sb.WriteString(fmt.Sprintf("片段: %v\n", chunkIdx))
		}

		sb.WriteString("\n内容:\n")
		sb.WriteString(doc.Content)
		sb.WriteString("\n")
	}

	return sb.String()
}

// defaultSystemPrompt 默认系统提示词
const defaultSystemPrompt = `
# Role: Student Learning Assistant

# Language: Chinese

# Critical Rules:
1. You MUST ONLY use the information provided in the documents below to answer questions.
2. If the documents do not contain information relevant to the question, you MUST state "根据现有资料,我无法找到相关信息来回答这个问题。"
3. Do NOT make up information or use your own knowledge to answer.
4. Do NOT mention people, facts, or data that are not in the provided documents.

- When providing assistance:
  • Be clear and concise
  • Only reference information from the provided documents
  • If multiple documents are provided, clearly indicate which source each piece of information comes from

here's documents searched for you:
==== doc start ====
  {documents}
==== doc end ====
`

splitter.go

Go 复制代码
package rag

import (
	"context"
	"strings"
	"unicode/utf8"

	"eino_redis_rag/config"

	"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
	"github.com/cloudwego/eino/components/document"
	"github.com/cloudwego/eino/schema"
	"github.com/sirupsen/logrus"
)

// NewSplitter 初始化文档分割器
// 根据配置中的策略创建对应的分割器
// 支持的策略:
//   - header: 按 Markdown 标题层级分割
//   - recursive: 递归字符分割器(推荐,适用于所有文件类型)
//   - fixed: 固定大小分割器
//
// 参数:
//   - ctx: 上下文
//   - cfg: 配置实例
//
// 返回:
//   - document.Transformer: 文档转换器接口
//   - error: 错误信息
func NewSplitter(ctx context.Context, cfg *config.Config) (document.Transformer, error) {
	strategy := cfg.Splitter.Strategy
	if strategy == "" {
		strategy = "recursive"
	}

	logrus.Infof("使用文档分割策略: %s, chunk_size=%d, chunk_overlap=%d",
		strategy, cfg.Splitter.ChunkSize, cfg.Splitter.ChunkOverlap)

	switch strategy {
	case "header":
		return newHeaderSplitter(ctx, cfg)
	case "recursive":
		return newRecursiveCharacterSplitter(ctx, cfg)
	case "fixed":
		return newFixedSplitter(ctx, cfg)
	default:
		logrus.Warnf("未知分割策略 '%s',使用默认的 recursive 策略", strategy)
		return newRecursiveCharacterSplitter(ctx, cfg)
	}
}

// newHeaderSplitter 创建 Markdown 标题分割器
func newHeaderSplitter(ctx context.Context, cfg *config.Config) (document.Transformer, error) {
	headers := make(map[string]string)
	for _, h := range cfg.Splitter.MarkdownHeaders {
		headers[h] = "title"
	}

	t, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
		Headers:     headers,
		TrimHeaders: cfg.Splitter.TrimHeaders,
	})
	if err != nil {
		return nil, err
	}
	return t, nil
}

// ==================== 递归字符分割器 ====================

// recursiveCharacterSplitter 递归字符分割器
// 按不同的分隔符递归分割文本,直到块大小满足要求
// 分隔符优先级: \n\n > \n > 空格 > 字符
type recursiveCharacterSplitter struct {
	chunkSize    int
	chunkOverlap int
	separators   []string
}

func newRecursiveCharacterSplitter(ctx context.Context, cfg *config.Config) (document.Transformer, error) {
	return &recursiveCharacterSplitter{
		chunkSize:    cfg.Splitter.ChunkSize,
		chunkOverlap: cfg.Splitter.ChunkOverlap,
		separators:   []string{"\n\n", "\n", " ", ""},
	}, nil
}

func (s *recursiveCharacterSplitter) Transform(ctx context.Context, docs []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
	var result []*schema.Document

	for _, doc := range docs {
		chunks := s.splitText(doc.Content)
		for i, chunk := range chunks {
			if chunk == "" {
				continue
			}

			meta := make(map[string]any)
			if doc.MetaData != nil {
				for k, v := range doc.MetaData {
					meta[k] = v
				}
			}
			meta["chunk_index"] = i
			meta["total_chunks"] = len(chunks)
			meta["chunk_size"] = utf8.RuneCountInString(chunk)

			chunkDoc := &schema.Document{
				ID:       doc.ID + ":chunk:" + string(rune('0'+i)),
				Content:  chunk,
				MetaData: meta,
			}
			result = append(result, chunkDoc)
		}
	}

	logrus.Debugf("递归分割完成: 输入 %d 个文档, 输出 %d 个块", len(docs), len(result))
	return result, nil
}

func (s *recursiveCharacterSplitter) splitText(text string) []string {
	maxLength := s.chunkSize
	if utf8.RuneCountInString(text) <= maxLength {
		return []string{text}
	}

	// 尝试每个分隔符
	for _, sep := range s.separators {
		parts := s.splitWithSeparator(text, sep)
		
		// 检查是否所有部分都满足大小
		allGood := true
		for _, part := range parts {
			if utf8.RuneCountInString(part) > maxLength {
				allGood = false
				break
			}
		}

		if allGood {
			return s.mergeChunks(parts, maxLength, s.chunkOverlap)
		}

		// 如果是最后一个分隔符(空字符串),直接强制分割
		if sep == "" {
			return s.forceSplit(text, maxLength, s.chunkOverlap)
		}
	}

	return s.forceSplit(text, maxLength, s.chunkOverlap)
}

func (s *recursiveCharacterSplitter) splitWithSeparator(text, sep string) []string {
	if sep == "" {
		// 空分隔符:按字符分割
		runes := []rune(text)
		var parts []string
		for _, r := range runes {
			parts = append(parts, string(r))
		}
		return parts
	}
	return strings.Split(text, sep)
}

func (s *recursiveCharacterSplitter) mergeChunks(parts []string, maxLength, minOverlap int) []string {
	if len(parts) == 0 {
		return []string{}
	}

	var chunks []string
	currentChunk := ""

	for i, part := range parts {
		testChunk := part
		if currentChunk != "" {
			testChunk = currentChunk + "\n" + part
		}

		if utf8.RuneCountInString(testChunk) <= maxLength {
			currentChunk = testChunk
		} else {
			if currentChunk != "" {
				chunks = append(chunks, currentChunk)
			}

			// 计算重叠
			overlap := ""
			if minOverlap > 0 {
				runes := []rune(currentChunk)
				if len(runes) > minOverlap {
					overlap = string(runes[len(runes)-minOverlap:])
				}
			}

			currentChunk = overlap
			if currentChunk != "" {
				currentChunk = currentChunk + "\n" + part
			} else {
				currentChunk = part
			}

			// 如果仍然超过限制,强制截断
			if utf8.RuneCountInString(currentChunk) > maxLength {
				runes := []rune(currentChunk)
				currentChunk = string(runes[:maxLength])
			}
		}

		if i == len(parts)-1 && currentChunk != "" {
			chunks = append(chunks, currentChunk)
		}
	}

	if currentChunk != "" && len(chunks) == 0 {
		chunks = append(chunks, currentChunk)
	}

	return chunks
}

func (s *recursiveCharacterSplitter) forceSplit(text string, maxLength, minOverlap int) []string {
	runes := []rune(text)
	var chunks []string

	step := maxLength - minOverlap
	if step <= 0 {
		step = 1
	}

	for i := 0; i < len(runes); i += step {
		end := i + maxLength
		if end >= len(runes) {
			chunks = append(chunks, string(runes[i:]))
			break
		}
		chunks = append(chunks, string(runes[i:end]))
	}

	return chunks
}

// ==================== 固定大小分割器 ====================

type fixedSplitter struct {
	chunkSize    int
	chunkOverlap int
}

func newFixedSplitter(ctx context.Context, cfg *config.Config) (document.Transformer, error) {
	return &fixedSplitter{
		chunkSize:    cfg.Splitter.ChunkSize,
		chunkOverlap: cfg.Splitter.ChunkOverlap,
	}, nil
}

func (s *fixedSplitter) Transform(ctx context.Context, docs []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
	var result []*schema.Document

	for _, doc := range docs {
		chunks := s.splitFixed(doc.Content)
		for i, chunk := range chunks {
			if chunk == "" {
				continue
			}

			meta := make(map[string]any)
			if doc.MetaData != nil {
				for k, v := range doc.MetaData {
					meta[k] = v
				}
			}
			meta["chunk_index"] = i
			meta["total_chunks"] = len(chunks)

			result = append(result, &schema.Document{
				ID:       doc.ID + ":chunk:" + string(rune('0'+i)),
				Content:  chunk,
				MetaData: meta,
			})
		}
	}

	return result, nil
}

func (s *fixedSplitter) splitFixed(text string) []string {
	if text == "" {
		return []string{}
	}

	runes := []rune(text)
	step := s.chunkSize - s.chunkOverlap
	if step <= 0 {
		step = 1
	}

	var chunks []string
	for i := 0; i < len(runes); i += step {
		end := i + s.chunkSize
		if end > len(runes) {
			end = len(runes)
		}
		chunks = append(chunks, string(runes[i:end]))
		if end >= len(runes) {
			break
		}
	}

	return chunks
}

xlsx_parser.go

Go 复制代码
package rag

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"strings"

	"github.com/cloudwego/eino/components/document/parser"
	"github.com/cloudwego/eino/schema"
	"github.com/xuri/excelize/v2"
)

// readAllFromReader 从 io.Reader 中读取所有内容到 bytes.Reader
// 因为 excelize.OpenReader 需要 io.ReaderAt 接口,而 eino 的 parser.Parser 只传递 io.Reader
// bytes.Reader 实现了 io.ReaderAt 接口,所以可以用它来包装读取的数据
func readAllFromReader(reader io.Reader) (*bytes.Reader, error) {
	buf, err := io.ReadAll(reader)
	if err != nil {
		return nil, fmt.Errorf("读取数据失败: %w", err)
	}
	return bytes.NewReader(buf), nil
}

const (
	xlsxMetaDataRow = "_row"
	xlsxMetaDataExt = "_ext"
	xlsxMetaDataSheet = "_sheet"
)

// XlsxParserConfig 用于配置 XlsxParser
type XlsxParserConfig struct {
	// SheetName 设置要解析的工作表名称,默认为第一个工作表
	SheetName string
	// NoHeader 设置为 true 时表示第一行不是表头
	NoHeader bool
	// IDPrefix 自定义文档 ID 前缀
	IDPrefix string
}

// XlsxParser 自定义解析器,用于解析 XLSX 文件内容
// 支持有表头和无表头的 XLSX 文件
// 支持从多个工作表中选择特定表格
// 支持自定义文档 ID 前缀
type XlsxParser struct {
	config *XlsxParserConfig
}

// NewXlsxParser 创建一个新的 XlsxParser
func NewXlsxParser(ctx context.Context, config *XlsxParserConfig) (parser.Parser, error) {
	if config == nil {
		config = &XlsxParserConfig{}
	}
	return &XlsxParser{config: config}, nil
}

// generateID 根据配置生成文档 ID
func (p *XlsxParser) generateID(sheetName string, rowIdx int) string {
	if p.config.IDPrefix != "" {
		return fmt.Sprintf("%s%s:%d", p.config.IDPrefix, sheetName, rowIdx)
	}
	return fmt.Sprintf("%s:%d", sheetName, rowIdx)
}

// buildRowContent 将行数据构建为可读的文本内容
func (p *XlsxParser) buildRowContent(row []string, headers []string) string {
	if !p.config.NoHeader && len(headers) > 0 {
		// 有表头时,使用 "表头: 值" 的格式
		var parts []string
		for j, header := range headers {
			if j < len(row) {
				value := strings.TrimSpace(row[j])
				if value != "" {
					parts = append(parts, fmt.Sprintf("%s: %s", header, value))
				}
			}
		}
		return strings.Join(parts, "\n")
	}

	// 无表头时,直接用制表符连接
	contentParts := make([]string, len(row))
	for j, cell := range row {
		contentParts[j] = strings.TrimSpace(cell)
	}
	return strings.Join(contentParts, "\t")
}

// Parse 解析 XLSX 内容
// 将所有工作表的数据解析为文档列表
func (p *XlsxParser) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
	// 先读取所有内容到内存,bytes.Reader 实现了 io.ReaderAt 接口
	dataReader, err := readAllFromReader(reader)
	if err != nil {
		return nil, fmt.Errorf("读取 Excel 数据失败: %w", err)
	}

	// 使用 OpenReader 打开 Excel 文件(bytes.Reader 实现了 io.ReaderAt 接口)
	xlFile, err := excelize.OpenReader(dataReader)
	if err != nil {
		return nil, fmt.Errorf("打开 Excel 文件失败: %w", err)
	}
	defer xlFile.Close()

	// 获取所有工作表
	sheets := xlFile.GetSheetList()
	if len(sheets) == 0 {
		return nil, nil
	}

	var ret []*schema.Document

	// 确定要解析的工作表
	var targetSheets []string
	if p.config.SheetName != "" {
		targetSheets = []string{p.config.SheetName}
	} else {
		targetSheets = sheets
	}

	for _, sheetName := range targetSheets {
		// 获取该工作表的所有行(表头 + 数据行)
		rows, err := xlFile.GetRows(sheetName)
		if err != nil {
			// 如果某个工作表读取失败,记录错误但继续处理其他工作表
			continue
		}
		if len(rows) == 0 {
			continue
		}

		// 处理表头
		startIdx := 0
		var headers []string
		if !p.config.NoHeader && len(rows) > 0 {
			headers = rows[0]
			startIdx = 1
		}

		// 处理数据行
		for i := startIdx; i < len(rows); i++ {
			row := rows[i]
			if len(row) == 0 {
				continue
			}

			// 检查行是否全为空
			allEmpty := true
			for _, cell := range row {
				if strings.TrimSpace(cell) != "" {
					allEmpty = false
					break
				}
			}
			if allEmpty {
				continue
			}

			// 构建内容
			content := p.buildRowContent(row, headers)
			if content == "" {
				continue
			}

			// 构建元数据
			// 注意:Redis 无法序列化嵌套的 map[string]any,所以元数据只能包含简单类型(string, int, float64 等)
			meta := make(map[string]any)
			meta[xlsxMetaDataSheet] = sheetName
			meta[xlsxMetaDataRow] = i - startIdx + 1

			// 如果有表头,将表头信息存储为字符串(而不是嵌套map)
			if !p.config.NoHeader && len(headers) > 0 {
				meta["_headers"] = strings.Join(headers, ", ")
				// 将每个表头-值对存储为单独的元数据字段
				for j, header := range headers {
					if j < len(row) {
						// 使用 _col_ 前缀避免冲突
						meta["_col_"+header] = strings.TrimSpace(row[j])
					}
				}
			}

			// 创建文档
			doc := &schema.Document{
				ID:       p.generateID(sheetName, i),
				Content:  content,
				MetaData: meta,
			}

			ret = append(ret, doc)
		}
	}

	return ret, nil
}

// ParseAllSheetsToString 将整个 XLSX 文件的所有工作表解析为单个字符串内容
// 这种方式适合将整个 Excel 文件作为单个文档处理
func ParseAllSheetsToString(ctx context.Context, reader io.Reader) (string, error) {
	// 先读取所有内容到内存,bytes.Reader 实现了 io.ReaderAt 接口
	dataReader, err := readAllFromReader(reader)
	if err != nil {
		return "", fmt.Errorf("读取 Excel 数据失败: %w", err)
	}

	// 使用 OpenReader 打开 Excel 文件(bytes.Reader 实现了 io.ReaderAt 接口)
	xlFile, err := excelize.OpenReader(dataReader)
	if err != nil {
		return "", fmt.Errorf("打开 Excel 文件失败: %w", err)
	}
	defer xlFile.Close()

	sheets := xlFile.GetSheetList()
	if len(sheets) == 0 {
		return "", nil
	}

	var result strings.Builder

	for _, sheetName := range sheets {
		rows, err := xlFile.GetRows(sheetName)
		if err != nil {
			continue
		}
		if len(rows) == 0 {
			continue
		}

		// 写入工作表标题
		result.WriteString(fmt.Sprintf("【工作表: %s】\n", sheetName))

		for i, row := range rows {
			if len(row) == 0 {
				continue
			}

			// 检查行是否全为空
			allEmpty := true
			for _, cell := range row {
				if strings.TrimSpace(cell) != "" {
					allEmpty = false
					break
				}
			}
			if allEmpty {
				continue
			}

			if i == 0 {
				// 第一行作为表头,用制表符连接
				contentParts := make([]string, len(row))
				for j, cell := range row {
					contentParts[j] = strings.TrimSpace(cell)
				}
				result.WriteString(strings.Join(contentParts, "\t"))
				result.WriteString("\n")
				// 添加分隔线
				result.WriteString(strings.Repeat("-", len(row)*10) + "\n")
			} else {
				// 数据行,用制表符连接
				contentParts := make([]string, len(row))
				for j, cell := range row {
					contentParts[j] = strings.TrimSpace(cell)
				}
				result.WriteString(strings.Join(contentParts, "\t"))
				result.WriteString("\n")
			}
		}

		result.WriteString("\n")
	}

	return result.String(), nil
}

// XlsxParserForWholeFile 将整个 XLSX 文件解析为单个文档的解析器
// 这种方式更适合 RAG 场景,因为整个表格的上下文被保留
type XlsxParserForWholeFile struct{}

// NewXlsxParserForWholeFile 创建一个新的整文件解析器
func NewXlsxParserForWholeFile() parser.Parser {
	return &XlsxParserForWholeFile{}
}

// Parse 实现 parser.Parser 接口
func (p *XlsxParserForWholeFile) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
	content, err := ParseAllSheetsToString(ctx, reader)
	if err != nil {
		return nil, err
	}
	if content == "" {
		return nil, nil
	}

	// 注意:Redis 无法序列化嵌套的 map[string]any,所以元数据只能包含简单类型
	meta := make(map[string]any)

	doc := &schema.Document{
		ID:       "1",
		Content:  content,
		MetaData: meta,
	}

	return []*schema.Document{doc}, nil
}

service

indexer_service.go

Go 复制代码
package service

import (
	"context"
	"crypto/md5"
	"encoding/hex"
	"fmt"
	"os"
	"path/filepath"
	"strings"
	"time"

	"eino_redis_rag/config"
	"eino_redis_rag/rag"

	"github.com/cloudwego/eino/components/document"
	"github.com/redis/go-redis/v9"
	"github.com/sirupsen/logrus"
)

// IndexerService 索引服务
// 负责处理文档索引相关的操作
type IndexerService struct {
	ragEngine *rag.RAGEngine // RAG 引擎实例
	config    *config.Config // 配置实例
}

// NewIndexerService 创建索引服务
// 参数:
//   - ragEngine: RAG 引擎实例
//   - config: 配置实例
//
// 返回:
//   - *IndexerService: 索引服务实例
func NewIndexerService(ragEngine *rag.RAGEngine, config *config.Config) *IndexerService {
	return &IndexerService{
		ragEngine: ragEngine,
		config:    config,
	}
}

// IndexFolder 遍历指定文件夹下的所有文件,将可支持读取的文件读取并写入向量数据库中
// 参数:
//   - ctx: 上下文
//   - folderPath: 文件夹路径
//
// 返回:
//   - error: 错误信息
func (s *IndexerService) IndexFolder(ctx context.Context, folderPath string) error {
	// 检查文件夹是否存在
	if _, err := os.Stat(folderPath); os.IsNotExist(err) {
		return fmt.Errorf("文件夹不存在: %s", folderPath)
	}

	// 获取文件夹下所有支持的文件
	files, err := s.getFilesInFolder(folderPath)
	if err != nil {
		return err
	}

	// 如果没有文件,直接返回
	if len(files) == 0 {
		logrus.Info("文件夹中没有文件")
		return nil
	}

	logrus.Infof("找到 %d 个文件,开始索引", len(files))

	// 遍历文件并索引
	for _, file := range files {
		if err := s.indexFile(ctx, file); err != nil {
			logrus.Errorf("索引文件失败: %s, error: %v", file, err)
			continue // 继续处理其他文件
		}
		logrus.Infof("索引文件成功: %s", file)
	}

	return nil
}

// getFilesInFolder 获取文件夹下的所有文件
// 参数:
//   - folderPath: 文件夹路径
//
// 返回:
//   - []string: 文件路径列表
//   - error: 错误信息
func (s *IndexerService) getFilesInFolder(folderPath string) ([]string, error) {
	var files []string

	// 递归遍历文件夹
	err := filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}

		// 只处理文件,跳过目录
		if info.IsDir() {
			return nil
		}

		// 检查文件扩展名是否在支持的列表中
		ext := strings.ToLower(filepath.Ext(path))
		for _, supportedExt := range s.config.FileLoader.SupportedExtensions {
			if ext == supportedExt {
				files = append(files, path)
				break
			}
		}

		return nil
	})

	return files, err
}

// calculateFileMD5 计算文件的 MD5 值
// 参数:
//   - filePath: 文件路径
//
// 返回:
//   - string: MD5 值的十六进制表示
//   - error: 错误信息
func (s *IndexerService) calculateFileMD5(filePath string) (string, error) {
	data, err := os.ReadFile(filePath)
	if err != nil {
		return "", err
	}
	hash := md5.Sum(data)
	return hex.EncodeToString(hash[:]), nil
}

// normalizePath 规范化文件路径,将反斜杠转换为正斜杠
// 确保在不同系统上路径一致
// 参数:
//   - path: 文件路径
//
// 返回:
//   - string: 规范化后的路径
func normalizePath(path string) string {
	return strings.ReplaceAll(path, "\\", "/")
}

// getDocumentKey 获取文档在 Redis 中的键名
// 参数:
//   - filePath: 文件路径
//   - chunkIndex: 文档块索引
//
// 返回:
//   - string: Redis 键名
func (s *IndexerService) getDocumentKey(filePath string, chunkIndex int) string {
	// 使用文件路径和块索引生成唯一键名
	// 注意:Indexer 会自动添加 KeyPrefix,所以这里只需要提供文件路径和块索引
	// 格式:文件路径:块索引
	// 例如:test_txt/mysql-1.md:0
	// 使用规范化路径,确保在不同系统上键名一致
	normalizedPath := normalizePath(filePath)
	return fmt.Sprintf("%s:%d", normalizedPath, chunkIndex)
}

// getMetadataKey 获取文件元数据在 Redis 中的键名
// 参数:
//   - filePath: 文件路径
//
// 返回:
//   - string: Redis 键名
func (s *IndexerService) getMetadataKey(filePath string) string {
	// 使用文件路径生成元数据键名
	// 格式:keyPrefix_metadata:文件路径
	// 例如:OuterCyrex:metadata:test_txt/mysql-1.md
	// 使用规范化路径,确保在不同系统上键名一致
	normalizedPath := normalizePath(filePath)
	return fmt.Sprintf("%s_metadata:%s", s.config.Index.KeyPrefix, normalizedPath)
}

// isFileAlreadyIndexed 检查文件是否已经索引过
// 通过检查 Redis 中是否存在该文件的元数据键(存储 MD5 值)来判断
// 参数:
//   - ctx: 上下文
//   - filePath: 文件路径
//
// 返回:
//   - bool: 是否已索引
func (s *IndexerService) isFileAlreadyIndexed(ctx context.Context, filePath string) (bool, error) {
	// 通过检查元数据键是否存在来判断文件是否已索引
	metadataKey := s.getMetadataKey(filePath)
	_, err := s.ragEngine.Redis.Exists(ctx, metadataKey).Result()
	if err != nil {
		return false, err
	}
	
	// 如果键存在,说明文件已索引
	// Exists 返回 1 表示键存在,返回 0 表示键不存在
	result, err := s.ragEngine.Redis.Exists(ctx, metadataKey).Result()
	if err != nil {
		return false, err
	}
	
	return result > 0, nil
}

// DeleteDocumentByKey 直接通过 Redis key 删除文档
// 参数:
//   - ctx: 上下文
//   - key: Redis key
//
// 返回:
//   - error: 错误信息
func (s *IndexerService) DeleteDocumentByKey(ctx context.Context, key string) error {
	// 同时尝试正斜杠和反斜杠两种格式
	normalizedKey := normalizePath(key)
	
	deletedKeys := []string{}
	
	// 检查原始 key 是否存在
	exists, _ := s.ragEngine.Redis.Exists(ctx, key).Result()
	if exists > 0 {
		_, err := s.ragEngine.Redis.Del(ctx, key).Result()
		if err != nil {
			logrus.Warnf("删除键失败: %s, error: %v", key, err)
		} else {
			deletedKeys = append(deletedKeys, key)
		}
	}
	
	// 如果规范化后的 key 不同,也尝试删除
	if normalizedKey != key {
		exists, _ = s.ragEngine.Redis.Exists(ctx, normalizedKey).Result()
		if exists > 0 {
			_, err := s.ragEngine.Redis.Del(ctx, normalizedKey).Result()
			if err != nil {
				logrus.Warnf("删除键失败: %s, error: %v", normalizedKey, err)
			} else {
				deletedKeys = append(deletedKeys, normalizedKey)
			}
		}
	}
	
	if len(deletedKeys) > 0 {
		logrus.Infof("成功删除键: %v", deletedKeys)
	} else {
		logrus.Warnf("未找到键: %s 和 %s", key, normalizedKey)
	}
	
	return nil
}

// deleteOldDocuments 删除文件中所有已索引的文档块
// 在文件内容变更需要重新索引时,先删除旧文档,避免残留过时数据
// 参数:
//   - ctx: 上下文
//   - filePath: 文件路径
//
// 返回:
//   - error: 错误信息
func (s *IndexerService) deleteOldDocuments(ctx context.Context, filePath string) error {
	// 同时尝试正斜杠和反斜杠两种路径格式,确保兼容历史数据
	normalizedPath := normalizePath(filePath)
	originalPath := filePath
	
	keyPatterns := []string{
		s.config.Index.KeyPrefix + normalizedPath + ":*",
	}
	
	// 如果原始路径与规范化路径不同,也添加原始路径的模式
	if originalPath != normalizedPath {
		keyPatterns = append(keyPatterns, s.config.Index.KeyPrefix + originalPath + ":*")
	}
	
	logrus.Infof("删除旧文档: %s, patterns: %v", filePath, keyPatterns)
	
	deletedCount := 0
	
	// 对每种路径模式进行 SCAN 扫描
scan_loop:
	for _, keyPattern := range keyPatterns {
		cursor := uint64(0)
		for {
			result, err := s.ragEngine.Redis.Do(ctx, "SCAN", cursor, "MATCH", keyPattern, "COUNT", 100).Result()
			if err != nil {
				logrus.Warnf("SCAN 命令执行失败: %v", err)
				break
			}
			
			// SCAN 返回一个包含两个元素的 slice:[cursor, [keys...]]
			resultSlice, ok := result.([]interface{})
			if !ok || len(resultSlice) < 2 {
				logrus.Debugf("SCAN 结果无法解析: ok=%v, len=%d", ok, func() int { if r, ok := result.([]interface{}); ok { return len(r) }; return 0 }())
				break
			}
			
			// 更新 cursor - 需要处理多种类型:uint64, float64, string, int64
			var newCursor uint64
			switch v := resultSlice[0].(type) {
			case uint64:
				newCursor = v
			case float64:
				newCursor = uint64(v)
			case string:
				// Redis RESP3 协议可能返回字符串类型的 cursor
				fmt.Sscanf(v, "%d", &newCursor)
			case int64:
				newCursor = uint64(v)
			case int:
				newCursor = uint64(v)
			default:
				logrus.Warnf("无法解析 cursor: type=%T, value=%v", resultSlice[0], resultSlice[0])
				break scan_loop
			}
			cursor = newCursor
			
			// 获取键列表 - 需要处理多种类型
			var keys []interface{}
			switch k := resultSlice[1].(type) {
			case []interface{}:
				keys = k
			case []string:
				for _, s := range k {
					keys = append(keys, s)
				}
			case string:
				// 单个键的情况
				keys = []interface{}{k}
			case nil:
				keys = []interface{}{}
			default:
				logrus.Warnf("无法解析 keys: type=%T, value=%v", resultSlice[1], resultSlice[1])
				break
			}
			
			logrus.Infof("SCAN pattern=%s 找到 %d 个键: %v", keyPattern, len(keys), func() interface{} { return resultSlice[1] }())
			
			// 删除所有匹配的键
			if len(keys) > 0 {
				keysStr := make([]string, len(keys))
				for i, k := range keys {
					keysStr[i] = fmt.Sprintf("%v", k)
				}
				
				_, err := s.ragEngine.Redis.Del(ctx, keysStr...).Result()
				if err != nil {
					logrus.Warnf("删除键失败: %v, error: %v", keysStr, err)
				} else {
					deletedCount += len(keysStr)
					logrus.Infof("成功删除 %d 个键: %v", len(keysStr), keysStr)
				}
			}
			
			// 如果 cursor 为 0,说明扫描完成
			if cursor == 0 {
				break
			}
		}
	}
	
	logrus.Infof("删除旧文档完成: %s, 删除了 %d 个文档块", filePath, deletedCount)
	return nil
}

// indexFile 索引单个文件
// 1. 计算文件 MD5 值
// 2. 从 Redis 中获取已存储的 MD5 值
// 3. 如果 MD5 值相同,则跳过索引
// 4. 如果 MD5 值不同或不存在,则删除旧文档后重新索引
// 5. 存储新的 MD5 值到 Redis
//
// 参数:
//   - ctx: 上下文
//   - filePath: 文件路径
//
// 返回:
//   - error: 错误信息
func (s *IndexerService) indexFile(ctx context.Context, filePath string) error {
	// 步骤1: 计算当前文件的 MD5 值
	fileMD5, err := s.calculateFileMD5(filePath)
	if err != nil {
		return fmt.Errorf("计算文件 MD5 失败: %w", err)
	}

	// 步骤2: 获取元数据键(用于存储/读取 MD5 值)
	metadataKey := s.getMetadataKey(filePath)

	// 步骤3: 从 Redis 中获取已存储的 MD5 值
	storedMD5, err := s.ragEngine.Redis.Get(ctx, metadataKey).Result()
	if err != nil {
		if err == redis.Nil {
			// 键不存在,说明文件未索引过,需要执行完整索引
			logrus.Infof("文件未索引过,开始索引: %s", filePath)
		} else {
			return fmt.Errorf("获取文件 MD5 失败: %w", err)
		}
	}

	// 步骤4: 如果 MD5 值相同,说明文件内容未变化,跳过索引
	if storedMD5 != "" && storedMD5 == fileMD5 {
		logrus.Infof("文件未更改,跳过索引: %s (MD5: %s)", filePath, fileMD5)
		return nil
	}

	// 步骤5: 文件内容已变化(或首次索引),需要先删除旧文档
	// 如果之前已索引过,删除旧文档以避免残留过时数据
	if storedMD5 != "" {
		logrus.Infof("文件内容已变化,删除旧文档并重新索引: %s (旧MD5: %s, 新MD5: %s)", filePath, storedMD5, fileMD5)
		if err := s.deleteOldDocuments(ctx, filePath); err != nil {
			return fmt.Errorf("删除旧文档失败: %w", err)
		}
	} else {
		logrus.Infof("首次索引文件: %s", filePath)
	}

	// 步骤6: 加载文档
	docs, err := s.ragEngine.Loader.Load(ctx, document.Source{
		URI: filePath,
	})
	if err != nil {
		return fmt.Errorf("加载文档失败: %w", err)
	}

	if len(docs) == 0 {
		return fmt.Errorf("没有加载到文档内容")
	}

	// 步骤7: 分割文档为小块
	docs, err = s.ragEngine.Splitter.Transform(ctx, docs)
	if err != nil {
		return fmt.Errorf("分割文档失败: %w", err)
	}

	if len(docs) == 0 {
		return fmt.Errorf("文档分割后没有产生任何块")
	}

	// 步骤8: 为每个文档块生成唯一 ID 并添加丰富元数据
	// 元数据包括:文件名、文件类型、源路径、文件大小、索引时间等
	baseName := filepath.Base(filePath)
	ext := strings.ToLower(filepath.Ext(filePath))
	fileInfo, err := os.Stat(filePath)
	if err != nil {
		return fmt.Errorf("获取文件信息失败: %w", err)
	}
	normalizedPath := normalizePath(filePath)
	indexTime := time.Now().Format(time.RFC3339)

	for i, d := range docs {
		d.ID = s.getDocumentKey(filePath, i)
		
		// 确保 MetaData 存在
		if d.MetaData == nil {
			d.MetaData = make(map[string]any)
		}
		
		// 添加丰富元数据
		d.MetaData["filename"] = baseName
		d.MetaData["filetype"] = strings.TrimPrefix(ext, ".")
		d.MetaData["source_id"] = normalizedPath
		d.MetaData["source_path"] = normalizedPath
		d.MetaData["file_size"] = fileInfo.Size()
		d.MetaData["indexed_at"] = indexTime
		d.MetaData["modified_at"] = fileInfo.ModTime().Format(time.RFC3339)
		d.MetaData["md5"] = fileMD5
	}

	// 步骤9: 存储文档到 Redis
	// Indexer.Store 会自动将文档转换为向量并存储到 Redis
	_, err = s.ragEngine.Indexer.Store(ctx, docs)
	if err != nil {
		return fmt.Errorf("存储文档失败: %w", err)
	}

	// 步骤10: 存储新的 MD5 值到 Redis,用于下次对比
	_, err = s.ragEngine.Redis.Set(ctx, metadataKey, fileMD5, 0).Result()
	if err != nil {
		// 如果存储 MD5 失败,尝试回滚已存储的文档
		logrus.Errorf("存储文件 MD5 失败,尝试回滚: %v", err)
		s.deleteOldDocuments(ctx, filePath)
		return fmt.Errorf("存储文件 MD5 失败: %w", err)
	}

	logrus.Infof("文件索引成功: %s (MD5: %s, 文档块数: %d)", filePath, fileMD5, len(docs))
	return nil
}


// AutoIndex 定时自动索引
// 参数:
//   - ctx: 上下文
//   - folderPath: 文件夹路径
//   - interval: 间隔时间
func (s *IndexerService) AutoIndex(ctx context.Context, folderPath string, interval time.Duration) {
	ticker := time.NewTicker(interval)
	defer ticker.Stop()

	for {
		select {
		case <-ctx.Done():
			return // 上下文取消,退出
		case <-ticker.C:
			logrus.Info("开始自动索引...")
			if err := s.IndexFolder(ctx, folderPath); err != nil {
				logrus.Errorf("自动索引失败: %v", err)
			}
		}
	}
}

retriever_service.go

Go 复制代码
package service

import (
	"context"
	"fmt"
	"sync"

	"eino_redis_rag/config"
	"eino_redis_rag/rag"

	"github.com/cloudwego/eino/schema"
	"github.com/sirupsen/logrus"
)

// RetrieverService 检索服务
// 负责处理文档检索和回答生成相关的操作
type RetrieverService struct {
	ragEngine *rag.RAGEngine // RAG 引擎实例
	config    *config.Config // 配置实例
	mu        sync.RWMutex   // 互斥锁,保护检索器配置的并发访问
}

// NewRetrieverService 创建检索服务
func NewRetrieverService(ragEngine *rag.RAGEngine, config *config.Config) *RetrieverService {
	return &RetrieverService{
		ragEngine: ragEngine,
		config:    config,
	}
}

// RetrieveOptions 检索选项
type RetrieveOptions struct {
	// TopK 返回的最相似文档数量,0 表示使用配置默认值
	TopK int
	// DistanceThreshold 距离阈值,nil 表示使用配置默认值
	DistanceThreshold *float64
	// Filter 过滤条件,支持按元数据字段过滤
	Filter map[string]string
	// UseHybrid 是否使用混合检索(向量+全文)
	UseHybrid bool
	// RRF 融合权重 [向量权重, 全文权重]
	RRFWeights [2]float64
}

// Retrieve 检索相关文档
func (s *RetrieverService) Retrieve(ctx context.Context, query string) ([]*schema.Document, error) {
	return s.RetrieveWithOptions(ctx, query, &RetrieveOptions{})
}

// RetrieveWithOptions 使用自定义选项检索相关文档
// 使用互斥锁保护配置的并发修改
func (s *RetrieverService) RetrieveWithOptions(ctx context.Context, query string, opts *RetrieveOptions) ([]*schema.Document, error) {
	if opts == nil {
		opts = &RetrieveOptions{}
	}

	// 如果启用混合检索
	if opts.UseHybrid {
		return s.hybridRetrieve(ctx, query, opts)
	}

	// 使用写锁保护配置修改
	s.mu.Lock()
	originalTopK := s.ragEngine.RetrieverConfig.TopK
	originalThreshold := s.ragEngine.RetrieverConfig.DistanceThreshold

	if opts.TopK > 0 {
		s.ragEngine.RetrieverConfig.TopK = opts.TopK
	}
	if opts.DistanceThreshold != nil {
		s.ragEngine.RetrieverConfig.DistanceThreshold = opts.DistanceThreshold
	}
	s.mu.Unlock()

	// 构建带过滤的查询
	var searchQuery string
	if len(opts.Filter) > 0 {
		searchQuery = s.buildFilterQuery(query, opts.Filter)
	} else {
		searchQuery = query
	}

	// 执行检索
	docs, err := s.ragEngine.Retriever.Retrieve(ctx, searchQuery)

	// 恢复原始配置
	s.mu.Lock()
	s.ragEngine.RetrieverConfig.TopK = originalTopK
	s.ragEngine.RetrieverConfig.DistanceThreshold = originalThreshold
	s.mu.Unlock()

	if err != nil {
		return nil, fmt.Errorf("检索失败: %w", err)
	}

	logrus.Debugf("检索到 %d 个文档", len(docs))
	return docs, nil
}

// hybridRetrieve 混合检索:结合向量检索和全文检索,使用 RRF 融合
func (s *RetrieverService) hybridRetrieve(ctx context.Context, query string, opts *RetrieveOptions) ([]*schema.Document, error) {
	vectorWeight := opts.RRFWeights[0]
	textWeight := opts.RRFWeights[1]
	if vectorWeight == 0 {
		vectorWeight = 0.6
	}
	if textWeight == 0 {
		textWeight = 0.4
	}

	topK := opts.TopK
	if topK == 0 {
		topK = s.config.Retriever.TopK
	}

	// 1. 向量检索
	s.mu.Lock()
	originalTopK := s.ragEngine.RetrieverConfig.TopK
	s.ragEngine.RetrieverConfig.TopK = topK * 2 // 取更多结果用于融合
	s.mu.Unlock()

	vectorDocs, err := s.ragEngine.Retriever.Retrieve(ctx, query)
	if err != nil {
		return nil, fmt.Errorf("向量检索失败: %w", err)
	}

	// 恢复 TopK
	s.mu.Lock()
	s.ragEngine.RetrieverConfig.TopK = originalTopK
	s.mu.Unlock()

	// 2. 全文检索
	textDocs, err := s.fullTextSearch(ctx, query, topK*2)
	if err != nil {
		logrus.Warnf("全文检索失败,仅使用向量检索结果: %v", err)
		textDocs = []*schema.Document{}
	}

	// 3. RRF 融合
	mergedDocs := s.rrfFusion(vectorDocs, textDocs, vectorWeight, textWeight, topK)

	logrus.Debugf("混合检索: 向量=%d, 全文=%d, 融合后=%d", len(vectorDocs), len(textDocs), len(mergedDocs))
	return mergedDocs, nil
}

// fullTextSearch 全文检索
func (s *RetrieverService) fullTextSearch(ctx context.Context, query string, limit int) ([]*schema.Document, error) {
	// 构建全文搜索查询
	searchQuery := fmt.Sprintf("@content:(%s)", escapeQuery(query))

	result, err := s.ragEngine.Redis.Do(ctx, "FT.SEARCH", s.config.Index.Name, searchQuery, "LIMIT", 0, limit).Result()
	if err != nil {
		return nil, fmt.Errorf("全文搜索失败: %w", err)
	}

	// 解析搜索结果
	docs, err := parseSearchResult(result)
	if err != nil {
		return nil, fmt.Errorf("解析搜索结果失败: %w", err)
	}

	return docs, nil
}

// escapeQuery 转义搜索查询中的特殊字符
func escapeQuery(query string) string {
	// Redis Search 特殊字符转义
	result := query
	for _, ch := range []string{"(", ")", "[", "]", "{", "}", "^", "~", "*", "?", ":", "\"", "\\", "+", "-"} {
		result = replace(result, ch, "\\"+ch)
	}
	return result
}

// replace 简单的字符串替换
func replace(s, old, new string) string {
	result := ""
	for _, r := range s {
		if string(r) == old {
			result += new
		} else {
			result += string(r)
		}
	}
	return result
}

// parseSearchResult 解析 FT.SEARCH 结果
func parseSearchResult(result interface{}) ([]*schema.Document, error) {
	resultSlice, ok := result.([]interface{})
	if !ok || len(resultSlice) == 0 {
		return []*schema.Document{}, nil
	}

	// 第一个元素是匹配总数
	total := parseInt(resultSlice[0])
	if total == 0 {
		return []*schema.Document{}, nil
	}

	var docs []*schema.Document
	// 每两个元素一组:[key, [field1, value1, field2, value2, ...]]
	for i := 1; i < len(resultSlice)-1; i += 2 {
		key := fmt.Sprintf("%v", resultSlice[i])
		fields, ok := resultSlice[i+1].([]interface{})
		if !ok {
			continue
		}

		doc := &schema.Document{
			ID:       key,
			Content:  "",
			MetaData: make(map[string]any),
		}

		// 解析字段
		for j := 0; j < len(fields)-1; j += 2 {
			fieldName := fmt.Sprintf("%v", fields[j])
			fieldValue := fmt.Sprintf("%v", fields[j+1])
			if fieldName == "content" {
				doc.Content = fieldValue
			} else {
				doc.MetaData[fieldName] = fieldValue
			}
		}

		docs = append(docs, doc)
	}

	return docs, nil
}

// parseInt 解析整数
func parseInt(v interface{}) int {
	switch val := v.(type) {
	case int:
		return val
	case int64:
		return int(val)
	case float64:
		return int(val)
	default:
		return 0
	}
}

// buildFilterQuery 构建带过滤条件的查询
func (s *RetrieverService) buildFilterQuery(query string, filter map[string]string) string {
	var filterParts []string
	for k, v := range filter {
		filterParts = append(filterParts, fmt.Sprintf("@%s:{%s}", k, v))
	}
	if len(filterParts) == 0 {
		return query
	}
	return fmt.Sprintf("%s (%s)", join(filterParts, " "), query)
}

// join 字符串连接
func join(ss []string, sep string) string {
	result := ""
	for i, s := range ss {
		if i > 0 {
			result += sep
		}
		result += s
	}
	return result
}

// rrfFusion RRF (Reciprocal Rank Fusion) 融合算法
// 将向量检索和全文检索的结果按排名融合
func (s *RetrieverService) rrfFusion(vectorDocs, textDocs []*schema.Document, vectorWeight, textWeight float64, topK int) []*schema.Document {
	// 构建排名映射: docID -> [vectorRank, textRank]
	type rankInfo struct {
		vectorRank int
		textRank   int
		doc        *schema.Document
	}
	rankMap := make(map[string]*rankInfo)

	// 记录向量检索排名
	for i, doc := range vectorDocs {
		info, ok := rankMap[doc.ID]
		if !ok {
			info = &rankInfo{doc: doc}
			rankMap[doc.ID] = info
		}
		info.vectorRank = i + 1 // 1-based rank
	}

	// 记录全文检索排名
	for i, doc := range textDocs {
		info, ok := rankMap[doc.ID]
		if !ok {
			info = &rankInfo{doc: doc}
			rankMap[doc.ID] = info
		}
		info.textRank = i + 1
	}

	// 计算 RRF 分数
	type scoredDoc struct {
		doc   *schema.Document
		score float64
	}
	var scored []*scoredDoc

	k := 60 // RRF 常数,通常设为 60
	for _, info := range rankMap {
		var score float64
		if info.vectorRank > 0 {
			score += vectorWeight * 1.0 / float64(k+info.vectorRank)
		}
		if info.textRank > 0 {
			score += textWeight * 1.0 / float64(k+info.textRank)
		}
		scored = append(scored, &scoredDoc{doc: info.doc, score: score})
	}

	// 按分数降序排序
	for i := 0; i < len(scored); i++ {
		for j := i + 1; j < len(scored); j++ {
			if scored[i].score < scored[j].score {
				scored[i], scored[j] = scored[j], scored[i]
			}
		}
	}

	// 取 TopK
	result := make([]*schema.Document, 0, topK)
	for i := 0; i < len(scored) && i < topK; i++ {
		result = append(result, scored[i].doc)
	}

	return result
}

// Generate 根据查询生成回答
func (s *RetrieverService) Generate(ctx context.Context, query string) (*schema.StreamReader[*schema.Message], error) {
	return s.GenerateWithOptions(ctx, query, &GenerateOptions{})
}

// GenerateOptions 生成选项
type GenerateOptions struct {
	// TopK 返回的最相似文档数量
	TopK int
	// DistanceThreshold 距离阈值
	DistanceThreshold *float64
	// Filter 过滤条件
	Filter map[string]string
	// UseHybrid 是否使用混合检索
	UseHybrid bool
}

// GenerateWithOptions 使用自定义选项生成回答
func (s *RetrieverService) GenerateWithOptions(ctx context.Context, query string, opts *GenerateOptions) (*schema.StreamReader[*schema.Message], error) {
	if opts == nil {
		opts = &GenerateOptions{}
	}

	// 检索相关文档
	docs, err := s.RetrieveWithOptions(ctx, query, &RetrieveOptions{
		TopK:              opts.TopK,
		DistanceThreshold: opts.DistanceThreshold,
		Filter:            opts.Filter,
		UseHybrid:         opts.UseHybrid,
	})
	if err != nil {
		return nil, err
	}

	// 如果指定了距离阈值,过滤掉不相关的文档
	if opts.DistanceThreshold != nil {
		filtered := make([]*schema.Document, 0)
		for _, doc := range docs {
			if dist, ok := doc.MetaData["distance"]; ok {
				if distFloat, ok := dist.(float64); ok && distFloat <= *opts.DistanceThreshold {
					filtered = append(filtered, doc)
				}
			} else {
				filtered = append(filtered, doc)
			}
		}
		docs = filtered
	}

	logrus.Debugf("使用 %d 个文档生成回答", len(docs))
	return s.ragEngine.GenerateWithDocs(ctx, query, docs)
}

// RetrieveAndGenerate 先检索再生成回答
func (s *RetrieverService) RetrieveAndGenerate(ctx context.Context, query string) (*schema.StreamReader[*schema.Message], error) {
	return s.ragEngine.Generate(ctx, query)
}

watcher_service.go

Go 复制代码
package service

import (
	"context"
	"path/filepath"
	"strings"
	"sync"
	"time"

	"eino_redis_rag/config"

	"github.com/fsnotify/fsnotify"
	"github.com/sirupsen/logrus"
)

// WatcherService 文件监听服务
// 使用 fsnotify 监听目录下的文件变化,自动触发索引或删除操作
type WatcherService struct {
	indexerService *IndexerService // 索引服务实例
	config         *config.Config  // 配置实例
	watcher        *fsnotify.Watcher // fsnotify 监听器
	mu             sync.Mutex      // 互斥锁,防止并发索引
	debounceTimer  *time.Timer     // 防抖定时器
	isWatching     bool            // 是否正在监听
}

// NewWatcherService 创建文件监听服务
// 参数:
//   - indexerService: 索引服务实例
//   - config: 配置实例
//
// 返回:
//   - *WatcherService: 文件监听服务实例
func NewWatcherService(indexerService *IndexerService, config *config.Config) *WatcherService {
	return &WatcherService{
		indexerService: indexerService,
		config:         config,
	}
}

// Start 启动文件监听服务
// 创建 fsnotify 监听器,注册目录监听,进入事件处理循环
// 参数:
//   - ctx: 上下文,用于控制监听服务的生命周期
//
// 返回:
//   - error: 错误信息
func (s *WatcherService) Start(ctx context.Context) error {
	// 检查配置
	if !s.config.Watcher.Enable {
		logrus.Info("文件监听服务未启用")
		return nil
	}

	watchDir := s.config.Watcher.WatchDirectory
	if watchDir == "" {
		logrus.Warn("文件监听目录未配置,跳过启动监听服务")
		return nil
	}

	// 获取防抖时间
	debounceSeconds := s.config.Watcher.DebounceSeconds
	if debounceSeconds <= 0 {
		debounceSeconds = 2 // 默认 2 秒
	}
	logrus.Infof("文件监听服务启动,监听目录: %s, 防抖时间: %d秒", watchDir, debounceSeconds)

	// 创建 fsnotify 监听器
	watcher, err := fsnotify.NewWatcher()
	if err != nil {
		logrus.Errorf("创建文件监听器失败: %v", err)
		return err
	}
	defer watcher.Close()

	s.watcher = watcher
	s.isWatching = true

	// 注册目录监听(启用递归监听)
	err = watcher.Add(watchDir)
	if err != nil {
		logrus.Errorf("注册目录监听失败: %v", err)
		return err
	}
	logrus.Infof("开始监听目录: %s", watchDir)

	// 事件处理循环
	for {
		select {
		case <-ctx.Done():
			logrus.Info("文件监听服务停止")
			s.isWatching = false
			return nil
		case event, ok := <-watcher.Events:
			if !ok {
				return nil
			}
			s.handleEvent(ctx, event, debounceSeconds)
		case err, ok := <-watcher.Errors:
			if !ok {
				return nil
			}
			logrus.Errorf("文件监听错误: %v", err)
		}
	}
}

// handleEvent 处理文件事件
// 根据事件类型执行相应的操作,并使用防抖机制避免重复触发
// 参数:
//   - ctx: 上下文
//   - event: fsnotify 事件
//   - debounceSeconds: 防抖时间(秒)
func (s *WatcherService) handleEvent(ctx context.Context, event fsnotify.Event, debounceSeconds int) {
	// 过滤临时文件(以 ~$ 开头的文件是 Office 临时文件)
	if strings.HasPrefix(filepath.Base(event.Name), "~$") {
		return
	}

	// 检查文件扩展名是否支持
	ext := strings.ToLower(filepath.Ext(event.Name))
	if !s.isSupportedExtension(ext) {
		return
	}

	logrus.Infof("检测到文件事件: %s, 文件: %s", event.Op.String(), event.Name)

	// 防抖处理:等待 debounceSeconds 秒后再执行操作
	s.mu.Lock()
	if s.debounceTimer != nil {
		s.debounceTimer.Stop()
	}

	s.debounceTimer = time.AfterFunc(time.Duration(debounceSeconds)*time.Second, func() {
		s.processFileEvent(ctx, event)
	})
	s.mu.Unlock()
}

// processFileEvent 处理文件事件的核心逻辑
// 根据事件类型执行索引或删除操作
// 参数:
//   - ctx: 上下文
//   - event: fsnotify 事件
func (s *WatcherService) processFileEvent(ctx context.Context, event fsnotify.Event) {
	s.mu.Lock()
	defer s.mu.Unlock()

	// 判断事件类型
	if event.Op&fsnotify.Remove == fsnotify.Remove || event.Op&fsnotify.Rename == fsnotify.Rename {
		// 文件删除或重命名:删除 Redis 中的文档
		s.handleFileRemove(ctx, event.Name)
	} else if event.Op&fsnotify.Write == fsnotify.Write {
		// 文件修改:重新索引文件
		s.handleFileWrite(ctx, event.Name)
	} else if event.Op&fsnotify.Create == fsnotify.Create {
		// 文件创建:索引新文件
		s.handleFileCreate(ctx, event.Name)
	}
}

// handleFileRemove 处理文件删除/重命名事件
// 删除 Redis 中该文件对应的所有文档块和元数据
// 参数:
//   - ctx: 上下文
//   - filePath: 文件路径
func (s *WatcherService) handleFileRemove(ctx context.Context, filePath string) {
	logrus.Infof("处理文件删除事件: %s", filePath)

	// 删除旧文档
	if err := s.indexerService.deleteOldDocuments(ctx, filePath); err != nil {
		logrus.Errorf("删除文档失败: %s, error: %v", filePath, err)
	}

	// 删除元数据
	metadataKey := s.indexerService.getMetadataKey(filePath)
	if _, err := s.indexerService.ragEngine.Redis.Del(ctx, metadataKey).Result(); err != nil {
		logrus.Errorf("删除元数据失败: %s, error: %v", metadataKey, err)
	}

	logrus.Infof("文件删除处理完成: %s", filePath)
}

// handleFileWrite 处理文件修改事件
// 重新索引文件(利用 MD5 去重机制,内容未变则跳过)
// 参数:
//   - ctx: 上下文
//   - filePath: 文件路径
func (s *WatcherService) handleFileWrite(ctx context.Context, filePath string) {
	logrus.Infof("处理文件修改事件: %s", filePath)

	if err := s.indexerService.indexFile(ctx, filePath); err != nil {
		logrus.Errorf("索引文件失败: %s, error: %v", filePath, err)
	} else {
		logrus.Infof("文件修改处理完成: %s", filePath)
	}
}

// handleFileCreate 处理文件创建事件
// 索引新文件
// 参数:
//   - ctx: 上下文
//   - filePath: 文件路径
func (s *WatcherService) handleFileCreate(ctx context.Context, filePath string) {
	logrus.Infof("处理文件创建事件: %s", filePath)

	if err := s.indexerService.indexFile(ctx, filePath); err != nil {
		logrus.Errorf("索引新文件失败: %s, error: %v", filePath, err)
	} else {
		logrus.Infof("新文件索引完成: %s", filePath)
	}
}

// isSupportedExtension 检查文件扩展名是否在支持的列表中
// 参数:
//   - ext: 文件扩展名(如 .md, .pdf)
//
// 返回:
//   - bool: 是否支持
func (s *WatcherService) isSupportedExtension(ext string) bool {
	for _, supportedExt := range s.config.FileLoader.SupportedExtensions {
		if ext == supportedExt {
			return true
		}
	}
	return false
}

// IsWatching 返回当前是否正在监听
// 返回:
//   - bool: 是否正在监听
func (s *WatcherService) IsWatching() bool {
	return s.isWatching
}
相关推荐
iuvtsrt1 小时前
存储过程如何处理海量数据的批处理_循环提交与分段LIMIT结合
jvm·数据库·python
yexuhgu1 小时前
SQL如何检查字符串是否存在:INSTR与LOCATE函数使用
jvm·数据库·python
2301_783848651 小时前
SQL如何用SQL子查询实现关联报表生成_嵌套逻辑关联多表
jvm·数据库·python
techdashen2 小时前
dial9:给 Tokio 装上“飞行记录仪“
java·数据库·redis
2501_901006472 小时前
Golang怎么用gRPC Gateway_Golang gRPC Gateway教程【经典】
jvm·数据库·python
2501_901200532 小时前
golang如何实现错误预算Error Budget计算_golang错误预算Error Budget计算实现实战
jvm·数据库·python
2401_867623982 小时前
如何解决OUI图形界面无法调用_xhost与DISPLAY变量设置
jvm·数据库·python
czlczl200209252 小时前
Mysql读写分离的过期读问题
数据库·mysql
2401_824697663 小时前
CSS如何实现元素反转特效_使用transform-scaleX(-1)操作
jvm·数据库·python