一种由大型语言模型提供支持的高级 PaperSearch 代理。PaSa 可以自主做出一系列决策,包括调用搜索工具、阅读论文和选择相关参考文献,最终为复杂的学术查询获得全面准确的结果。使用合成数据集 AutoScholarQuery 使用强化学习来优化 PaSa,其中包括 35k 个细粒度的学术查询和来自顶级 AI 会议出版物的相应论文。此外,我们还开发了 RealScholarQuery,这是一个收集真实学术查询的基准测试,用于评估 PaSa 在更现实场景中的性能。尽管是在合成数据上进行训练的,但 PaSa 的性能明显优于 RealScholarQuery 上的现有基线,包括 Google、Google Scholar、Google with GPT-4 for paraphrased queries、chatGPT(支持搜索的 GPT-4o)、GPT-o1 和 PaSa-GPT-4o(通过提示 GPT-4o 实现的 PaSa)。值得注意的是,PaSa-7B 在 recall@20 年和 recall@50 年分别超过了基于 Google 的最佳基线 37.78% 和 39.90% 的 GPT-4o。它的召回率也比 PaSa-GPT-4o 高出 30.36%,准确率高出 4.25%。
162 Stars 16 Forks 1 Issues 0 贡献者 Apache-2.0 License Python 语言
代码: https://github.com/bytedance/pasa
更多AI开源软件:AI开源 - 小众AI
快速开始
您可以准备学术搜索需求的详细描述,并在 https://pasa-agent.ai 上搜索论文
架构
PaSa 系统由两个 LLM 代理组成,即 Crawler 和 Selector。Crawler 处理用户查询,并可以访问纸张队列中的论文。它可以自主调用搜索工具、扩展引文或停止处理当前论文。Crawler 收集的所有论文都将附加到论文队列中。选择器会读取纸张队列中的每篇论文,以确定它是否满足用户查询中指定的条件。
数据
所有数据集都可以在 pasa-dataset 中找到
AutoScholarQuery 自动学术查询
AutoScholarQuery 是一个综合但高质量的学术查询和相关论文数据集,专为 AI 领域策划。
RealScholarQuery 查询
RealScholarQuery 是一个测试数据集,由 AI 研究人员为使用该系统而提出的 50 个真实世界的精细研究查询组成。专业注释者通过各种检索方法尽可能全面地确定每个查询的答案。
期限
基线
我们在 AutoScholarQuery 和 RealScholarQuery 的测试集上评估我们的论文检索代理。我们将 PaSa-7b 与以下基线进行比较:
- **谷歌。**使用 Google 直接搜索查询。
- **谷歌学术。**查询将直接提交给 Google Scholar。
- **谷歌与 GPT-4o。**我们首先使用 GPT-4o 来解释 scholar 查询。然后在 Google 上搜索释义的查询。
- **聊天GPT。**我们将学者查询提交给 ChatGPT,它由支持搜索的 GPT-4o 提供支持。由于需要手动提交查询,我们只评估 AutoScholarQuery 测试集中的 100 个随机采样实例。
- **GPT-o1 的。**提示 GPT-o1 处理 scholar 查询。
- **PaSa-GPT-4o 的 Git-4o 中。**在 PaSa 框架内提示 GPT-4o。它可以执行多次搜索、论文阅读和引文网络爬取。
主要结果
如表 5 所示,PaSa-7b 在 AutoScholarQuery 测试集上的表现优于所有基线。具体来说,与最强的基线 PaSa-GPT-4o 相比,PaSa-7b 的召回率提高了 9.64%,精度相当。此外,PaSa-7b 中 Crawler 的召回率比 PaSa-GPT-4o 中高出 3.66%。与基于 Google 的最佳基线相比,使用 GPT-4o 的 Google,PaSa-7b 在 Recall@20、Recall@50 和 Recall@100 方面分别提高了 33.80%、38.83% 和 42.64%。
我们观察到,在推理过程中使用 Crawler 的多个 ensembles 可以提高性能。具体来说,在推理期间运行两次 Crawler 使 AutoScholarQuery 上的 Crawler 召回率提高了 3.34%,导致最终召回率提高了 1.51%,精度保持相似。
为了在更现实的环境中评估 PaSa,我们在 RealScholarQuery 上评估了它的有效性。如表 6 所示,PaSa-7b 在现实世界的学术搜索场景中表现出更大的优势。与 PaSa-GPT-4o 相比,PaSa-7b 的召回率提高了 30.36%,准确率提高了 4.25%。相对于 RealScholarQuery 上基于 Google 的最佳基线,使用 GPT-4o、PaSa-7b 的 Google 在 recall@20、recall@50 和 recall@100 方面的表现分别比 Google 高出 37.78%、39.90% 和 39.83%。此外,PaSa-7b-ensemble 将爬行程序召回率进一步提高了 4.32%,有助于整个代理系统的召回率总体提高 3.52%。
在本地运行
数据准备
从 pasa-dataset 下载数据集并保存在 data 文件夹中。
pasa/data
├── AutoScholarQuery
│ ├── dev.jsonl
│ ├── test.jsonl
│ └── train.jsonl
├── paper_database
│ ├── cs_paper_2nd.zip
│ └── id2paper.json
├── RealScholarQuery
│ └── test.jsonl
├── sft_crawler
│ └── train.jsonl
└── sft_selector
├── test.jsonl
└── train.jsonl
模型准备
下载模型 checkpoints pasa-7b-crawler 和 pasa-7b-selector 并将其保存在 checkpoints 文件夹中。
pasa/checkpoints
├── pasa-7b-crawler
└── pasa-7b-selector
运行 Pasa
git clone git@github.com:hyc2026/transformers.git
cd transformers
pip3 install -e .
cd ..
pip install -r requirements.txt
您需要先在 serper.dev 申请 Google Search API 密钥,然后替换 中的"您的 google 密钥"。utils.py
python run_paper_agent.py
- 将从用户查询中生成搜索查询,并从论文的 ll 次要部分名称中选择扩展部分。crawler
- 它将论文的标题和摘要作为输入,并生成一个分数,该分数表示论文与用户查询之间的相关性。selector
- 我们还使用 google search api 来搜索 生成的查询,并使用 arxiv/ar5iv search api 来获取完整的论文。crawler
训练您自己的代理
我们修改了 和 的代码,您可以在克隆安装后进行 SFT 和 PPO 训练。trltransformers
https://github.com/hyc2026/trl https://github.com/hyc2026/transformers
安装依赖项
git clone git@github.com:hyc2026/trl.git
cd trl
pip3 install -e .
cd ..
git clone git@github.com:hyc2026/transformers.git
cd transformers
pip3 install -e .
cd ..
pip install -r requirements.txt
Selector SFT 培训
cd trl
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 8 \
--main_process_port 2501 \
--machine_rank 0 \
--main_process_ip 127.0.0.1 \
examples/scripts/sft.py \
--model_name_or_path Qwen2.5-7B-Instruct \
--dataset_name ../data/sft_selector/train.jsonl \
--learning_rate 1.0e-5 \
--num_train_epochs 1 \
--bf16 True \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--gradient_checkpointing \
--logging_steps 50 \
--save_steps 2000 \
--max_seq_length 1024 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--output_dir ../results/sft_selector \
--attn_implementation "flash_attention_2"
SFT 培训
cd trl
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 8 \
--main_process_port 2501 \
--machine_rank 0 \
--main_process_ip 127.0.0.1 \
examples/scripts/sft.py \
--model_name_or_path Qwen2.5-7B-Instruct \
--dataset_name ../data/sft_crawler/train.jsonl \
--learning_rate 1.0e-5 \
--num_train_epochs 1 \
--bf16 True \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--gradient_checkpointing \
--logging_steps 50 \
--save_steps 2000 \
--max_seq_length 1024 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--output_dir ../results/sft_crawler \
--attn_implementation "flash_attention_2"
Crawler PPO 培训
**培训前:**
-
您需要先在 serper.dev 申请 Google Search API 密钥,然后替换 中的"您的 google 密钥"。trl/custom_agent/search_tools.py
-
如果设置,则需要部署其他选择器模型,这些模型可在训练期间访问。请修改函数 in 以调用 selector 并获取 select 结果。use_selector=Truecall_selectortrl/custom_agent/utils.py
cd trl
accelerate launch
--config_file examples/accelerate_configs/deepspeed_zero3_multi.yaml
--main_process_port 2501
--machine_rank 0
--main_process_ip 127.0.0.1
examples/scripts/ppo/ppo_tldr.py
--dataset_name ../data/AutoScholarQuery/train.jsonl
--dataset_test_split validation
--output_dir ../results/ppo_crawler
--learning_rate 1e-6
--per_device_train_batch_size 1
--gradient_accumulation_steps 4
--total_episodes 16000
--paper_db ../data/paper_database/cs_paper_2nd.zip
--paper_id ../data/paper_database/id2paper.json
--model_name_or_path ../output/sft_crawler
--sft_model_path ../output/sft_crawler
--reward_model_path ../output/sft_crawler
--local_rollout_forward_batch_size 4
--num_sample_generations 0
--attn_implementation "flash_attention_2"
--response_length 1024
--stop_token eos
--gamma1 0.1
--save_steps 10
--rounds 3
--use_vm True
--use_selector True
--vf_coef 10.0
--expand_select_score 1.5
--expand_cost 0.1
--search_select_score 1.5
--search_cost 0.1
--num_ppo_epochs 2
--kl_coef 0.1