PaSa - 大型语言模型提供支持的高级论文搜索代理

一种由大型语言模型提供支持的高级 PaperSearch 代理。PaSa 可以自主做出一系列决策,包括调用搜索工具、阅读论文和选择相关参考文献,最终为复杂的学术查询获得全面准确的结果。使用合成数据集 AutoScholarQuery 使用强化学习来优化 PaSa,其中包括 35k 个细粒度的学术查询和来自顶级 AI 会议出版物的相应论文。此外,我们还开发了 RealScholarQuery,这是一个收集真实学术查询的基准测试,用于评估 PaSa 在更现实场景中的性能。尽管是在合成数据上进行训练的,但 PaSa 的性能明显优于 RealScholarQuery 上的现有基线,包括 Google、Google Scholar、Google with GPT-4 for paraphrased queries、chatGPT(支持搜索的 GPT-4o)、GPT-o1 和 PaSa-GPT-4o(通过提示 GPT-4o 实现的 PaSa)。值得注意的是,PaSa-7B 在 recall@20 年和 recall@50 年分别超过了基于 Google 的最佳基线 37.78% 和 39.90% 的 GPT-4o。它的召回率也比 PaSa-GPT-4o 高出 30.36%,准确率高出 4.25%。

162 Stars 16 Forks 1 Issues 0 贡献者 Apache-2.0 License Python 语言

代码: https://github.com/bytedance/pasa

更多AI开源软件:AI开源 - 小众AI

快速开始

您可以准备学术搜索需求的详细描述,并在 https://pasa-agent.ai 上搜索论文

架构

PaSa 系统由两个 LLM 代理组成,即 Crawler 和 Selector。Crawler 处理用户查询,并可以访问纸张队列中的论文。它可以自主调用搜索工具、扩展引文或停止处理当前论文。Crawler 收集的所有论文都将附加到论文队列中。选择器会读取纸张队列中的每篇论文,以确定它是否满足用户查询中指定的条件。

数据

所有数据集都可以在 pasa-dataset 中找到

AutoScholarQuery 自动学术查询

AutoScholarQuery 是一个综合但高质量的学术查询和相关论文数据集,专为 AI 领域策划。

RealScholarQuery 查询

RealScholarQuery 是一个测试数据集,由 AI 研究人员为使用该系统而提出的 50 个真实世界的精细研究查询组成。专业注释者通过各种检索方法尽可能全面地确定每个查询的答案。

期限

基线

我们在 AutoScholarQuery 和 RealScholarQuery 的测试集上评估我们的论文检索代理。我们将 PaSa-7b 与以下基线进行比较:

  • **谷歌。**使用 Google 直接搜索查询。
  • **谷歌学术。**查询将直接提交给 Google Scholar。
  • **谷歌与 GPT-4o。**我们首先使用 GPT-4o 来解释 scholar 查询。然后在 Google 上搜索释义的查询。
  • **聊天GPT。**我们将学者查询提交给 ChatGPT,它由支持搜索的 GPT-4o 提供支持。由于需要手动提交查询,我们只评估 AutoScholarQuery 测试集中的 100 个随机采样实例。
  • **GPT-o1 的。**提示 GPT-o1 处理 scholar 查询。
  • **PaSa-GPT-4o 的 Git-4o 中。**在 PaSa 框架内提示 GPT-4o。它可以执行多次搜索、论文阅读和引文网络爬取。
主要结果

如表 5 所示,PaSa-7b 在 AutoScholarQuery 测试集上的表现优于所有基线。具体来说,与最强的基线 PaSa-GPT-4o 相比,PaSa-7b 的召回率提高了 9.64%,精度相当。此外,PaSa-7b 中 Crawler 的召回率比 PaSa-GPT-4o 中高出 3.66%。与基于 Google 的最佳基线相比,使用 GPT-4o 的 Google,PaSa-7b 在 Recall@20、Recall@50 和 Recall@100 方面分别提高了 33.80%、38.83% 和 42.64%。

我们观察到,在推理过程中使用 Crawler 的多个 ensembles 可以提高性能。具体来说,在推理期间运行两次 Crawler 使 AutoScholarQuery 上的 Crawler 召回率提高了 3.34%,导致最终召回率提高了 1.51%,精度保持相似。

为了在更现实的环境中评估 PaSa,我们在 RealScholarQuery 上评估了它的有效性。如表 6 所示,PaSa-7b 在现实世界的学术搜索场景中表现出更大的优势。与 PaSa-GPT-4o 相比,PaSa-7b 的召回率提高了 30.36%,准确率提高了 4.25%。相对于 RealScholarQuery 上基于 Google 的最佳基线,使用 GPT-4o、PaSa-7b 的 Google 在 recall@20、recall@50 和 recall@100 方面的表现分别比 Google 高出 37.78%、39.90% 和 39.83%。此外,PaSa-7b-ensemble 将爬行程序召回率进一步提高了 4.32%,有助于整个代理系统的召回率总体提高 3.52%。

在本地运行

数据准备

pasa-dataset 下载数据集并保存在 data 文件夹中。

pasa/data
├── AutoScholarQuery
│   ├── dev.jsonl
│   ├── test.jsonl
│   └── train.jsonl
├── paper_database
│   ├── cs_paper_2nd.zip
│   └── id2paper.json
├── RealScholarQuery
│   └── test.jsonl
├── sft_crawler
│   └── train.jsonl
└── sft_selector
    ├── test.jsonl
    └── train.jsonl
模型准备

下载模型 checkpoints pasa-7b-crawlerpasa-7b-selector 并将其保存在 checkpoints 文件夹中。

pasa/checkpoints
├── pasa-7b-crawler
└── pasa-7b-selector
运行 Pasa
git clone git@github.com:hyc2026/transformers.git
cd transformers
pip3 install -e .
cd ..
pip install -r requirements.txt

您需要先在 serper.dev 申请 Google Search API 密钥,然后替换 中的"您的 google 密钥"。utils.py​

python run_paper_agent.py
  • 将从用户查询中生成搜索查询,并从论文的 ll 次要部分名称中选择扩展部分。crawler
  • 它将论文的标题和摘要作为输入,并生成一个分数,该分数表示论文与用户查询之间的相关性。selector
  • 我们还使用 google search api 来搜索 生成的查询,并使用 arxiv/ar5iv search api 来获取完整的论文。crawler

训练您自己的代理

我们修改了 和 的代码,您可以在克隆安装后进行 SFT 和 PPO 训练。trltransformers​

https://github.com/hyc2026/trl https://github.com/hyc2026/transformers

安装依赖项
git clone git@github.com:hyc2026/trl.git
cd trl
pip3 install -e .
cd ..
git clone git@github.com:hyc2026/transformers.git
cd transformers
pip3 install -e .
cd ..
pip install -r requirements.txt
Selector SFT 培训
cd trl
accelerate launch \
    --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ 
    --num_processes 8 \
    --main_process_port 2501 \
    --machine_rank 0 \
    --main_process_ip 127.0.0.1 \
    examples/scripts/sft.py \
    --model_name_or_path Qwen2.5-7B-Instruct \
    --dataset_name ../data/sft_selector/train.jsonl \
    --learning_rate 1.0e-5 \
    --num_train_epochs 1 \
    --bf16 True \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing \
    --logging_steps 50 \
    --save_steps 2000 \
    --max_seq_length 1024 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --output_dir ../results/sft_selector \
    --attn_implementation "flash_attention_2"
SFT 培训
cd trl
accelerate launch \
    --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ 
    --num_processes 8 \
    --main_process_port 2501 \
    --machine_rank 0 \
    --main_process_ip 127.0.0.1 \
    examples/scripts/sft.py \
    --model_name_or_path Qwen2.5-7B-Instruct \
    --dataset_name ../data/sft_crawler/train.jsonl \
    --learning_rate 1.0e-5 \
    --num_train_epochs 1 \
    --bf16 True \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing \
    --logging_steps 50 \
    --save_steps 2000 \
    --max_seq_length 1024 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --output_dir ../results/sft_crawler \
    --attn_implementation "flash_attention_2"
Crawler PPO 培训

**培训前:**

  1. 您需要先在 serper.dev 申请 Google Search API 密钥,然后替换 中的"您的 google 密钥"。trl/custom_agent/search_tools.py

  2. 如果设置,则需要部署其他选择器模型,这些模型可在训练期间访问。请修改函数 in 以调用 selector 并获取 select 结果。use_selector=Truecall_selectortrl/custom_agent/utils.py

    cd trl
    accelerate launch
    --config_file examples/accelerate_configs/deepspeed_zero3_multi.yaml
    --main_process_port 2501
    --machine_rank 0
    --main_process_ip 127.0.0.1
    examples/scripts/ppo/ppo_tldr.py
    --dataset_name ../data/AutoScholarQuery/train.jsonl
    --dataset_test_split validation
    --output_dir ../results/ppo_crawler
    --learning_rate 1e-6
    --per_device_train_batch_size 1
    --gradient_accumulation_steps 4
    --total_episodes 16000
    --paper_db ../data/paper_database/cs_paper_2nd.zip
    --paper_id ../data/paper_database/id2paper.json
    --model_name_or_path ../output/sft_crawler
    --sft_model_path ../output/sft_crawler
    --reward_model_path ../output/sft_crawler
    --local_rollout_forward_batch_size 4
    --num_sample_generations 0
    --attn_implementation "flash_attention_2"
    --response_length 1024
    --stop_token eos
    --gamma1 0.1
    --save_steps 10
    --rounds 3
    --use_vm True
    --use_selector True
    --vf_coef 10.0
    --expand_select_score 1.5
    --expand_cost 0.1
    --search_select_score 1.5
    --search_cost 0.1
    --num_ppo_epochs 2
    --kl_coef 0.1

相关推荐
刀客12313 分钟前
python3+TensorFlow 2.x(四)反向传播
人工智能·python·tensorflow
SpikeKing19 分钟前
LLM - 大模型 ScallingLaws 的设计 100B 预训练方案(PLM) 教程(5)
人工智能·llm·预训练·scalinglaws·100b·deepnorm·egs
小枫@码43 分钟前
免费GPU算力,不花钱部署DeepSeek-R1
人工智能·语言模型
liruiqiang0544 分钟前
机器学习 - 初学者需要弄懂的一些线性代数的概念
人工智能·线性代数·机器学习·线性回归
Icomi_1 小时前
【外文原版书阅读】《机器学习前置知识》1.线性代数的重要性,初识向量以及向量加法
c语言·c++·人工智能·深度学习·神经网络·机器学习·计算机视觉
微学AI1 小时前
GPU算力平台|在GPU算力平台部署可图大模型Kolors的应用实战教程
人工智能·大模型·llm·gpu算力
西猫雷婶1 小时前
python学opencv|读取图像(四十六)使用cv2.bitwise_or()函数实现图像按位或运算
人工智能·opencv·计算机视觉
IT古董1 小时前
【深度学习】常见模型-生成对抗网络(Generative Adversarial Network, GAN)
人工智能·深度学习·生成对抗网络
Jackilina_Stone1 小时前
【论文阅读笔记】“万字”关于深度学习的图像和视频阴影检测、去除和生成的综述笔记 | 2024.9.3
论文阅读·人工智能·笔记·深度学习·ai
远洋录1 小时前
AI Agent的安全实践:权限控制与数据保护
人工智能·ai·ai agent