FastGPT 引申:基于 Python 版本实现 Java 版本 RRF

文章目录

    • [FastGPT 引申:基于 Python 版本实现 Java 版本 RRF](#FastGPT 引申:基于 Python 版本实现 Java 版本 RRF)

FastGPT 引申:基于 Python 版本实现 Java 版本 RRF

函数定义

使用 Java 实现 RRF 相关的两个函数:合并结果、过滤结果

java 复制代码
import java.util.*;

// 搜索结果类型定义
public class SearchDataResponseItem {
    private String id;
    private String q;
    private String a;
    private List<Score> score;
    private double rrfScore;  // 临时存储RRF分数
    
    // 其他字段...
    
    // getter和setter方法
}

// 分数类型定义
public class Score {
    private String type;
    private double value;
    private int index;
    
    // getter和setter方法
}

// 搜索结果合并工具类
public class DatasetSearchUtils {
    
    /**
     * RRF搜索结果合并
     * @param searchResults 搜索结果列表,包含k值和结果列表
     * @return 合并后的结果
     */
    public static List<SearchDataResponseItem> datasetSearchResultConcat(
            List<SearchResultGroup> searchResults) {
            
        // 过滤空结果
        searchResults = searchResults.stream()
                .filter(item -> !item.getList().isEmpty())
                .collect(Collectors.toList());
                
        // 处理边界情况
        if (searchResults.isEmpty()) {
            return new ArrayList<>();
        }
        if (searchResults.size() == 1) {
            return searchResults.get(0).getList();
        }
        
        // 用Map存储合并结果
        Map<String, SearchDataResponseItem> resultMap = new HashMap<>();
        
        // RRF算法实现
        for (SearchResultGroup group : searchResults) {
            int k = group.getK();
            List<SearchDataResponseItem> list = group.getList();
            
            for (int i = 0; i < list.size(); i++) {
                SearchDataResponseItem data = list.get(i);
                int rank = i + 1;
                double score = 1.0 / (k + rank);
                
                SearchDataResponseItem record = resultMap.get(data.getId());
                if (record != null) {
                    // 合并分数
                    List<Score> concatScore = new ArrayList<>(record.getScore());
                    for (Score dataScore : data.getScore()) {
                        Optional<Score> sameScore = concatScore.stream()
                                .filter(s -> s.getType().equals(dataScore.getType()))
                                .findFirst();
                                
                        if (sameScore.isPresent()) {
                            sameScore.get().setValue(
                                Math.max(sameScore.get().getValue(), dataScore.getValue())
                            );
                        } else {
                            concatScore.add(dataScore);
                        }
                    }
                    
                    // 更新记录
                    record.setScore(concatScore);
                    record.setRrfScore(record.getRrfScore() + score);
                    resultMap.put(data.getId(), record);
                } else {
                    // 新记录
                    data.setRrfScore(score);
                    resultMap.put(data.getId(), data);
                }
            }
        }
        
        // 排序
        List<SearchDataResponseItem> results = new ArrayList<>(resultMap.values());
        results.sort((a, b) -> Double.compare(b.getRrfScore(), a.getRrfScore()));
        
        // 格式化结果
        for (int i = 0; i < results.size(); i++) {
            SearchDataResponseItem item = results.get(i);
            
            Optional<Score> rrfScore = item.getScore().stream()
                    .filter(s -> s.getType().equals("rrf"))
                    .findFirst();
                    
            if (rrfScore.isPresent()) {
                rrfScore.get().setValue(item.getRrfScore());
                rrfScore.get().setIndex(i);
            } else {
                Score newScore = new Score();
                newScore.setType("rrf");
                newScore.setValue(item.getRrfScore());
                newScore.setIndex(i);
                item.getScore().add(newScore);
            }
            
            // 清除临时RRF分数
            item.setRrfScore(0);
        }
        
        return results;
    }
    
    /**
     * 按最大Token数过滤结果
     * @param list 搜索结果列表
     * @param maxTokens 最大token限制
     * @return 过滤后的结果
     */
    public static List<SearchDataResponseItem> filterSearchResultsByMaxChars(
            List<SearchDataResponseItem> list, 
            int maxTokens) {
            
        List<SearchDataResponseItem> results = new ArrayList<>();
        int totalTokens = 0;
        
        for (SearchDataResponseItem item : list) {
            // 注意:这里需要实现countPromptTokens方法
            int tokens = countPromptTokens(item.getQ() + item.getA());
            totalTokens += tokens;
            
            if (totalTokens > maxTokens + 500) {
                break;
            }
            
            results.add(item);
            
            if (totalTokens > maxTokens) {
                break;
            }
        }
        
        // 确保至少返回一条结果
        if (results.isEmpty() && !list.isEmpty()) {
            results.add(list.get(0));
        }
        
        return results;
    }
    
    /**
     * 计算文本的token数量
     * 注意:这是一个示例实现,实际需要根据具体的分词算法来实现
     */
    private static int countPromptTokens(String text) {
        // 这里需要实现实际的token计算逻辑
        // 可以使用各种NLP库或自定义的分词算法
        return text.length(); // 示例实现
    }
}

// 搜索结果分组类
class SearchResultGroup {
    private int k;
    private List<SearchDataResponseItem> list;
    
    // getter和setter方法
}

使用示例

java 复制代码
// 使用示例
List<SearchResultGroup> searchResults = new ArrayList<>();
// ... 添加搜索结果

// 合并结果
List<SearchDataResponseItem> mergedResults = 
    DatasetSearchUtils.datasetSearchResultConcat(searchResults);

// 过滤结果
List<SearchDataResponseItem> filteredResults = 
    DatasetSearchUtils.filterSearchResultsByMaxChars(mergedResults, 1500);
相关推荐
冬奇Lab11 小时前
每日一个开源项目(第147篇):HyperGraphRAG - 用超图表示 N 元关系,RAG 的第三代范式
人工智能·开源·graphql
网易云信13 小时前
Cursor点燃个人开发者,企业级AI为何频频受挫?Agent工厂从提效工具到AI员工的跃迁
人工智能·开源
ZzT15 小时前
在 GitHub 上 @一下 claude,它自己把 issue 改成 PR
人工智能·开源
饼干哥哥16 小时前
最强视频创作工作流:Image2 + Seedance 2.0,Topview一键闭环|跨境电商版
开源·产品·设计
ApacheSeaTunnel17 小时前
当多表数据涌入,Apache SeaTunnel 如何巧妙化解主键冲突?
大数据·开源·数据集成·seatunnel·技术分享·数据同步
稀土熊猫君17 小时前
一个人能做出什么开源项目?
vue.js·后端·开源
狂师1 天前
比 Playwright 更给力,推荐一个AI Agent的浏览器自动化开源项目!
前端·开源·测试
AI袋鼠帝1 天前
开源「仓颉.Skill」2.0,你现在可以蒸馏任何视频!
开源·aigc
冬奇Lab1 天前
每日一个开源项目(第146篇):openpilot - 开源自动驾驶辅助系统,曾在 Consumer Reports 评测中超过特斯拉 Autopilot
人工智能·开源·自动驾驶
她的男孩2 天前
后台接口加密别只会 HTTPS,ForgeAdmin 的 RSA + SM4/AES 源码拆解
后端·面试·开源