关于【RAG&多模态】多模态RAG-ColPali:使用视觉语言模型实现高效的文档检索前面已经介绍了(供参考),这次来看看ColPali实践。
所需权重:
-
多模态问答模型:Qwen2-VL-72B-Instruct,https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct
-
基于 PaliGemma-3B 和 ColBERT 策略的视觉检索器:
-
ColPali(LoRA):https://huggingface.co/vidore/colpali
-
ColPali(基座):https://huggingface.co/vidore/colpaligemma-3b-mix-448-base
-
多模态检索问答实践
-
lora的adapter_config.json字段base_model_name_or_path修改地址:ColPali(基座)存储路径
-
qwen_vl_utils下载地址:https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils/src/qwen_vl_utils
-
byaldi安装方式:https://github.com/AnswerDotAI/byaldi
-
完整代码
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from pdf2image import convert_from_pathclass DocumentQA:
def init(self, rag_model_name: str, vlm_model_name: str, device: str = 'cuda', system_prompt: str = None):
self.rag_engine = RAGMultiModalModel.from_pretrained(rag_model_name)
self.vlm = Qwen2VLForConditionalGeneration.from_pretrained(
vlm_model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=device
)
self.processor = AutoProcessor.from_pretrained(vlm_model_name, trust_remote_code=True)
self.device = device
if system_prompt is None:
self.system_prompt = (
"你是一位专精于计算机科学和机器学习的AI研究助理。"
"你的任务是分析学术论文,尤其是关于文档检索和多模态模型的研究。"
"请仔细分析提供的图像和文本,提供深入的见解和解释。"
)
else:
self.system_prompt = system_promptdef index_document(self, pdf_path: str, index_name: str = 'index', overwrite: bool = True): self.pdf_path = pdf_path self.rag_engine.index( input_path=pdf_path, index_name=index_name, store_collection_with_index=False, overwrite=overwrite ) self.images = convert_from_path(pdf_path) def query(self, text_query: str, k: int = 3) -> str: results = self.rag_engine.search(text_query, k=k) print("搜索结果:", results) if not results: print("未找到相关查询结果。") return None try: page_num = results[0]["page_num"] image_index = page_num - 1 image = self.images[image_index] except (KeyError, IndexError) as e: print("获取页面图像时出错:", e) return None messages = [ { "role": "system", "content": self.system_prompt }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text_query}, ], } ] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) # 准备模型输入 inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(self.device) generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0]
if name == "main":
# 初始化 DocumentQA 实例
document_qa = DocumentQA(
rag_model_name="./colpali",
vlm_model_name="./Qwen2-VL-7B-Instruct",
device='cuda'
)# 索引 PDF 文档 document_qa.index_document("test.pdf") # 定义查询 text_query = ( "文中模型在哪个数据集上相比其他模型有最大的优势?" "该优势的改进幅度是多少?" ) # 执行查询并打印答案 answer = document_qa.query(text_query) print("答案:", answer)