eino_redis_rag
config\config.go
复制代码
# RAG 系统配置文件
# Redis 数据库配置
redis:
host: "100.102.39.213"
port: 6379
db: 0
password: "redis2026"
protocol: 2
unstable_resp3: true
# 向量嵌入模型配置
embedding:
api_key: "dsdsdsds"
base_url: "http://100.102.39.213:11434/v1"
model: "ollama.rnd.huawei.com/library/bge-m3:latest"
timeout_seconds: 30
dimension: 1024
# 聊天模型配置
chat_model:
api_key: "dsdsdsds"
base_url: "http://100.102.39.213:8000/v1"
model: "Qwen/Qwen3-Coder-Next"
timeout_seconds: 30
# Redis 索引配置
index:
name: "OuterIndex"
key_prefix: "OuterCyrex:"
vector_field: "vector_content"
# 文件加载配置
file_loader:
# 支持的文件扩展名
supported_extensions:
- ".md"
- ".txt"
- ".pdf"
- ".xlsx"
# 是否使用文件名作为文档 ID
use_name_as_id: true
# 检索配置
retriever:
# 返回最相似的文档数量
top_k: 5
# 距离阈值,nil 表示不限制
distance_threshold: null
# 使用的 Dialect 版本
dialect: 2
# 要返回的字段
return_fields:
- "vector_content"
- "content"
# 服务配置
server:
# 服务端口
port: 8080
# 是否启用定时任务
enable_cron: true
# 定时任务执行时间(cron 表达式)
cron_schedule: "0 0 2 * * ?" # 每天凌晨 2 点执行
# 文件监听配置
watcher:
# 是否启用文件监听
enable: true
# 监听的目录路径
watch_directory: "./test_txt"
# 防抖时间(秒),避免短时间内多次触发
debounce_seconds: 2
rag
rag.go
复制代码
package rag
import (
"context"
"fmt"
"time"
"eino_redis_rag/config"
"github.com/cloudwego/eino-ext/components/document/loader/file"
pdfParser "github.com/cloudwego/eino-ext/components/document/parser/pdf"
embedding2 "github.com/cloudwego/eino-ext/components/embedding/openai"
redisInd "github.com/cloudwego/eino-ext/components/indexer/redis"
redisRet "github.com/cloudwego/eino-ext/components/retriever/redis"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/redis/go-redis/v9"
)
// RAGEngine 是 RAG(Retrieval-Augmented Generation)引擎的核心结构体
// 负责管理所有 RAG 组件:嵌入模型、文件加载器、文档分割器、检索器、索引器和聊天模型
type RAGEngine struct {
indexName string // Redis 向量索引名称
prefix string // Redis 文档键前缀
dimension int // 向量维度
Redis *redis.Client // Redis 客户端(导出字段,供外部访问)
Embedder *embedding2.Embedder // 嵌入模型,用于将文本转换为向量
Loader *file.FileLoader // 文件加载器,用于加载文档
Splitter document.Transformer // 文档分割器,用于将文档分割成块
Retriever *redisRet.Retriever // 检索器,用于从 Redis 检索相关文档
RetrieverConfig *redisRet.RetrieverConfig // 检索器配置(导出字段,支持动态修改 TopK)
Indexer *redisInd.Indexer // 索引器,用于将文档存储到 Redis
ChatModel *openai.ChatModel // 聊天模型,用于生成回答
}
// NewRAGEngine 初始化 RAG 引擎
// 创建 Redis 客户端、嵌入模型和聊天模型,并返回 RAGEngine 实例
func NewRAGEngine(ctx context.Context, config *config.Config) (*RAGEngine, error) {
// 创建 Redis 客户端,用于连接 Redis 数据库
redisClient := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port),
Password: config.Redis.Password,
DB: config.Redis.DB,
Protocol: config.Redis.Protocol,
UnstableResp3: config.Redis.UnstableResp3,
})
// 创建嵌入模型,用于将文本转换为向量
timeout := time.Duration(config.Embedding.TimeoutSeconds) * time.Second
embedder, err := embedding2.NewEmbedder(ctx, &embedding2.EmbeddingConfig{
APIKey: config.Embedding.APIKey,
BaseURL: config.Embedding.BaseURL,
Model: config.Embedding.Model,
Timeout: timeout,
})
if err != nil {
return nil, err
}
// 创建聊天模型,用于生成回答
chatTimeout := time.Duration(config.ChatModel.TimeoutSeconds) * time.Second
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
APIKey: config.ChatModel.APIKey,
BaseURL: config.ChatModel.BaseURL,
Model: config.ChatModel.Model,
Timeout: chatTimeout,
})
if err != nil {
return nil, err
}
// 返回 RAGEngine 实例
return &RAGEngine{
indexName: config.Index.Name,
prefix: config.Index.KeyPrefix,
dimension: config.Embedding.Dimension,
Redis: redisClient,
Embedder: embedder,
ChatModel: chatModel,
}, nil
}
// InitLoader 初始化文件加载器
// 使用配置中的参数创建文件加载器,用于加载文档
// 使用 ExtParser 注册不同文件类型的解析器:PDF、XLSX、文本等
// PDF 使用 eino-ext 的 pdfParser(基于 pdfcpu,正确提取 Unicode 文本避免乱码)
// XLSX 使用自定义的 XlsxParserForWholeFile(将整个 Excel 解析为结构化文本)
func (r *RAGEngine) InitLoader(ctx context.Context, cfg *config.Config) error {
// 注册 PDF 解析器
// 使用 eino-ext 的 pdfParser,基于 pdfcpu 库正确提取 PDF 中的 Unicode 文本
// ToPages=false 表示将整个 PDF 解析为单个文档,保留完整上下文
pdfP, err := pdfParser.NewPDFParser(ctx, &pdfParser.Config{
ToPages: false, // 整个 PDF 解析为单个文档
})
if err != nil {
return fmt.Errorf("创建 PDF 解析器失败: %w", err)
}
// 注册 XLSX 解析器
// 使用自定义的 XlsxParserForWholeFile,将整个 Excel 文件解析为结构化文本
// 这种方式保留了表格的完整上下文,适合 RAG 场景
xlsxP := NewXlsxParserForWholeFile()
// 创建 ExtParser,支持按文件扩展名选择不同解析器
extParser, err := parser.NewExtParser(ctx, &parser.ExtParserConfig{
Parsers: map[string]parser.Parser{
".pdf": pdfP, // PDF 使用 eino-ext pdfParser(避免乱码)
".xlsx": xlsxP, // XLSX 使用自定义解析器(结构化文本输出)
".xls": xlsxP, // XLS 也使用相同解析器
".md": parser.TextParser{}, // Markdown 使用文本解析器
".txt": parser.TextParser{}, // 纯文本使用文本解析器
},
FallbackParser: parser.TextParser{}, // 未知类型回退到文本解析器
})
if err != nil {
return fmt.Errorf("创建 ExtParser 失败: %w", err)
}
l, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{
UseNameAsID: cfg.FileLoader.UseNameAsID, // 是否使用文件名作为文档 ID
Parser: extParser, // 使用自定义的 ExtParser
})
if err != nil {
return err
}
r.Loader = l
return nil
}
// InitSplitter 初始化文档分割器
// 创建 Markdown 分割器,用于将 Markdown 文档分割成块
func (r *RAGEngine) InitSplitter(ctx context.Context) error {
// 这里使用 Markdown 分割器,可以根据需要修改为其他类型
splitter, err := NewSplitter(ctx)
if err != nil {
return err
}
r.Splitter = splitter
return nil
}
// InitIndexer 初始化索引器
// 使用配置中的参数创建索引器,用于将文档存储到 Redis
func (r *RAGEngine) InitIndexer(ctx context.Context) error {
i, err := redisInd.NewIndexer(ctx, &redisInd.IndexerConfig{
Client: r.Redis,
KeyPrefix: r.prefix,
DocumentToHashes: nil, // 使用默认的文档到哈希转换
BatchSize: 10, // 批量处理大小
Embedding: r.Embedder, // 使用嵌入模型将文本转换为向量
})
if err != nil {
return err
}
r.Indexer = i
return nil
}
// InitRetriever 初始化检索器
// 使用配置中的参数创建检索器,用于从 Redis 检索相关文档
func (r *RAGEngine) InitRetriever(ctx context.Context, cfg *config.Config) error {
retrieverConfig := &redisRet.RetrieverConfig{
Client: r.Redis,
Index: r.indexName,
VectorField: cfg.Index.VectorField,
DistanceThreshold: cfg.Retriever.DistanceThreshold,
Dialect: cfg.Retriever.Dialect,
ReturnFields: cfg.Retriever.ReturnFields,
DocumentConverter: nil,
TopK: cfg.Retriever.TopK,
Embedding: r.Embedder,
}
re, err := redisRet.NewRetriever(ctx, retrieverConfig)
if err != nil {
return err
}
r.Retriever = re
// 保存配置引用,支持动态修改 TopK
r.RetrieverConfig = retrieverConfig
return nil
}
// InitVectorIndex 初始化向量索引(FT.CREATE)
// 在 Redis 中创建向量索引,用于高效的向量搜索
func (r *RAGEngine) InitVectorIndex(ctx context.Context) error {
// 检查索引是否已存在
if _, err := r.Redis.Do(ctx, "FT.INFO", r.indexName).Result(); err == nil {
return nil // 索引已存在,直接返回
}
// 创建索引
// FT.CREATE 命令用于创建向量索引
// ON HASH: 指定索引类型为 HASH
// PREFIX: 指定键前缀
// SCHEMA: 定义索引字段
// - content: TEXT 类型,用于全文搜索
// - vector_content: VECTOR 类型,用于向量搜索
// - FLAT: 简单的暴力搜索算法
// - TYPE FLOAT32: 向量数据类型
// - DIM: 向量维度
// - DISTANCE_METRIC COSINE: 使用余弦相似度计算距离
createIndexArgs := []interface{}{
"FT.CREATE", r.indexName,
"ON", "HASH",
"PREFIX", "1", r.prefix,
"SCHEMA",
"content", "TEXT",
"vector_content", "VECTOR", "FLAT",
"6",
"TYPE", "FLOAT32",
"DIM", r.dimension,
"DISTANCE_METRIC", "COSINE",
}
if err := r.Redis.Do(ctx, createIndexArgs...).Err(); err != nil {
return err
}
// 验证索引创建成功
if _, err := r.Redis.Do(ctx, "FT.INFO", r.indexName).Result(); err != nil {
return err
}
return nil
}
// Generate 根据查询生成回答
// 1. 使用检索器检索相关文档
// 2. 构建提示词,将检索到的文档和查询组合
// 3. 调用聊天模型生成回答
func (r *RAGEngine) Generate(ctx context.Context, query string) (*schema.StreamReader[*schema.Message], error) {
// 检索相关文档
docs, err := r.Retriever.Retrieve(ctx, query)
if err != nil {
return nil, err
}
// 使用 GenerateWithDocs 生成回答
return r.GenerateWithDocs(ctx, query, docs)
}
// GenerateWithDocs 使用指定的文档生成回答
// 这个方法允许调用者控制使用哪些文档,支持更精细的检索控制
// 参数:
// - ctx: 上下文
// - query: 查询字符串
// - docs: 用于生成回答的文档列表
//
// 返回:
// - *schema.StreamReader[*schema.Message]: 流式回答生成器
// - error: 错误信息
func (r *RAGEngine) GenerateWithDocs(ctx context.Context, query string, docs []*schema.Document) (*schema.StreamReader[*schema.Message], error) {
// 将文档列表转换为字符串内容
documentsStr := documentsToString(docs)
// 构建提示词
// systemPrompt: 系统提示词,定义角色和行为
// userMessage: 用户查询
tpl := prompt.FromMessages(schema.FString, []schema.MessagesTemplate{
schema.SystemMessage(systemPrompt),
schema.UserMessage("question: {content}"),
}...)
// 格式化提示词,将检索到的文档内容和查询插入到提示词中
messages, err := tpl.Format(ctx, map[string]any{
"documents": documentsStr,
"content": query,
})
if err != nil {
return nil, err
}
// 生成回答
return r.ChatModel.Stream(ctx, messages)
}
// documentsToString 将文档列表转换为可读的字符串格式
func documentsToString(docs []*schema.Document) string {
if len(docs) == 0 {
return "(无相关文档)"
}
var result string
for i, doc := range docs {
result += fmt.Sprintf("--- 文档 %d ---\n%s\n", i+1, doc.Content)
}
return result
}
// systemPrompt 是发送给大语言模型的系统提示词
// 定义了模型的角色和行为规范
// 关键改进:明确要求只使用提供的文档内容回答问题,避免幻觉
var systemPrompt = `
# Role: Student Learning Assistant
# Language: Chinese
# Critical Rules:
1. You MUST ONLY use the information provided in the documents below to answer questions.
2. If the documents do not contain information relevant to the question, you MUST state "根据现有资料,我无法找到相关信息来回答这个问题。"
3. Do NOT make up information or use your own knowledge to answer.
4. Do NOT mention people, facts, or data that are not in the provided documents.
- When providing assistance:
• Be clear and concise
• Only reference information from the provided documents
• If multiple documents are provided, clearly indicate which source each piece of information comes from
here's documents searched for you:
==== doc start ====
{documents}
==== doc end ====
`
splitter.go
Go
复制代码
package rag
import (
"context"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
"github.com/cloudwego/eino/components/document"
)
// NewSplitter 初始化 Markdown 分割器
// 用于将 Markdown 文档按照标题层级分割成块
// 参数:
// - ctx: 上下文,用于控制请求的生命周期
//
// 返回:
// - document.Transformer: 文档转换器接口
// - error: 错误信息
func NewSplitter(ctx context.Context) (document.Transformer, error) {
// 创建 Markdown 标题分割器
// Headers: 定义标题层级映射
// - "#": "title" - 将 # 开头的行作为标题字段
// TrimHeaders: 是否移除标题行
t, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
Headers: map[string]string{
"#": "title",
},
TrimHeaders: false, // 保留标题行
})
if err != nil {
return nil, err
}
return t, nil
}
xlsx_parser.go
Go
复制代码
package rag
import (
"bytes"
"context"
"fmt"
"io"
"strings"
"github.com/cloudwego/eino/components/document/parser"
"github.com/cloudwego/eino/schema"
"github.com/xuri/excelize/v2"
)
// readAllFromReader 从 io.Reader 中读取所有内容到 bytes.Reader
// 因为 excelize.OpenReader 需要 io.ReaderAt 接口,而 eino 的 parser.Parser 只传递 io.Reader
// bytes.Reader 实现了 io.ReaderAt 接口,所以可以用它来包装读取的数据
func readAllFromReader(reader io.Reader) (*bytes.Reader, error) {
buf, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("读取数据失败: %w", err)
}
return bytes.NewReader(buf), nil
}
const (
xlsxMetaDataRow = "_row"
xlsxMetaDataExt = "_ext"
xlsxMetaDataSheet = "_sheet"
)
// XlsxParserConfig 用于配置 XlsxParser
type XlsxParserConfig struct {
// SheetName 设置要解析的工作表名称,默认为第一个工作表
SheetName string
// NoHeader 设置为 true 时表示第一行不是表头
NoHeader bool
// IDPrefix 自定义文档 ID 前缀
IDPrefix string
}
// XlsxParser 自定义解析器,用于解析 XLSX 文件内容
// 支持有表头和无表头的 XLSX 文件
// 支持从多个工作表中选择特定表格
// 支持自定义文档 ID 前缀
type XlsxParser struct {
config *XlsxParserConfig
}
// NewXlsxParser 创建一个新的 XlsxParser
func NewXlsxParser(ctx context.Context, config *XlsxParserConfig) (parser.Parser, error) {
if config == nil {
config = &XlsxParserConfig{}
}
return &XlsxParser{config: config}, nil
}
// generateID 根据配置生成文档 ID
func (p *XlsxParser) generateID(sheetName string, rowIdx int) string {
if p.config.IDPrefix != "" {
return fmt.Sprintf("%s%s:%d", p.config.IDPrefix, sheetName, rowIdx)
}
return fmt.Sprintf("%s:%d", sheetName, rowIdx)
}
// buildRowContent 将行数据构建为可读的文本内容
func (p *XlsxParser) buildRowContent(row []string, headers []string) string {
if !p.config.NoHeader && len(headers) > 0 {
// 有表头时,使用 "表头: 值" 的格式
var parts []string
for j, header := range headers {
if j < len(row) {
value := strings.TrimSpace(row[j])
if value != "" {
parts = append(parts, fmt.Sprintf("%s: %s", header, value))
}
}
}
return strings.Join(parts, "\n")
}
// 无表头时,直接用制表符连接
contentParts := make([]string, len(row))
for j, cell := range row {
contentParts[j] = strings.TrimSpace(cell)
}
return strings.Join(contentParts, "\t")
}
// Parse 解析 XLSX 内容
// 将所有工作表的数据解析为文档列表
func (p *XlsxParser) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
// 先读取所有内容到内存,bytes.Reader 实现了 io.ReaderAt 接口
dataReader, err := readAllFromReader(reader)
if err != nil {
return nil, fmt.Errorf("读取 Excel 数据失败: %w", err)
}
// 使用 OpenReader 打开 Excel 文件(bytes.Reader 实现了 io.ReaderAt 接口)
xlFile, err := excelize.OpenReader(dataReader)
if err != nil {
return nil, fmt.Errorf("打开 Excel 文件失败: %w", err)
}
defer xlFile.Close()
// 获取所有工作表
sheets := xlFile.GetSheetList()
if len(sheets) == 0 {
return nil, nil
}
var ret []*schema.Document
// 确定要解析的工作表
var targetSheets []string
if p.config.SheetName != "" {
targetSheets = []string{p.config.SheetName}
} else {
targetSheets = sheets
}
for _, sheetName := range targetSheets {
// 获取该工作表的所有行(表头 + 数据行)
rows, err := xlFile.GetRows(sheetName)
if err != nil {
// 如果某个工作表读取失败,记录错误但继续处理其他工作表
continue
}
if len(rows) == 0 {
continue
}
// 处理表头
startIdx := 0
var headers []string
if !p.config.NoHeader && len(rows) > 0 {
headers = rows[0]
startIdx = 1
}
// 处理数据行
for i := startIdx; i < len(rows); i++ {
row := rows[i]
if len(row) == 0 {
continue
}
// 检查行是否全为空
allEmpty := true
for _, cell := range row {
if strings.TrimSpace(cell) != "" {
allEmpty = false
break
}
}
if allEmpty {
continue
}
// 构建内容
content := p.buildRowContent(row, headers)
if content == "" {
continue
}
// 构建元数据
// 注意:Redis 无法序列化嵌套的 map[string]any,所以元数据只能包含简单类型(string, int, float64 等)
meta := make(map[string]any)
meta[xlsxMetaDataSheet] = sheetName
meta[xlsxMetaDataRow] = i - startIdx + 1
// 如果有表头,将表头信息存储为字符串(而不是嵌套map)
if !p.config.NoHeader && len(headers) > 0 {
meta["_headers"] = strings.Join(headers, ", ")
// 将每个表头-值对存储为单独的元数据字段
for j, header := range headers {
if j < len(row) {
// 使用 _col_ 前缀避免冲突
meta["_col_"+header] = strings.TrimSpace(row[j])
}
}
}
// 创建文档
doc := &schema.Document{
ID: p.generateID(sheetName, i),
Content: content,
MetaData: meta,
}
ret = append(ret, doc)
}
}
return ret, nil
}
// ParseAllSheetsToString 将整个 XLSX 文件的所有工作表解析为单个字符串内容
// 这种方式适合将整个 Excel 文件作为单个文档处理
func ParseAllSheetsToString(ctx context.Context, reader io.Reader) (string, error) {
// 先读取所有内容到内存,bytes.Reader 实现了 io.ReaderAt 接口
dataReader, err := readAllFromReader(reader)
if err != nil {
return "", fmt.Errorf("读取 Excel 数据失败: %w", err)
}
// 使用 OpenReader 打开 Excel 文件(bytes.Reader 实现了 io.ReaderAt 接口)
xlFile, err := excelize.OpenReader(dataReader)
if err != nil {
return "", fmt.Errorf("打开 Excel 文件失败: %w", err)
}
defer xlFile.Close()
sheets := xlFile.GetSheetList()
if len(sheets) == 0 {
return "", nil
}
var result strings.Builder
for _, sheetName := range sheets {
rows, err := xlFile.GetRows(sheetName)
if err != nil {
continue
}
if len(rows) == 0 {
continue
}
// 写入工作表标题
result.WriteString(fmt.Sprintf("【工作表: %s】\n", sheetName))
for i, row := range rows {
if len(row) == 0 {
continue
}
// 检查行是否全为空
allEmpty := true
for _, cell := range row {
if strings.TrimSpace(cell) != "" {
allEmpty = false
break
}
}
if allEmpty {
continue
}
if i == 0 {
// 第一行作为表头,用制表符连接
contentParts := make([]string, len(row))
for j, cell := range row {
contentParts[j] = strings.TrimSpace(cell)
}
result.WriteString(strings.Join(contentParts, "\t"))
result.WriteString("\n")
// 添加分隔线
result.WriteString(strings.Repeat("-", len(row)*10) + "\n")
} else {
// 数据行,用制表符连接
contentParts := make([]string, len(row))
for j, cell := range row {
contentParts[j] = strings.TrimSpace(cell)
}
result.WriteString(strings.Join(contentParts, "\t"))
result.WriteString("\n")
}
}
result.WriteString("\n")
}
return result.String(), nil
}
// XlsxParserForWholeFile 将整个 XLSX 文件解析为单个文档的解析器
// 这种方式更适合 RAG 场景,因为整个表格的上下文被保留
type XlsxParserForWholeFile struct{}
// NewXlsxParserForWholeFile 创建一个新的整文件解析器
func NewXlsxParserForWholeFile() parser.Parser {
return &XlsxParserForWholeFile{}
}
// Parse 实现 parser.Parser 接口
func (p *XlsxParserForWholeFile) Parse(ctx context.Context, reader io.Reader, opts ...parser.Option) ([]*schema.Document, error) {
content, err := ParseAllSheetsToString(ctx, reader)
if err != nil {
return nil, err
}
if content == "" {
return nil, nil
}
// 注意:Redis 无法序列化嵌套的 map[string]any,所以元数据只能包含简单类型
meta := make(map[string]any)
doc := &schema.Document{
ID: "1",
Content: content,
MetaData: meta,
}
return []*schema.Document{doc}, nil
}
service
indexer_service.go
Go
复制代码
package service
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"eino_redis_rag/config"
"eino_redis_rag/rag"
"github.com/cloudwego/eino/components/document"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
)
// IndexerService 索引服务
// 负责处理文档索引相关的操作
type IndexerService struct {
ragEngine *rag.RAGEngine // RAG 引擎实例
config *config.Config // 配置实例
}
// NewIndexerService 创建索引服务
// 参数:
// - ragEngine: RAG 引擎实例
// - config: 配置实例
//
// 返回:
// - *IndexerService: 索引服务实例
func NewIndexerService(ragEngine *rag.RAGEngine, config *config.Config) *IndexerService {
return &IndexerService{
ragEngine: ragEngine,
config: config,
}
}
// IndexFolder 遍历指定文件夹下的所有文件,将可支持读取的文件读取并写入向量数据库中
// 参数:
// - ctx: 上下文
// - folderPath: 文件夹路径
//
// 返回:
// - error: 错误信息
func (s *IndexerService) IndexFolder(ctx context.Context, folderPath string) error {
// 检查文件夹是否存在
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
return fmt.Errorf("文件夹不存在: %s", folderPath)
}
// 获取文件夹下所有支持的文件
files, err := s.getFilesInFolder(folderPath)
if err != nil {
return err
}
// 如果没有文件,直接返回
if len(files) == 0 {
logrus.Info("文件夹中没有文件")
return nil
}
logrus.Infof("找到 %d 个文件,开始索引", len(files))
// 遍历文件并索引
for _, file := range files {
if err := s.indexFile(ctx, file); err != nil {
logrus.Errorf("索引文件失败: %s, error: %v", file, err)
continue // 继续处理其他文件
}
logrus.Infof("索引文件成功: %s", file)
}
return nil
}
// getFilesInFolder 获取文件夹下的所有文件
// 参数:
// - folderPath: 文件夹路径
//
// 返回:
// - []string: 文件路径列表
// - error: 错误信息
func (s *IndexerService) getFilesInFolder(folderPath string) ([]string, error) {
var files []string
// 递归遍历文件夹
err := filepath.Walk(folderPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 只处理文件,跳过目录
if info.IsDir() {
return nil
}
// 检查文件扩展名是否在支持的列表中
ext := strings.ToLower(filepath.Ext(path))
for _, supportedExt := range s.config.FileLoader.SupportedExtensions {
if ext == supportedExt {
files = append(files, path)
break
}
}
return nil
})
return files, err
}
// calculateFileMD5 计算文件的 MD5 值
// 参数:
// - filePath: 文件路径
//
// 返回:
// - string: MD5 值的十六进制表示
// - error: 错误信息
func (s *IndexerService) calculateFileMD5(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
hash := md5.Sum(data)
return hex.EncodeToString(hash[:]), nil
}
// normalizePath 规范化文件路径,将反斜杠转换为正斜杠
// 确保在不同系统上路径一致
// 参数:
// - path: 文件路径
//
// 返回:
// - string: 规范化后的路径
func normalizePath(path string) string {
return strings.ReplaceAll(path, "\\", "/")
}
// getDocumentKey 获取文档在 Redis 中的键名
// 参数:
// - filePath: 文件路径
// - chunkIndex: 文档块索引
//
// 返回:
// - string: Redis 键名
func (s *IndexerService) getDocumentKey(filePath string, chunkIndex int) string {
// 使用文件路径和块索引生成唯一键名
// 注意:Indexer 会自动添加 KeyPrefix,所以这里只需要提供文件路径和块索引
// 格式:文件路径:块索引
// 例如:test_txt/mysql-1.md:0
// 使用规范化路径,确保在不同系统上键名一致
normalizedPath := normalizePath(filePath)
return fmt.Sprintf("%s:%d", normalizedPath, chunkIndex)
}
// getMetadataKey 获取文件元数据在 Redis 中的键名
// 参数:
// - filePath: 文件路径
//
// 返回:
// - string: Redis 键名
func (s *IndexerService) getMetadataKey(filePath string) string {
// 使用文件路径生成元数据键名
// 格式:keyPrefix_metadata:文件路径
// 例如:OuterCyrex:metadata:test_txt/mysql-1.md
// 使用规范化路径,确保在不同系统上键名一致
normalizedPath := normalizePath(filePath)
return fmt.Sprintf("%s_metadata:%s", s.config.Index.KeyPrefix, normalizedPath)
}
// isFileAlreadyIndexed 检查文件是否已经索引过
// 通过检查 Redis 中是否存在该文件的元数据键(存储 MD5 值)来判断
// 参数:
// - ctx: 上下文
// - filePath: 文件路径
//
// 返回:
// - bool: 是否已索引
func (s *IndexerService) isFileAlreadyIndexed(ctx context.Context, filePath string) (bool, error) {
// 通过检查元数据键是否存在来判断文件是否已索引
metadataKey := s.getMetadataKey(filePath)
_, err := s.ragEngine.Redis.Exists(ctx, metadataKey).Result()
if err != nil {
return false, err
}
// 如果键存在,说明文件已索引
// Exists 返回 1 表示键存在,返回 0 表示键不存在
result, err := s.ragEngine.Redis.Exists(ctx, metadataKey).Result()
if err != nil {
return false, err
}
return result > 0, nil
}
// DeleteDocumentByKey 直接通过 Redis key 删除文档
// 参数:
// - ctx: 上下文
// - key: Redis key
//
// 返回:
// - error: 错误信息
func (s *IndexerService) DeleteDocumentByKey(ctx context.Context, key string) error {
// 同时尝试正斜杠和反斜杠两种格式
normalizedKey := normalizePath(key)
deletedKeys := []string{}
// 检查原始 key 是否存在
exists, _ := s.ragEngine.Redis.Exists(ctx, key).Result()
if exists > 0 {
_, err := s.ragEngine.Redis.Del(ctx, key).Result()
if err != nil {
logrus.Warnf("删除键失败: %s, error: %v", key, err)
} else {
deletedKeys = append(deletedKeys, key)
}
}
// 如果规范化后的 key 不同,也尝试删除
if normalizedKey != key {
exists, _ = s.ragEngine.Redis.Exists(ctx, normalizedKey).Result()
if exists > 0 {
_, err := s.ragEngine.Redis.Del(ctx, normalizedKey).Result()
if err != nil {
logrus.Warnf("删除键失败: %s, error: %v", normalizedKey, err)
} else {
deletedKeys = append(deletedKeys, normalizedKey)
}
}
}
if len(deletedKeys) > 0 {
logrus.Infof("成功删除键: %v", deletedKeys)
} else {
logrus.Warnf("未找到键: %s 和 %s", key, normalizedKey)
}
return nil
}
// deleteOldDocuments 删除文件中所有已索引的文档块
// 在文件内容变更需要重新索引时,先删除旧文档,避免残留过时数据
// 参数:
// - ctx: 上下文
// - filePath: 文件路径
//
// 返回:
// - error: 错误信息
func (s *IndexerService) deleteOldDocuments(ctx context.Context, filePath string) error {
// 同时尝试正斜杠和反斜杠两种路径格式,确保兼容历史数据
normalizedPath := normalizePath(filePath)
originalPath := filePath
keyPatterns := []string{
s.config.Index.KeyPrefix + normalizedPath + ":*",
}
// 如果原始路径与规范化路径不同,也添加原始路径的模式
if originalPath != normalizedPath {
keyPatterns = append(keyPatterns, s.config.Index.KeyPrefix + originalPath + ":*")
}
logrus.Infof("删除旧文档: %s, patterns: %v", filePath, keyPatterns)
deletedCount := 0
// 对每种路径模式进行 SCAN 扫描
scan_loop:
for _, keyPattern := range keyPatterns {
cursor := uint64(0)
for {
result, err := s.ragEngine.Redis.Do(ctx, "SCAN", cursor, "MATCH", keyPattern, "COUNT", 100).Result()
if err != nil {
logrus.Warnf("SCAN 命令执行失败: %v", err)
break
}
// SCAN 返回一个包含两个元素的 slice:[cursor, [keys...]]
resultSlice, ok := result.([]interface{})
if !ok || len(resultSlice) < 2 {
logrus.Debugf("SCAN 结果无法解析: ok=%v, len=%d", ok, func() int { if r, ok := result.([]interface{}); ok { return len(r) }; return 0 }())
break
}
// 更新 cursor - 需要处理多种类型:uint64, float64, string, int64
var newCursor uint64
switch v := resultSlice[0].(type) {
case uint64:
newCursor = v
case float64:
newCursor = uint64(v)
case string:
// Redis RESP3 协议可能返回字符串类型的 cursor
fmt.Sscanf(v, "%d", &newCursor)
case int64:
newCursor = uint64(v)
case int:
newCursor = uint64(v)
default:
logrus.Warnf("无法解析 cursor: type=%T, value=%v", resultSlice[0], resultSlice[0])
break scan_loop
}
cursor = newCursor
// 获取键列表 - 需要处理多种类型
var keys []interface{}
switch k := resultSlice[1].(type) {
case []interface{}:
keys = k
case []string:
for _, s := range k {
keys = append(keys, s)
}
case string:
// 单个键的情况
keys = []interface{}{k}
case nil:
keys = []interface{}{}
default:
logrus.Warnf("无法解析 keys: type=%T, value=%v", resultSlice[1], resultSlice[1])
break
}
logrus.Infof("SCAN pattern=%s 找到 %d 个键: %v", keyPattern, len(keys), func() interface{} { return resultSlice[1] }())
// 删除所有匹配的键
if len(keys) > 0 {
keysStr := make([]string, len(keys))
for i, k := range keys {
keysStr[i] = fmt.Sprintf("%v", k)
}
_, err := s.ragEngine.Redis.Del(ctx, keysStr...).Result()
if err != nil {
logrus.Warnf("删除键失败: %v, error: %v", keysStr, err)
} else {
deletedCount += len(keysStr)
logrus.Infof("成功删除 %d 个键: %v", len(keysStr), keysStr)
}
}
// 如果 cursor 为 0,说明扫描完成
if cursor == 0 {
break
}
}
}
logrus.Infof("删除旧文档完成: %s, 删除了 %d 个文档块", filePath, deletedCount)
return nil
}
// indexFile 索引单个文件
// 1. 计算文件 MD5 值
// 2. 从 Redis 中获取已存储的 MD5 值
// 3. 如果 MD5 值相同,则跳过索引
// 4. 如果 MD5 值不同或不存在,则删除旧文档后重新索引
// 5. 存储新的 MD5 值到 Redis
//
// 参数:
// - ctx: 上下文
// - filePath: 文件路径
//
// 返回:
// - error: 错误信息
func (s *IndexerService) indexFile(ctx context.Context, filePath string) error {
// 步骤1: 计算当前文件的 MD5 值
fileMD5, err := s.calculateFileMD5(filePath)
if err != nil {
return fmt.Errorf("计算文件 MD5 失败: %w", err)
}
// 步骤2: 获取元数据键(用于存储/读取 MD5 值)
metadataKey := s.getMetadataKey(filePath)
// 步骤3: 从 Redis 中获取已存储的 MD5 值
storedMD5, err := s.ragEngine.Redis.Get(ctx, metadataKey).Result()
if err != nil {
if err == redis.Nil {
// 键不存在,说明文件未索引过,需要执行完整索引
logrus.Infof("文件未索引过,开始索引: %s", filePath)
} else {
return fmt.Errorf("获取文件 MD5 失败: %w", err)
}
}
// 步骤4: 如果 MD5 值相同,说明文件内容未变化,跳过索引
if storedMD5 != "" && storedMD5 == fileMD5 {
logrus.Infof("文件未更改,跳过索引: %s (MD5: %s)", filePath, fileMD5)
return nil
}
// 步骤5: 文件内容已变化(或首次索引),需要先删除旧文档
// 如果之前已索引过,删除旧文档以避免残留过时数据
if storedMD5 != "" {
logrus.Infof("文件内容已变化,删除旧文档并重新索引: %s (旧MD5: %s, 新MD5: %s)", filePath, storedMD5, fileMD5)
if err := s.deleteOldDocuments(ctx, filePath); err != nil {
return fmt.Errorf("删除旧文档失败: %w", err)
}
} else {
logrus.Infof("首次索引文件: %s", filePath)
}
// 步骤6: 加载文档
docs, err := s.ragEngine.Loader.Load(ctx, document.Source{
URI: filePath,
})
if err != nil {
return fmt.Errorf("加载文档失败: %w", err)
}
if len(docs) == 0 {
return fmt.Errorf("没有加载到文档内容")
}
// 步骤7: 分割文档为小块
docs, err = s.ragEngine.Splitter.Transform(ctx, docs)
if err != nil {
return fmt.Errorf("分割文档失败: %w", err)
}
if len(docs) == 0 {
return fmt.Errorf("文档分割后没有产生任何块")
}
// 步骤8: 为每个文档块生成基于文件路径的唯一 ID
// 使用规范化路径和块索引生成唯一 ID,确保相同文件的相同块有相同的 ID
for i, d := range docs {
d.ID = s.getDocumentKey(filePath, i)
}
// 步骤9: 存储文档到 Redis
// Indexer.Store 会自动将文档转换为向量并存储到 Redis
_, err = s.ragEngine.Indexer.Store(ctx, docs)
if err != nil {
return fmt.Errorf("存储文档失败: %w", err)
}
// 步骤10: 存储新的 MD5 值到 Redis,用于下次对比
_, err = s.ragEngine.Redis.Set(ctx, metadataKey, fileMD5, 0).Result()
if err != nil {
// 如果存储 MD5 失败,尝试回滚已存储的文档
logrus.Errorf("存储文件 MD5 失败,尝试回滚: %v", err)
s.deleteOldDocuments(ctx, filePath)
return fmt.Errorf("存储文件 MD5 失败: %w", err)
}
logrus.Infof("文件索引成功: %s (MD5: %s, 文档块数: %d)", filePath, fileMD5, len(docs))
return nil
}
// AutoIndex 定时自动索引
// 参数:
// - ctx: 上下文
// - folderPath: 文件夹路径
// - interval: 间隔时间
func (s *IndexerService) AutoIndex(ctx context.Context, folderPath string, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return // 上下文取消,退出
case <-ticker.C:
logrus.Info("开始自动索引...")
if err := s.IndexFolder(ctx, folderPath); err != nil {
logrus.Errorf("自动索引失败: %v", err)
}
}
}
}
retriever_service.go
Go
复制代码
package service
import (
"context"
"eino_redis_rag/config"
"eino_redis_rag/rag"
"github.com/cloudwego/eino/schema"
)
// RetrieverService 检索服务
// 负责处理文档检索和回答生成相关的操作
type RetrieverService struct {
ragEngine *rag.RAGEngine // RAG 引擎实例
config *config.Config // 配置实例
}
// NewRetrieverService 创建检索服务
// 参数:
// - ragEngine: RAG 引擎实例
// - config: 配置实例
//
// 返回:
// - *RetrieverService: 检索服务实例
func NewRetrieverService(ragEngine *rag.RAGEngine, config *config.Config) *RetrieverService {
return &RetrieverService{
ragEngine: ragEngine,
config: config,
}
}
// RetrieveOptions 检索选项
type RetrieveOptions struct {
// TopK 返回的最相似文档数量,0 表示使用配置默认值
TopK int
// DistanceThreshold 距离阈值,nil 表示使用配置默认值
DistanceThreshold *float64
}
// Retrieve 检索相关文档
// 使用查询字符串从 Redis 中检索相关文档
// 参数:
// - ctx: 上下文
// - query: 查询字符串
//
// 返回:
// - []*schema.Document: 检索到的文档列表
// - error: 错误信息
func (s *RetrieverService) Retrieve(ctx context.Context, query string) ([]*schema.Document, error) {
return s.RetrieveWithOptions(ctx, query, &RetrieveOptions{})
}
// RetrieveWithOptions 使用自定义选项检索相关文档
// 参数:
// - ctx: 上下文
// - query: 查询字符串
// - opts: 检索选项
//
// 返回:
// - []*schema.Document: 检索到的文档列表
// - error: 错误信息
func (s *RetrieverService) RetrieveWithOptions(ctx context.Context, query string, opts *RetrieveOptions) ([]*schema.Document, error) {
// 如果需要动态设置 TopK 或 DistanceThreshold,临时修改检索器配置
if opts == nil {
opts = &RetrieveOptions{}
}
// 保存原始值
originalTopK := 0
originalThreshold := s.ragEngine.RetrieverConfig.DistanceThreshold
// 如果指定了 TopK,临时修改
if opts.TopK > 0 {
originalTopK = s.ragEngine.RetrieverConfig.TopK
s.ragEngine.RetrieverConfig.TopK = opts.TopK
}
// 如果指定了 DistanceThreshold,临时修改
if opts.DistanceThreshold != nil {
originalThreshold = s.ragEngine.RetrieverConfig.DistanceThreshold
s.ragEngine.RetrieverConfig.DistanceThreshold = opts.DistanceThreshold
}
// 调用 RAG 引擎的检索器进行检索
docs, err := s.ragEngine.Retriever.Retrieve(ctx, query)
// 恢复原始值
if opts.TopK > 0 {
s.ragEngine.RetrieverConfig.TopK = originalTopK
}
if opts.DistanceThreshold != nil {
s.ragEngine.RetrieverConfig.DistanceThreshold = originalThreshold
}
if err != nil {
return nil, err
}
return docs, nil
}
// Generate 根据查询生成回答
// 先检索相关文档,然后使用聊天模型生成回答
// 参数:
// - ctx: 上下文
// - query: 查询字符串
//
// 返回:
// - *schema.StreamReader[*schema.Message]: 流式回答生成器
// - error: 错误信息
func (s *RetrieverService) Generate(ctx context.Context, query string) (*schema.StreamReader[*schema.Message], error) {
return s.GenerateWithOptions(ctx, query, &GenerateOptions{})
}
// GenerateOptions 生成选项
type GenerateOptions struct {
// TopK 返回的最相似文档数量,0 表示使用配置默认值
TopK int
// DistanceThreshold 距离阈值,nil 表示使用配置默认值
DistanceThreshold *float64
}
// GenerateWithOptions 使用自定义选项生成回答
// 参数:
// - ctx: 上下文
// - query: 查询字符串
// - opts: 生成选项
//
// 返回:
// - *schema.StreamReader[*schema.Message]: 流式回答生成器
// - error: 错误信息
func (s *RetrieverService) GenerateWithOptions(ctx context.Context, query string, opts *GenerateOptions) (*schema.StreamReader[*schema.Message], error) {
if opts == nil {
opts = &GenerateOptions{}
}
// 检索相关文档,使用指定的 TopK
docs, err := s.RetrieveWithOptions(ctx, query, &RetrieveOptions{
TopK: opts.TopK,
DistanceThreshold: opts.DistanceThreshold,
})
if err != nil {
return nil, err
}
// 如果指定了距离阈值,过滤掉不相关的文档
if opts.DistanceThreshold != nil {
filtered := make([]*schema.Document, 0)
for _, doc := range docs {
// 检查 metadata 中的距离值
if dist, ok := doc.MetaData["distance"]; ok {
if distFloat, ok := dist.(float64); ok && distFloat <= *opts.DistanceThreshold {
filtered = append(filtered, doc)
}
} else {
// 如果没有距离信息,保留文档
filtered = append(filtered, doc)
}
}
docs = filtered
}
// 调用 RAG 引擎的生成方法,使用过滤后的文档
return s.ragEngine.GenerateWithDocs(ctx, query, docs)
}
// RetrieveAndGenerate 先检索再生成回答
// 与 Generate 方法相同,保留此方法是为了 API 的一致性
// 参数:
// - ctx: 上下文
// - query: 查询字符串
//
// 返回:
// - *schema.StreamReader[*schema.Message]: 流式回答生成器
// - error: 错误信息
func (s *RetrieverService) RetrieveAndGenerate(ctx context.Context, query string) (*schema.StreamReader[*schema.Message], error) {
return s.ragEngine.Generate(ctx, query)
}
watcher_service.go
Go
复制代码
package service
import (
"context"
"path/filepath"
"strings"
"sync"
"time"
"eino_redis_rag/config"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
)
// WatcherService 文件监听服务
// 使用 fsnotify 监听目录下的文件变化,自动触发索引或删除操作
type WatcherService struct {
indexerService *IndexerService // 索引服务实例
config *config.Config // 配置实例
watcher *fsnotify.Watcher // fsnotify 监听器
mu sync.Mutex // 互斥锁,防止并发索引
debounceTimer *time.Timer // 防抖定时器
isWatching bool // 是否正在监听
}
// NewWatcherService 创建文件监听服务
// 参数:
// - indexerService: 索引服务实例
// - config: 配置实例
//
// 返回:
// - *WatcherService: 文件监听服务实例
func NewWatcherService(indexerService *IndexerService, config *config.Config) *WatcherService {
return &WatcherService{
indexerService: indexerService,
config: config,
}
}
// Start 启动文件监听服务
// 创建 fsnotify 监听器,注册目录监听,进入事件处理循环
// 参数:
// - ctx: 上下文,用于控制监听服务的生命周期
//
// 返回:
// - error: 错误信息
func (s *WatcherService) Start(ctx context.Context) error {
// 检查配置
if !s.config.Watcher.Enable {
logrus.Info("文件监听服务未启用")
return nil
}
watchDir := s.config.Watcher.WatchDirectory
if watchDir == "" {
logrus.Warn("文件监听目录未配置,跳过启动监听服务")
return nil
}
// 获取防抖时间
debounceSeconds := s.config.Watcher.DebounceSeconds
if debounceSeconds <= 0 {
debounceSeconds = 2 // 默认 2 秒
}
logrus.Infof("文件监听服务启动,监听目录: %s, 防抖时间: %d秒", watchDir, debounceSeconds)
// 创建 fsnotify 监听器
watcher, err := fsnotify.NewWatcher()
if err != nil {
logrus.Errorf("创建文件监听器失败: %v", err)
return err
}
defer watcher.Close()
s.watcher = watcher
s.isWatching = true
// 注册目录监听(启用递归监听)
err = watcher.Add(watchDir)
if err != nil {
logrus.Errorf("注册目录监听失败: %v", err)
return err
}
logrus.Infof("开始监听目录: %s", watchDir)
// 事件处理循环
for {
select {
case <-ctx.Done():
logrus.Info("文件监听服务停止")
s.isWatching = false
return nil
case event, ok := <-watcher.Events:
if !ok {
return nil
}
s.handleEvent(ctx, event, debounceSeconds)
case err, ok := <-watcher.Errors:
if !ok {
return nil
}
logrus.Errorf("文件监听错误: %v", err)
}
}
}
// handleEvent 处理文件事件
// 根据事件类型执行相应的操作,并使用防抖机制避免重复触发
// 参数:
// - ctx: 上下文
// - event: fsnotify 事件
// - debounceSeconds: 防抖时间(秒)
func (s *WatcherService) handleEvent(ctx context.Context, event fsnotify.Event, debounceSeconds int) {
// 过滤临时文件(以 ~$ 开头的文件是 Office 临时文件)
if strings.HasPrefix(filepath.Base(event.Name), "~$") {
return
}
// 检查文件扩展名是否支持
ext := strings.ToLower(filepath.Ext(event.Name))
if !s.isSupportedExtension(ext) {
return
}
logrus.Infof("检测到文件事件: %s, 文件: %s", event.Op.String(), event.Name)
// 防抖处理:等待 debounceSeconds 秒后再执行操作
s.mu.Lock()
if s.debounceTimer != nil {
s.debounceTimer.Stop()
}
s.debounceTimer = time.AfterFunc(time.Duration(debounceSeconds)*time.Second, func() {
s.processFileEvent(ctx, event)
})
s.mu.Unlock()
}
// processFileEvent 处理文件事件的核心逻辑
// 根据事件类型执行索引或删除操作
// 参数:
// - ctx: 上下文
// - event: fsnotify 事件
func (s *WatcherService) processFileEvent(ctx context.Context, event fsnotify.Event) {
s.mu.Lock()
defer s.mu.Unlock()
// 判断事件类型
if event.Op&fsnotify.Remove == fsnotify.Remove || event.Op&fsnotify.Rename == fsnotify.Rename {
// 文件删除或重命名:删除 Redis 中的文档
s.handleFileRemove(ctx, event.Name)
} else if event.Op&fsnotify.Write == fsnotify.Write {
// 文件修改:重新索引文件
s.handleFileWrite(ctx, event.Name)
} else if event.Op&fsnotify.Create == fsnotify.Create {
// 文件创建:索引新文件
s.handleFileCreate(ctx, event.Name)
}
}
// handleFileRemove 处理文件删除/重命名事件
// 删除 Redis 中该文件对应的所有文档块和元数据
// 参数:
// - ctx: 上下文
// - filePath: 文件路径
func (s *WatcherService) handleFileRemove(ctx context.Context, filePath string) {
logrus.Infof("处理文件删除事件: %s", filePath)
// 删除旧文档
if err := s.indexerService.deleteOldDocuments(ctx, filePath); err != nil {
logrus.Errorf("删除文档失败: %s, error: %v", filePath, err)
}
// 删除元数据
metadataKey := s.indexerService.getMetadataKey(filePath)
if _, err := s.indexerService.ragEngine.Redis.Del(ctx, metadataKey).Result(); err != nil {
logrus.Errorf("删除元数据失败: %s, error: %v", metadataKey, err)
}
logrus.Infof("文件删除处理完成: %s", filePath)
}
// handleFileWrite 处理文件修改事件
// 重新索引文件(利用 MD5 去重机制,内容未变则跳过)
// 参数:
// - ctx: 上下文
// - filePath: 文件路径
func (s *WatcherService) handleFileWrite(ctx context.Context, filePath string) {
logrus.Infof("处理文件修改事件: %s", filePath)
if err := s.indexerService.indexFile(ctx, filePath); err != nil {
logrus.Errorf("索引文件失败: %s, error: %v", filePath, err)
} else {
logrus.Infof("文件修改处理完成: %s", filePath)
}
}
// handleFileCreate 处理文件创建事件
// 索引新文件
// 参数:
// - ctx: 上下文
// - filePath: 文件路径
func (s *WatcherService) handleFileCreate(ctx context.Context, filePath string) {
logrus.Infof("处理文件创建事件: %s", filePath)
if err := s.indexerService.indexFile(ctx, filePath); err != nil {
logrus.Errorf("索引新文件失败: %s, error: %v", filePath, err)
} else {
logrus.Infof("新文件索引完成: %s", filePath)
}
}
// isSupportedExtension 检查文件扩展名是否在支持的列表中
// 参数:
// - ext: 文件扩展名(如 .md, .pdf)
//
// 返回:
// - bool: 是否支持
func (s *WatcherService) isSupportedExtension(ext string) bool {
for _, supportedExt := range s.config.FileLoader.SupportedExtensions {
if ext == supportedExt {
return true
}
}
return false
}
// IsWatching 返回当前是否正在监听
// 返回:
// - bool: 是否正在监听
func (s *WatcherService) IsWatching() bool {
return s.isWatching
}
config.yaml
Go
复制代码
# RAG 系统配置文件
# Redis 数据库配置
redis:
host: "100.102.39.213"
port: 6379
db: 0
password: "redis2026"
protocol: 2
unstable_resp3: true
# 向量嵌入模型配置
embedding:
api_key: "dsdsdsds"
base_url: "http://100.102.39.213:11434/v1"
model: "ollama.rnd.huawei.com/library/bge-m3:latest"
timeout_seconds: 30
dimension: 1024
# 聊天模型配置
chat_model:
api_key: "dsdsdsds"
base_url: "http://100.102.39.213:8000/v1"
model: "Qwen/Qwen3-Coder-Next"
timeout_seconds: 30
# Redis 索引配置
index:
name: "OuterIndex"
key_prefix: "OuterCyrex:"
vector_field: "vector_content"
# 文件加载配置
file_loader:
# 支持的文件扩展名
supported_extensions:
- ".md"
- ".txt"
- ".pdf"
- ".xlsx"
# 是否使用文件名作为文档 ID
use_name_as_id: true
# 检索配置
retriever:
# 返回最相似的文档数量
top_k: 5
# 距离阈值,nil 表示不限制
distance_threshold: null
# 使用的 Dialect 版本
dialect: 2
# 要返回的字段
return_fields:
- "vector_content"
- "content"
# 服务配置
server:
# 服务端口
port: 8080
# 是否启用定时任务
enable_cron: true
# 定时任务执行时间(cron 表达式)
cron_schedule: "0 0 2 * * ?" # 每天凌晨 2 点执行
# 文件监听配置
watcher:
# 是否启用文件监听
enable: true
# 监听的目录路径
watch_directory: "./test_txt"
# 防抖时间(秒),避免短时间内多次触发
debounce_seconds: 2
main.go
Go
复制代码
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"eino_redis_rag/config"
"eino_redis_rag/rag"
"eino_redis_rag/service"
"github.com/cloudwego/eino/schema"
"github.com/sirupsen/logrus"
)
func main() {
// 加载配置
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
// 初始化 RAG 引擎
ctx := context.Background()
ragEngine, err := rag.NewRAGEngine(ctx, cfg)
if err != nil {
log.Fatalf("初始化 RAG 引擎失败: %v", err)
}
// 初始化组件
if err := ragEngine.InitLoader(ctx, cfg); err != nil {
log.Fatalf("初始化文件加载器失败: %v", err)
}
if err := ragEngine.InitSplitter(ctx); err != nil {
log.Fatalf("初始化文档分割器失败: %v", err)
}
if err := ragEngine.InitIndexer(ctx); err != nil {
log.Fatalf("初始化索引器失败: %v", err)
}
if err := ragEngine.InitRetriever(ctx, cfg); err != nil {
log.Fatalf("初始化检索器失败: %v", err)
}
if err := ragEngine.InitVectorIndex(ctx); err != nil {
log.Fatalf("初始化向量索引失败: %v", err)
}
// 创建服务
indexerService := service.NewIndexerService(ragEngine, cfg)
retrieverService := service.NewRetrieverService(ragEngine, cfg)
watcherService := service.NewWatcherService(indexerService, cfg)
// 创建 HTTP 服务器
mux := http.NewServeMux()
// 注册路由
mux.HandleFunc("/api/v1/index", func(w http.ResponseWriter, r *http.Request) {
indexHandler(w, r, indexerService)
})
mux.HandleFunc("/api/v1/retrieve", func(w http.ResponseWriter, r *http.Request) {
retrieveHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/generate", func(w http.ResponseWriter, r *http.Request) {
generateHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/retrieve-and-generate", func(w http.ResponseWriter, r *http.Request) {
retrieveAndGenerateHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/document/delete", func(w http.ResponseWriter, r *http.Request) {
deleteDocumentHandler(w, r, indexerService)
})
// 创建退出信号通道
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// 启动定时任务
if cfg.Server.EnableCron {
go func() {
interval := time.Hour * 24 // 每天执行一次
indexerService.AutoIndex(ctx, "./test_txt", interval)
}()
}
// 启动文件监听服务
var watcherCancel context.CancelFunc
if cfg.Watcher.Enable {
watcherCtx, cancel := context.WithCancel(ctx)
watcherCancel = cancel
go func() {
if err := watcherService.Start(watcherCtx); err != nil {
logrus.Errorf("文件监听服务启动失败: %v", err)
}
}()
logrus.Info("文件监听服务已启动")
}
// 启动服务器
port := fmt.Sprintf(":%d", cfg.Server.Port)
logrus.Infof("服务器启动在端口 %d", cfg.Server.Port)
srv := &http.Server{
Addr: port,
Handler: mux,
}
// 启动服务器 goroutine
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("启动服务器失败: %v", err)
}
}()
// 等待退出信号
<-quit
logrus.Info("正在关闭服务器...")
// 停止文件监听服务
if watcherCancel != nil {
watcherCancel()
}
// 关闭服务器
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
logrus.Errorf("服务器关闭失败: %v", err)
}
logrus.Info("服务器已关闭")
}
// IndexFolderRequest 索引文件夹请求
type IndexFolderRequest struct {
FolderPath string `json:"folder_path"`
}
// IndexFolderResponse 索引文件夹响应
type IndexFolderResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
func indexHandler(w http.ResponseWriter, r *http.Request, indexerService *service.IndexerService) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req IndexFolderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.FolderPath == "" {
http.Error(w, "folder_path is required", http.StatusBadRequest)
return
}
ctx := context.Background()
if err := indexerService.IndexFolder(ctx, req.FolderPath); err != nil {
logrus.Errorf("索引文件夹失败: %v", err)
http.Error(w, "Indexing failed: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(IndexFolderResponse{
Code: 200,
Message: "Indexing successful",
})
}
// RetrieveRequest 检索请求
type RetrieveRequest struct {
Query string `json:"query"`
}
// RetrieveResponse 检索响应
type RetrieveResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Docs []*schema.Document `json:"docs"`
}
func retrieveHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req RetrieveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Query == "" {
http.Error(w, "query is required", http.StatusBadRequest)
return
}
ctx := context.Background()
docs, err := retrieverService.Retrieve(ctx, req.Query)
if err != nil {
logrus.Errorf("检索失败: %v", err)
http.Error(w, "Retrieval failed: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
encoder := json.NewEncoder(w)
encoder.SetEscapeHTML(false)
encoder.Encode(RetrieveResponse{
Code: 200,
Message: "Retrieval successful",
Docs: docs,
})
}
// GenerateRequest 生成回答请求
type GenerateRequest struct {
// Query 查询字符串(必填)
Query string `json:"query"`
// TopK 返回的最相似文档数量,可选,默认为配置文件中的值(5)
TopK int `json:"top_k"`
}
// GenerateResponse 生成回答响应
type GenerateResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Content string `json:"content"`
}
func generateHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req GenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Query == "" {
http.Error(w, "query is required", http.StatusBadRequest)
return
}
ctx := context.Background()
// 使用带选项的生成方法,支持动态 TopK
stream, err := retrieverService.GenerateWithOptions(ctx, req.Query, &service.GenerateOptions{
TopK: req.TopK,
})
if err != nil {
logrus.Errorf("生成回答失败: %v", err)
http.Error(w, "Generation failed: "+err.Error(), http.StatusInternalServerError)
return
}
// 读取流式输出
var content string
for {
msg, err := stream.Recv()
if err != nil {
break
}
content += msg.Content
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(GenerateResponse{
Code: 200,
Message: "Generation successful",
Content: content,
})
}
// DeleteDocumentRequest 删除文档请求
type DeleteDocumentRequest struct {
Key string `json:"key"`
}
// DeleteDocumentResponse 删除文档响应
type DeleteDocumentResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
func deleteDocumentHandler(w http.ResponseWriter, r *http.Request, indexerService *service.IndexerService) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req DeleteDocumentRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Key == "" {
http.Error(w, "key is required", http.StatusBadRequest)
return
}
ctx := context.Background()
if err := indexerService.DeleteDocumentByKey(ctx, req.Key); err != nil {
logrus.Errorf("删除文档失败: %v", err)
http.Error(w, "Delete failed: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(DeleteDocumentResponse{
Code: 200,
Message: "Document deleted successfully",
})
}
// RetrieveAndGenerateRequest 检索并生成回答请求
type RetrieveAndGenerateRequest struct {
Query string `json:"query"`
}
// RetrieveAndGenerateResponse 检索并生成回答响应
type RetrieveAndGenerateResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Content string `json:"content"`
}
func retrieveAndGenerateHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req RetrieveAndGenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Query == "" {
http.Error(w, "query is required", http.StatusBadRequest)
return
}
ctx := context.Background()
stream, err := retrieverService.RetrieveAndGenerate(ctx, req.Query)
if err != nil {
logrus.Errorf("检索并生成回答失败: %v", err)
http.Error(w, "Retrieval and generation failed: "+err.Error(), http.StatusInternalServerError)
return
}
// 读取流式输出
var content string
for {
msg, err := stream.Recv()
if err != nil {
break
}
content += msg.Content
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(RetrieveAndGenerateResponse{
Code: 200,
Message: "Retrieval and generation successful",
Content: content,
})
}
v2的main.go
Go
复制代码
package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"eino_redis_rag/config"
"eino_redis_rag/rag"
"eino_redis_rag/service"
"github.com/sirupsen/logrus"
)
// globalConfig 全局配置,用于健康检查等
var globalConfig *config.Config
func main() {
// 初始化统一日志
initLogger()
logrus.Info("========================================")
logrus.Info("Eino Redis RAG 服务启动中...")
logrus.Info("========================================")
// 加载配置
cfg, err := config.LoadConfig()
if err != nil {
logrus.Fatalf("加载配置失败: %v", err)
}
globalConfig = cfg
logrus.Infof("配置加载成功: 端口=%d, Redis=%s:%d, 索引=%s",
cfg.Server.Port, cfg.Redis.Host, cfg.Redis.Port, cfg.Index.Name)
// 初始化 RAG 引擎
ctx := context.Background()
ragEngine, err := rag.NewRAGEngine(ctx, cfg)
if err != nil {
logrus.Fatalf("初始化 RAG 引擎失败: %v", err)
}
// 初始化组件
logrus.Info("初始化文件加载器...")
if err := ragEngine.InitLoader(ctx, cfg); err != nil {
logrus.Fatalf("初始化文件加载器失败: %v", err)
}
logrus.Info("初始化文档分割器...")
if err := ragEngine.InitSplitter(ctx, cfg); err != nil {
logrus.Fatalf("初始化文档分割器失败: %v", err)
}
logrus.Info("初始化索引器...")
if err := ragEngine.InitIndexer(ctx); err != nil {
logrus.Fatalf("初始化索引器失败: %v", err)
}
logrus.Info("初始化检索器...")
if err := ragEngine.InitRetriever(ctx, cfg); err != nil {
logrus.Fatalf("初始化检索器失败: %v", err)
}
logrus.Info("初始化向量索引...")
if err := ragEngine.InitVectorIndex(ctx, cfg); err != nil {
logrus.Fatalf("初始化向量索引失败: %v", err)
}
// 创建服务
indexerService := service.NewIndexerService(ragEngine, cfg)
retrieverService := service.NewRetrieverService(ragEngine, cfg)
watcherService := service.NewWatcherService(indexerService, cfg)
// 创建 HTTP 路由
mux := http.NewServeMux()
// 健康检查
mux.HandleFunc("/health", healthHandler)
// API v1 路由
mux.HandleFunc("/api/v1/index", func(w http.ResponseWriter, r *http.Request) {
indexHandler(w, r, indexerService)
})
mux.HandleFunc("/api/v1/retrieve", func(w http.ResponseWriter, r *http.Request) {
retrieveHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/generate", func(w http.ResponseWriter, r *http.Request) {
generateHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/generate/stream", func(w http.ResponseWriter, r *http.Request) {
generateStreamHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/retrieve-and-generate", func(w http.ResponseWriter, r *http.Request) {
retrieveAndGenerateHandler(w, r, retrieverService)
})
mux.HandleFunc("/api/v1/document/delete", func(w http.ResponseWriter, r *http.Request) {
deleteDocumentHandler(w, r, indexerService)
})
// 创建退出信号通道
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// 启动定时任务
if cfg.Server.EnableCron {
logrus.Info("定时索引任务已启用 (每天执行一次)")
go func() {
interval := time.Hour * 24
indexerService.AutoIndex(ctx, "./test_txt", interval)
}()
}
// 启动文件监听服务
var watcherCancel context.CancelFunc
if cfg.Watcher.Enable {
watcherCtx, cancel := context.WithCancel(ctx)
watcherCancel = cancel
go func() {
if err := watcherService.Start(watcherCtx); err != nil {
logrus.Errorf("文件监听服务启动失败: %v", err)
}
}()
logrus.Info("文件监听服务已启动")
}
// 启动服务器
port := fmt.Sprintf(":%d", cfg.Server.Port)
logrus.Info("========================================")
logrus.Infof("服务器启动在端口 %d", cfg.Server.Port)
logrus.Info("API 文档:")
logrus.Info(" POST /api/v1/index - 索引文件夹")
logrus.Info(" POST /api/v1/retrieve - 检索文档")
logrus.Info(" POST /api/v1/generate - 生成回答 (JSON)")
logrus.Info(" POST /api/v1/generate/stream - 生成回答 (SSE流式)")
logrus.Info(" POST /api/v1/retrieve-and-generate - 检索并生成")
logrus.Info(" POST /api/v1/document/delete - 删除文档")
logrus.Info(" GET /health - 健康检查")
logrus.Info("========================================")
srv := &http.Server{
Addr: port,
Handler: mux,
}
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logrus.Fatalf("启动服务器失败: %v", err)
}
}()
// 等待退出信号
<-quit
logrus.Info("收到退出信号,正在关闭服务器...")
// 停止文件监听服务
if watcherCancel != nil {
watcherCancel()
logrus.Info("文件监听服务已停止")
}
// 优雅关闭服务器
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
logrus.Errorf("服务器关闭失败: %v", err)
}
logrus.Info("服务器已关闭")
}
// initLogger 初始化统一日志配置
func initLogger() {
logrus.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: time.RFC3339,
})
logrus.SetOutput(os.Stdout)
logrus.SetLevel(logrus.InfoLevel)
}
// ==================== 健康检查 ====================
// HealthResponse 健康检查响应
type HealthResponse struct {
Status string `json:"status"`
Timestamp string `json:"timestamp"`
Service string `json:"service"`
Version string `json:"version"`
}
func healthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(HealthResponse{
Status: "ok",
Timestamp: time.Now().Format(time.RFC3339),
Service: "eino-redis-rag",
Version: "1.1.0",
})
}
// ==================== 索引处理 ====================
// IndexFolderRequest 索引文件夹请求
type IndexFolderRequest struct {
FolderPath string `json:"folder_path"`
}
// APIResponse 通用 API 响应
type APIResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
func indexHandler(w http.ResponseWriter, r *http.Request, indexerService *service.IndexerService) {
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusMethodNotAllowed)
json.NewEncoder(w).Encode(APIResponse{Code: 405, Message: "Method not allowed"})
return
}
var req IndexFolderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "Invalid request body"})
return
}
if req.FolderPath == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "folder_path is required"})
return
}
logrus.Infof("开始索引文件夹: %s", req.FolderPath)
ctx := context.Background()
if err := indexerService.IndexFolder(ctx, req.FolderPath); err != nil {
logrus.Errorf("索引文件夹失败: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(APIResponse{Code: 500, Message: "Indexing failed: " + err.Error()})
return
}
logrus.Infof("文件夹索引成功: %s", req.FolderPath)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(APIResponse{Code: 200, Message: "Indexing successful"})
}
// ==================== 检索处理 ====================
// RetrieveRequest 检索请求
type RetrieveRequest struct {
Query string `json:"query"`
TopK int `json:"top_k,omitempty"`
Filter map[string]string `json:"filter,omitempty"`
Hybrid bool `json:"hybrid,omitempty"`
}
func retrieveHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusMethodNotAllowed)
json.NewEncoder(w).Encode(APIResponse{Code: 405, Message: "Method not allowed"})
return
}
var req RetrieveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "Invalid request body"})
return
}
if req.Query == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "query is required"})
return
}
logrus.Infof("检索请求: query=%s, topK=%d, hybrid=%v", req.Query, req.TopK, req.Hybrid)
ctx := context.Background()
docs, err := retrieverService.RetrieveWithOptions(ctx, req.Query, &service.RetrieveOptions{
TopK: req.TopK,
Filter: req.Filter,
UseHybrid: req.Hybrid,
})
if err != nil {
logrus.Errorf("检索失败: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(APIResponse{Code: 500, Message: "Retrieval failed: " + err.Error()})
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
encoder := json.NewEncoder(w)
encoder.SetEscapeHTML(false)
encoder.Encode(APIResponse{
Code: 200,
Message: "Retrieval successful",
Data: map[string]interface{}{
"count": len(docs),
"docs": docs,
},
})
}
// ==================== 生成回答处理 ====================
// GenerateRequest 生成回答请求
type GenerateRequest struct {
Query string `json:"query"`
TopK int `json:"top_k,omitempty"`
Hybrid bool `json:"hybrid,omitempty"`
}
// GenerateResponse 生成回答响应
type GenerateResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Content string `json:"content"`
}
func generateHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusMethodNotAllowed)
json.NewEncoder(w).Encode(APIResponse{Code: 405, Message: "Method not allowed"})
return
}
var req GenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "Invalid request body"})
return
}
if req.Query == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "query is required"})
return
}
logrus.Infof("生成回答请求: query=%s", req.Query)
ctx := context.Background()
stream, err := retrieverService.GenerateWithOptions(ctx, req.Query, &service.GenerateOptions{
TopK: req.TopK,
UseHybrid: req.Hybrid,
})
if err != nil {
logrus.Errorf("生成回答失败: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(APIResponse{Code: 500, Message: "Generation failed: " + err.Error()})
return
}
// 读取流式输出
var content string
for {
msg, err := stream.Recv()
if err != nil {
break
}
content += msg.Content
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(GenerateResponse{
Code: 200,
Message: "Generation successful",
Content: content,
})
}
// ==================== SSE 流式输出 ====================
// SSEEvent SSE 事件
type SSEEvent struct {
Event string `json:"event,omitempty"`
ID string `json:"id,omitempty"`
Data string `json:"data"`
Timestamp int64 `json:"timestamp,omitempty"`
}
func generateStreamHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusMethodNotAllowed)
json.NewEncoder(w).Encode(APIResponse{Code: 405, Message: "Method not allowed"})
return
}
var req GenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "Invalid request body"})
return
}
if req.Query == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "query is required"})
return
}
logrus.Infof("SSE 流式生成请求: query=%s", req.Query)
// 设置 SSE 响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // 禁用 Nginx 缓冲
flusher, ok := w.(http.Flusher)
if !ok {
logrus.Error("流式刷新不支持")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(APIResponse{Code: 500, Message: "Streaming not supported"})
return
}
ctx := context.Background()
stream, err := retrieverService.GenerateWithOptions(ctx, req.Query, &service.GenerateOptions{
TopK: req.TopK,
UseHybrid: req.Hybrid,
})
if err != nil {
logrus.Errorf("生成回答失败: %v", err)
writeSSEEvent(w, flusher, "error", fmt.Sprintf("Generation failed: %v", err))
return
}
// 发送开始事件
writeSSEEvent(w, flusher, "start", "")
eventID := 0
for {
msg, err := stream.Recv()
if err != nil {
if err == io.EOF {
break
}
writeSSEEvent(w, flusher, "error", err.Error())
return
}
eventID++
writeSSEEvent(w, flusher, "chunk", msg.Content)
}
// 发送结束事件
writeSSEEvent(w, flusher, "done", "")
logrus.Info("SSE 流式输出完成")
}
// writeSSEEvent 写入 SSE 事件
func writeSSEEvent(w http.ResponseWriter, flusher http.Flusher, eventType, data string) {
sseEvent := SSEEvent{
Event: eventType,
Data: data,
Timestamp: time.Now().UnixMilli(),
}
jsonData, _ := json.Marshal(sseEvent)
fmt.Fprintf(w, "data: %s\n\n", jsonData)
flusher.Flush()
}
// ==================== 删除文档处理 ====================
// DeleteDocumentRequest 删除文档请求
type DeleteDocumentRequest struct {
Key string `json:"key"`
}
func deleteDocumentHandler(w http.ResponseWriter, r *http.Request, indexerService *service.IndexerService) {
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusMethodNotAllowed)
json.NewEncoder(w).Encode(APIResponse{Code: 405, Message: "Method not allowed"})
return
}
var req DeleteDocumentRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "Invalid request body"})
return
}
if req.Key == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "key is required"})
return
}
logrus.Infof("删除文档请求: key=%s", req.Key)
ctx := context.Background()
if err := indexerService.DeleteDocumentByKey(ctx, req.Key); err != nil {
logrus.Errorf("删除文档失败: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(APIResponse{Code: 500, Message: "Delete failed: " + err.Error()})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(APIResponse{Code: 200, Message: "Document deleted successfully"})
}
// ==================== 检索并生成处理 ====================
// RetrieveAndGenerateRequest 检索并生成回答请求
type RetrieveAndGenerateRequest struct {
Query string `json:"query"`
TopK int `json:"top_k,omitempty"`
Hybrid bool `json:"hybrid,omitempty"`
}
func retrieveAndGenerateHandler(w http.ResponseWriter, r *http.Request, retrieverService *service.RetrieverService) {
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusMethodNotAllowed)
json.NewEncoder(w).Encode(APIResponse{Code: 405, Message: "Method not allowed"})
return
}
var req RetrieveAndGenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "Invalid request body"})
return
}
if req.Query == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(APIResponse{Code: 400, Message: "query is required"})
return
}
logrus.Infof("检索并生成请求: query=%s", req.Query)
ctx := context.Background()
stream, err := retrieverService.GenerateWithOptions(ctx, req.Query, &service.GenerateOptions{
TopK: req.TopK,
UseHybrid: req.Hybrid,
})
if err != nil {
logrus.Errorf("检索并生成回答失败: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(APIResponse{Code: 500, Message: "Retrieval and generation failed: " + err.Error()})
return
}
// 读取流式输出
var content string
for {
msg, err := stream.Recv()
if err != nil {
break
}
content += msg.Content
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(GenerateResponse{
Code: 200,
Message: "Retrieval and generation successful",
Content: content,
})
}
// 确保 bufio 被引用
var _ = bufio.NewWriter