论文:Visual Instruction Tuning
arXiv: 2304.08485 | NeurIPS 2023 Oral
作者:Haotian Liu, Chunyuan Li, Qingyang Wu, Yong Jae Lee
机构:University of Wisconsin--Madison, Microsoft Research, Columbia University
目录
- [论文逐段精读(原文 + 解析)](#论文逐段精读(原文 + 解析))
- Abstract
- Introduction
- Related Work
- GPT-assisted Visual Instruction Data Generation
- Visual Instruction Tuning(核心架构章节)
- Experiments
- Conclusion
- [LLaVA 模型网络结构深度解析(源码级)](#LLaVA 模型网络结构深度解析(源码级))
- 整体架构概览
- 模块一:CLIP 视觉编码器
- 模块二:MLP Projection 层
- 模块三:多模态 Token 拼接
- 模块四:LLaMA 语言解码器
- 完整前向传播流程图
一、论文逐段精读
Abstract(摘要)
原文:
Instruction tuning large language models (LLMs) using machine-generated instruction-following data has improved zero-shot capabilities on new tasks, but the idea is less explored in the multimodal field. In this paper, we present the first attempt to use language-only GPT-4 to generate multimodal language-image instruction-following data. By instruction tuning on such generated data, we introduce LLaVA (Large Language and Vision Assistant): an end-to-end trained large multimodal model that connects a vision encoder and LLM for general-purpose visual and language understanding.
Our early experiments show that LLaVA demonstrates impressive multimodal chat abilities, sometimes exhibiting the behaviors of multimodal GPT-4 on unseen images/instructions, and yields a 85.1% relative score compared with GPT-4 on a synthetic multimodal instruction-following dataset. When fine-tuned on Science QA, the synergy of LLaVA and GPT-4 achieves a new state-of-the-art accuracy of 92.53%. We make GPT-4 generated visual instruction tuning data, model and code publicly available.
精读解析:
摘要只有两段,但信息量极高。
第一段点出了本文的核心动机和方法:instruction tuning(指令微调)在纯文本LLM上已经被验证非常有效(如InstructGPT、Alpaca、Vicuna),但在多模态领域几乎没人做过。LLaVA的贡献在于:第一,用语言版GPT-4(text-only)来生成图像-语言的指令跟随数据;第二,基于这些数据训练了一个叫LLaVA的多模态模型,架构上非常简洁------把CLIP视觉编码器和LLM连起来。
第二段汇报了两个核心实验结果:(1) 在自建的多模态指令跟随数据集上,LLaVA达到GPT-4得分的85.1%;(2) 在Science QA上,LLaVA+GPT-4集成达到92.53%的新SOTA。这两个数字在2023年4月(论文首发时)是相当惊艳的。
关键词理解:
- instruction-following data(指令跟随数据):一种特定格式的数据,通常是"Human: [问题] \n Assistant: [回答]",用于训练模型理解和遵循指令。
- end-to-end trained:端到端训练,视觉编码器和语言模型在同一训练目标下联合优化(部分参数冻结,但整体是一个系统)。
- connects a vision encoder and LLM:这句话高度概括了LLaVA的架构------CLIP + Projection + LLM,三个模块串联。
Section 1: Introduction(引言)
原文段落一:
Humans interact with the world through many channels such as vision and language, as each individual channel has a unique advantage in representing and communicating certain concepts, and thus facilitates a better understanding of the world. One of the core aspirations in artificial intelligence is to develop a general-purpose assistant that can effectively follow multi-modal vision-and-language instructions, aligned with human intent to complete various real-world tasks in the wild.
解析:
这是标准的"宏大背景"开头。核心逻辑是:人类通过视觉+语言理解世界,AI也应该如此。关键词"general-purpose assistant"预示着本文不是在做某个特定的视觉任务(比如分类、检测),而是要做一个通用的多模态助手。
原文段落二:
To this end, the community has witnessed an emergent interest in developing language-augmented foundation vision models, with strong capabilities in open-world visual understanding such as classification, detection, segmentation and captioning, as well as visual generation and editing. In this line of work, each task is solved independently by one single large vision model, with the task instruction implicitly considered in the model design. Further, language is only utilized to describe the image content. While this allows language to play an important role in mapping visual signals to language semantics---a common channel for human communication, it leads to models that usually have a fixed interface with limited interactivity and adaptability to the user's instructions.
解析:
这段是在批评现有工作的局限性。现有的视觉-语言模型(如CLIP、BLIP等)虽然能做视觉理解,但存在两个根本缺陷:
- 每个任务一个专用模型,缺乏通用性;
- 语言只是被用来描述图像内容(比如image captioning),而不是用来下达指令、控制模型行为。
这引出了LLaVA要解决的核心问题:如何让模型理解并遵循用户的多模态指令。
原文段落三:
Large language models (LLM), on the other hand, have shown that language can play a wider role: a universal interface for a general-purpose assistant, where various task instructions can be explicitly represented in language and guide the end-to-end trained neural assistant to switch to the task of interest to solve it. For example, the recent success of ChatGPT and GPT-4 have demonstrated the power of aligned LLMs in following human instructions, and have stimulated tremendous interest in developing open-source LLMs. Among them, LLaMA is an open-source LLM that matches the performance of GPT-3. Alpaca, Vicuna, GPT-4-LLM utilize various machine-generated high-quality instruction-following samples to improve the LLM's alignment ability, reporting impressive performance compared with proprietary LLMs. Importantly, this line of work is text-only.
解析:
这段是说"灵感来源"。LLM那边的指令微调已经做得很好了(ChatGPT、GPT-4、Vicuna都是例子),关键是它们都是pure text(纯文本)的。那一个自然的问题就是:能不能把这套方法移植到多模态领域?本文最后一句"Importantly, this line of work is text-only"是一个点睛之笔------它既是对现有工作的精准定位,也是对本文工作的铺垫。
原文段落四(核心贡献):
In this paper, we present visual instruction-tuning, the first attempt to extend instruction-tuning to the language-image multimodal space, to pave the way towards building a general-purpose visual assistant. In particular, our paper makes the following contributions:
Multimodal instruction-following data. One key challenge is the lack of vision-language instruction-following data. We present a data reformation perspective and pipeline to convert image-text pairs into an appropriate instruction-following format, using ChatGPT/GPT-4.
Large multimodal models. We develop a large multimodal model (LMM), by connecting the open-set visual encoder of CLIP with the language decoder Vicuna, and fine-tuning end-to-end on our generated instructional vision-language data. Our empirical study validates the effectiveness of using generated data for LMM instruction-tuning, and suggests practical tips for building a general-purpose instruction-following visual agent. When ensembled with GPT-4, our approach achieves SoTA on the Science QA multimodal reasoning dataset.
Multimodal instruction-following benchmark. We present LLaVA-Bench with two challenging benchmarks, with a diverse selection of paired images, instructions and detailed annotations.
Open-source. We release the following assets to the public: the generated multimodal instruction data, the codebase, the model checkpoints, and a visual chat demo.
解析:
四个贡献,依次对应四个关键点:
- 数据:用GPT-4把现有的图文对数据转成指令跟随格式------这是整个工作的基础,解决了多模态指令数据几乎为零的问题。
- 模型(架构):CLIP + Vicuna,中间用一个Linear Projection连接------架构极简,但效果很好,这也是LLaVA最大的贡献之一。
- 评估基准:提出LLaVA-Bench,填补了多模态指令跟随评测的空白。
- 开源:数据+代码+权重全部开源,对后续工作影响巨大。
对于VLA研究方向来说,第2点(架构设计)是最值得深究的,直接影响了OpenVLA等后续工作如何设计多模态模型。
Section 2: Related Work(相关工作)
原文段落一:
Multimodal Instruction-following Agents. In computer vision, existing works that build instruction-following agents can be broadly categorized into two classes: (i) End-to-end trained models, which are separately explored for each specific research topic. For example, the vision-language navigation task and Habitat require the embodied AI agent to follow natural language instructions and take a sequence of actions to complete goals in visual environments. In the image editing domain, given an input image and a written instruction that tells the agent what to do, InstructPix2Pix edits images by following the human instructions. (ii) A system that coordinates various models via LangChain / LLMs, such as Visual ChatGPT, X-GPT, MM-REACT, VisProg, and ViperGPT. While sharing the same goal in building instruction-following agents, we focus on developing an end-to-end trained language-vision multimodal model for multiple tasks.
解析:
这段把"多模态指令跟随智能体"分成了两类:
- 端到端训练模型:每个任务训一个模型(如导航、图像编辑),指令隐式包含在模型结构里;
- 多模型协调系统:用LangChain/LLM把多个专用模型串起来,类似"工具调用"。
LLaVA属于第一类,但不同于以往每个任务专门训的做法,它追求的是一个模型解决多个任务(通用助手)。这个定位对于后来的VLA系统至关重要------RT-2、OpenVLA也走的是这条路(一个大模型处理所有任务)。
原文段落二:
Instruction Tuning. In the natural language processing (NLP) community, to enable LLMs such as GPT-3, T5, PaLM, and OPT to follow natural language instructions and complete real-world tasks, researchers have explored methods for LLM instruction-tuning, leading to instruction-tuned counterparts such as InstructGPT/ChatGPT, FLAN-T5, FLAN-PaLM, and OPT-IML, respectively. It turns out that this simple approach can effectively improve the zero- and few-shot generalization abilities of LLMs. It is thus natural to borrow the idea from NLP to computer vision. Flamingo can be viewed as the GPT-3 moment in the multimodal domain, due to its strong performance on zero-shot task transfer and in-context-learning. Other LMMs trained on image-text pairs include BLIP-2, FROMAGe, and KOSMOS-1. PaLM-E is an LMM for embodied AI. Based on the recent "best" open-source LLM LLaMA, OpenFlamingo and LLaMA-Adapter are open-source efforts that enable LLaMA to use image inputs, paving the way to build open-source multimodal LLMs. While these models present promising task transfer generalization performance, they are not explicitly tuned with vision-language instruction data, and their performance in multimodal tasks usually falls short compared to language-only tasks. In this paper, we aim to fill this gap and study its effectiveness. Finally, note that visual instruction tuning is different from visual prompt tuning: the former aims to improve the model's instruction-following abilities, while the latter aims to improve the parameter-efficiency in model adaptation.
解析:
这段梳理了LLM侧的指令微调历史,然后指出多模态侧的现有模型(Flamingo、BLIP-2、OpenFlamingo等)虽然有图像理解能力,但没有用指令数据做微调,所以在遵循用户指令方面表现差。最后一句话特别重要:visual instruction tuning(LLaVA的方法)和visual prompt tuning是两回事------前者改变模型行为,后者提升参数效率,不要混淆。
Section 3: GPT-assisted Visual Instruction Data Generation(数据生成)
原文段落一:
The community has witnessed a surge in the amount of public multimodal data such as image-text pairs, ranging from CC to LAION. However, when it comes to multimodal instruction-following data, the available amount is limited, partially because the process for creating such data is time-consuming and less well-defined when human crowd-scouring is considered. Inspired by the success of recent GPT models in text-annotation tasks, we propose to leverage ChatGPT/GPT-4 for multimodal instruction-following data collection, based on the widely existing image-pair data.
解析:
指出了核心问题:图文对数据很多,但指令跟随格式的数据几乎没有,人工标注太贵。解决方案是用GPT-4来做自动转换------这是一个非常聪明的工程技巧。
原文段落二:
For an image Xv and its associated caption Xc, it is natural to create a set of questions Xq with the intent to instruct the assistant to describe the image content. We prompt GPT-4 to curate such a list of questions (see details in Appendix). Therefore, a simple way to expand an image-text pair to its instruction-following version is Human: Xq Xv Assistant: Xc . Though cheap to construct, this simple expanded version lacks diversity and in-depth reasoning in both the instructions and responses.
To mitigate this issue, we leverage language-only GPT-4 or ChatGPT as the strong teacher (both accept only text as input), to create instruction-following data involving visual content. Specifically, in order to encode an image into its visual features to prompt a text-only GPT, we use two types of symbolic representations: (i) Captions typically describe the visual scene from various perspectives; (ii) Bounding boxes usually localize the objects in the scene, and each box encodes the object concept and its spatial location.
解析:
数据生成的关键技巧在这里:由于GPT-4是纯文本模型(2023年4月,GPT-4V还没有),无法直接"看"图像。作者的做法是用两种文本来表示图像:标注(描述图像场景)和边界框(定位物体及空间位置)。这样就把一个视觉问题转化为了一个文本问题,让text-only的GPT-4能够理解图像内容并生成对话数据。
这个技巧非常值得学习:当你没有合适的工具直接处理某种模态时,可以先把该模态转换成另一种可处理的表示形式。
原文段落三(三种数据类型):
We collect 158K unique language-image instruction-following samples in total, including 58K in conversations, 23K in detailed description, and 77k in complex reasoning, respectively.
Conversation. We design a conversation between the assistant and a person asking questions about this photo. The answers are in a tone as if the assistant is seeing the image and answering the question. A diverse set of questions are asked about the visual content of the image, including the object types, counting the objects, object actions, object locations, relative positions between objects. Only questions that have definite answers are considered.
Detailed description. To include a rich and comprehensive description for an image, we create a list of questions with such an intent. We prompt GPT-4 then curate the list. For each image, we randomly sample one question from the list to ask GPT-4 to generate the detailed description.
Complex reasoning. The above two types focus on the visual content itself, based on which we further create in-depth reasoning questions. The answers typically require a step-by-step reasoning process by following rigorous logic.
解析:
三种数据类型覆盖了不同层次的视觉理解需求:
- 对话(58K):基础问答,关注具体可观察的事实(物体类型、数量、位置)
- 详细描述(23K):综合描述图像,类似image captioning但更丰富
- 复杂推理(77K):需要多步推理,类似"这张图里的人可能在做什么?"这类需要推断的问题
比例上复杂推理最多(77K),这说明作者非常重视推理能力,不只是让模型学会描述图像。
Section 4: Visual Instruction Tuning(核心章节)
4.1 Architecture(网络架构)
原文:
The primary goal is to effectively leverage the capabilities of both the pre-trained LLM and visual model. We choose Vicuna as our LLM fφ(·) parameterized by φ, as it has the best instruction following capabilities in language tasks among publicly available checkpoints.
For an input image Xv, we consider the pre-trained CLIP visual encoder ViT-L/14, which provides the visual feature Zv = g(Xv). The grid features before and after the last Transformer layer are considered in our experiments. We consider a simple linear layer to connect image features into the word embedding space. Specifically, we apply a trainable projection matrix W to convert Zv into language embedding tokens Hv, which have the same dimensionality as the word embedding space in the language model:
Hv = W · Zv, with Zv = g(Xv)
Thus, we have a sequence of visual tokens Hv. Note that our simple projection scheme is lightweight, which allows us to iterate data centric experiments quickly. More sophisticated schemes to connect the image and language representations can also be considered, such as gated cross-attention in Flamingo and Q-former in BLIP-2. We leave exploring possibly more effective and sophisticated architecture designs for LLaVA as future work.
精读解析:
这一段是整篇论文最值得反复品读的段落。架构设计体现了"奥卡姆剃刀"原则:
整个架构可以用一个公式表达:
Hv = W · Zv, 其中 Zv = g(Xv)
Xv:原始图像输入g(·):CLIP ViT-L/14 视觉编码器(冻结参数)Zv:CLIP输出的视觉特征,shape为[N_patches, 1024](N_patches=576,对于336×336输入)W:可训练的线性投影矩阵,shape为[1024, 4096]Hv:输出的视觉token序列,shape为[576, 4096],维度与LLM的词嵌入对齐
为什么选择一个简单的Linear层?
作者明确说了:这是为了快速迭代数据实验。Linear层参数量极少,不会主导性能,使得实验结论更多反映的是"数据"的作用,而不是"架构"的作用。这是一种很严谨的科学态度。
作者也承认了局限性:Flamingo的gated cross-attention、BLIP-2的Q-former是更复杂的架构,但LLaVA把这些留作"future work"。事实上,LLaVA 1.5就把Linear换成了2层MLP,性能显著提升。
4.2 Training(训练策略)
原文段落一(指令格式):
For each image Xv, we generate multi-turn conversation data (Xq¹, Xa¹, ···, Xq^T, Xa^T), where T is the total number of turns. We organize them as a sequence, by treating all answers as the assistant's response, and the instruction Xinstruct^t at the t-th turn as:
Xinstruct^t = { Randomly choose [Xq¹, Xv] or [Xv, Xq¹], the first turn t=1; Xq^t, the remaining turns t>1 }
解析:
这描述了多轮对话的数据格式。图像只在第一轮被放入(要么放在问题前,要么放在问题后,随机选择------这增加了数据多样性),之后的轮次只有文本。这样设计的原因是:在对话中图像是上下文的一部分,模型需要学会在整个对话过程中"记住"图像信息,而不是每轮都重新处理图像。
原文段落二(损失函数):
We perform instruction-tuning of the LLM on the prediction tokens, using its original auto-regressive training objective.
Specifically, for a sequence of length L, we compute the probability of the target answers Xa by:
p(Xa | Xv, Xinstruct) = ∏_{i=1}^{L} pθ(xi | Xv, Xinstruct,<i, Xa,<i)
解析:
损失函数是标准的自回归语言建模目标(next token prediction)。但有一个关键细节:只对"助手的回答"部分计算损失,对"Human"的提问和图像token部分不计算损失。
这和普通的语言模型训练的区别在于:普通LM训练要预测所有token,而指令微调只监督模型的"回答"部分。这个设计使得模型学会"给定问题和图像,生成合适的回答",而不是学习去预测问题本身。
原文段落三(两阶段训练):
Stage 1: Pre-training for Feature Alignment. To strike a balance between concept coverage and training efficiency, we filter CC3M to 595K image-text pairs. These pairs are converted to the instruction-following data using the naive expansion method described in Section 3. Each sample can be treated as a single-turn conversation. To construct the input Xinstruct, for an image Xv, a question Xq is randomly sampled, which is a language instruction to request the assistant to describe the image briefly. The ground-truth prediction answer Xa is the original caption. In training, we keep both the visual encoder and LLM weights frozen, and maximize the likelihood with trainable parameters θ = W (the projection matrix) only. In this way, the image features Hv can be aligned with the pre-trained LLM word embedding. This stage can be understood as training a compatible visual tokenizer for the frozen LLM.
解析:
Stage 1:特征对齐预训练
这一阶段只训练投影矩阵W,所有其他参数全部冻结。
- 数据:595K图文对(从CC3M过滤)
- 训练参数:只有W(投影矩阵)
- 学习率:2e-3(较大,因为W是随机初始化的)
- 训练轮次:1 epoch
目的是什么?
让投影矩阵学会把CLIP的视觉特征空间"翻译"到LLM的词嵌入空间。作者把这个阶段比喻为"训练一个与冻结LLM兼容的视觉分词器"(visual tokenizer)。这个类比很精准------就像文本分词器把单词转换成token ID,投影矩阵把视觉特征转换成语言模型能理解的"视觉token"。
原文段落四:
Stage 2: Fine-tuning End-to-End. We always keep the visual encoder weights frozen, and continue to update both the pre-trained weights of the projection layer and LLM in LLaVA; i.e., the trainable parameters are θ = {W, φ}.
We consider two specific use case scenarios:
Multimodal Chatbot. We develop a Chatbot by fine-tuning on the 158K language-image instruction-following data. Among the three types of responses, conversation is multi-turn while the other two are single-turn. They are uniformly sampled in training.
Science QA. We study our method on the ScienceQA benchmark, the first large-scale multimodal science question dataset that annotates the answers with detailed lectures and explanations. Each question is provided a context in the form of natural language or an image. The assistant provides the reasoning process in natural language and selects the answer among multiple choices. For training, we organize the data as a single turn conversation, the question & context as Xinstruct, and reasoning & answer as Xa.
解析:
Stage 2:端到端微调
这一阶段解冻LLM(Vicuna),投影矩阵和LLM一起训练,CLIP视觉编码器继续保持冻结。
- 训练参数:W + LLM(Vicuna)参数
- 学习率:2e-5(较小,精细微调LLM)
- 训练轮次:3 epochs
为什么CLIP始终冻结?
因为CLIP已经在巨量数据上训练好了,它的特征质量很高,在小数据上微调反而可能破坏泛化性。而且冻结CLIP可以大幅减少计算量(CLIP ViT-L/14是很大的模型)。
两阶段训练策略是LLaVA的一个工程亮点:先对齐特征空间,再联合微调,分而治之,效果稳定。
Section 5: Experiments(实验)
原文段落一:
We assess the performance of LLaVA in instruction-following and visual reasoning capabilities with two primary experimental settings: multimodal chatbot and the ScienceQA dataset, respectively. We train all models with 8× A100s, following Vicuna's hyperparameters. We pre-train our model on the filtered CC-595K subset for 1 epoch with a learning rate of 2e-3 and a batch size of 128, and fine-tune on the proposed LLaVA-Instruct-158K dataset for 3 epochs, with a learning rate of 2e-5 and a batch size of 32.
解析:
实验配置:8张A100,Stage 1大约4小时,Stage 2大约10小时(当时这在学术界已经是相对平民化的训练配置)。这个结果告诉我们:LLaVA的训练效率非常高,后来开放复现的成本很低,这是它社区影响力大的重要原因之一。
原文段落二(定量评估方法):
To gain a systematic understanding of the performance of LLaVA, we propose a quantitative metric to measure the model's instruction-following capability on multimodal data. Inspired by, we leverage GPT-4 to measure the quality of generated responses. Specifically, we create triplets consisting of image, ground-truth textual descriptions, and question. The candidate models predict the answers based on the question and the image. To provide an approximate theoretical upper bound, we create a reference prediction based on the question and the ground-truth textual descriptions, using the text-only GPT-4.
LLaVA-Bench (COCO). We randomly select 30 images from COCO-Val-2014, and for each image, we generate three types of questions (conversation, detailed description, complex reasoning) using the proposed data generation pipeline, totaling 90 questions. This benchmark studies the model's alignment behavior and capabilities with consistent visual inputs.
LLaVA-Bench (In-the-Wild). To evaluate the model's capability in more challenging tasks and generalizability to novel domains, we collect a diverse set of 24 images with 60 questions in total, including indoor and outdoor scenes, memes, paintings, sketches, etc., and associate each image with a highly-detailed and manually-curated description and a proper selection of questions.
解析:
这里的评估方法很有创意:用text-only GPT-4作为"理论上界"------给GPT-4看图像的文字描述(而不是图像本身),让它回答同样的问题。然后用GPT-4来对比LLaVA的回答和这个"上界"的回答,打分。
这个评估框架有一个隐含假设:text-only GPT-4在看到图像文字描述后的回答质量 ≥ 真正能看图的模型(因为文字描述是GT,信息更完整)。这个假设基本成立,所以这个"理论上界"是合理的。
原文段落三(消融实验):
Ablations. We ablate several design choices: (i) Visual features: using the last layer feature yields 89.96% and is 0.96% lower than the feature before the last layer. (ii) Chain-of-thought: answer-first reports the best number 89.77% accuracy, while reasoning-first can quickly reach 89.77% accuracy in 6 epochs. (iii) Pre-training: skipping pre-training and directly training on Science QA from scratch -- performance drops to 85.81% accuracy. The 5.11% absolute degradation indicates the importance of our pre-training stage. (iv) Model size: training a 7B model yields 89.84% accuracy, which is 1.08% lower than 90.92%, demonstrating the importance of model scale.
解析:
消融实验给出了几个重要结论:
-
取倒数第二层而不是最后一层的视觉特征(差0.96%):CLIP最后一层的特征经过了更多抽象,可能丢失了一些细节信息,倒数第二层的特征更适合用于视觉理解任务。
-
Pre-training很重要(差5.11%):跳过Stage 1直接做Stage 2,性能大幅下降。这证明了特征对齐预训练的必要性。
-
模型规模有影响(7B vs 13B差1.08%):更大的LLM确实带来性能提升,但差距不是很大,说明数据质量可能比模型规模更重要。
Section 6: Conclusion(结论)
原文:
This paper demonstrated the effectiveness of visual instruction tuning. We presented an automatic pipeline to create language-image instruction-following data, based on which we train LLaVA, a multimodal model to follow human intent to complete visual tasks. It achieves the new SoTA accuracy when fine-tuned on ScienceQA, and excellent visual chat capabilities when fine-tuned on multimodal chat data. Besides, we present the first benchmark to study multimodal instruction-following capability. This paper is an initial step in visual instruction tuning, and mainly focuses on real-life tasks. For more quantitative results of LLaVA on academic benchmarks, please refer to the improved baselines with visual instruction tuning. We hope our work can inspire future research on building more capable multimodal models.
解析:
结论部分非常克制,没有过度吹嘘。最后一句"This paper is an initial step"表明作者清楚地认识到这项工作的局限性------LLaVA只是一个起点,架构上有很多值得改进的地方(后来的LLaVA 1.5、LLaVA-Next等版本都对此做了大量改进)。
二、LLaVA 模型网络结构深度解析
基于 LLaVA 官方仓库 haotian-liu/LLaVA 源代码分析
分析版本:LLaVA v1.5(代码库最新版,架构与原始论文一致,Projector升级为MLP)
2.1 整体架构数据流
LLaVA 完整数据流
┌──────────────────────────────────────────────────────────────────┐
│ │
│ 图像输入 文本输入 │
│ images input_ids │
│ [B, 3, 336, 336] [B, L_txt] │
│ (CLIP 336px 输入) (含1个 <IMAGE> 占位符) │
│ │ │ │
│ ▼ │ │
│ ┌─────────────────────────┐ │ │
│ │ Step 1: CLIP ViT-L/14 │ │ │
│ │ 视觉编码器(冻结) │ │ │
│ │ 336×336 → 576个patch │ │ │
│ └────────────┬────────────┘ │ │
│ │ [B, 576, 1024] │ │
│ ▼ │ │
│ ┌─────────────────────────┐ │ │
│ │ Step 2: MLP Projector │ │ │
│ │ (2层线性+GELU) │ │ │
│ │ 1024 → 4096 │ │ │
│ └────────────┬────────────┘ │ │
│ │ [B, 576, 4096] │ │
│ └──────────────┬───────────┘ │
│ ▼ │
│ Step 3: prepare_inputs_labels_for_multimodal() │
│ 把<IMAGE>占位符替换为576个视觉token │
│ 拼接: [System] + [前缀文本] + [576×视觉token] + [后缀]│
│ │ [B, L_total, 4096] │
│ │ L_total = L_txt - 1 + 576 │
│ ▼ │
│ ┌─────────────────────────────┐ │
│ │ Step 4: LLaMA (7B/13B) │ │
│ │ 32层 Transformer Decoder │ │
│ │ hidden_size = 4096 │ │
│ └───────────────┬─────────────┘ │
│ │ [B, L_total, 4096] │
│ ▼ │
│ ┌─────────────────────────────┐ │
│ │ Step 5: lm_head Linear │ │
│ │ 4096 → 32000 (vocab_size) │ │
│ └───────────────┬─────────────┘ │
│ │ [B, L_total, 32000] │
│ ▼ │
│ logits / loss │
└──────────────────────────────────────────────────────────────────┘
2.2 模块一:CLIP 视觉编码器(clip_encoder.py)
源码(核心部分)
python
# 文件:llava/model/multimodal_encoder/clip_encoder.py
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower # "openai/clip-vit-large-patch14-336"
self.select_layer = args.mm_vision_select_layer # 默认 -2(倒数第二层)
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') # 'patch' 表示去掉CLS token
if not delay_load:
self.load_model()
def load_model(self, device_map=None):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False) # ← 冻结所有参数!
self.is_loaded = True
def feature_select(self, image_forward_outs):
# 选取指定层的特征
image_features = image_forward_outs.hidden_states[self.select_layer] # 取倒数第二层
if self.select_feature == 'patch':
image_features = image_features[:, 1:] # 去掉第一个token(CLS token),保留patch tokens
elif self.select_feature == 'cls_patch':
image_features = image_features # 保留CLS token + patch tokens
return image_features
@torch.no_grad() # ← 推理时不计算梯度(编码器冻结)
def forward(self, images):
# images: [B, 3, 336, 336]
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True # 需要所有层的隐藏状态,以便选择特定层
)
# 选择特征层并去掉CLS token
image_features = self.feature_select(image_forward_outs).to(images.dtype)
# image_features: [B, 576, 1024]
return image_features
@property
def hidden_size(self):
return self.config.hidden_size # 1024(ViT-L的隐藏层维度)
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
# = (336 // 14)² = 24² = 576
维度变化详解
输入图像:[B, 3, 336, 336]
↓
将图像切成 patch(patch_size=14):
行方向:336 / 14 = 24 个 patch
列方向:336 / 14 = 24 个 patch
共 24 × 24 = 576 个 patch
↓
加入 CLS token(位置0):总序列长度 = 577
↓
ViT 内部处理(注意力机制,32层Transformer):
输入:[B, 577, 1024](含CLS)
输出(hidden_states):每层都是 [B, 577, 1024]
↓
取倒数第二层(select_layer = -2):
[B, 577, 1024]
↓
去掉 CLS token(index 0),保留 patch token:
[B, 576, 1024]
↓
最终输出:[B, 576, 1024]
--- 576个视觉特征向量,每个维度为1024
关键设计决策:
- 使用倒数第二层而非最后一层:消融实验显示最后一层的特征更"全局化",丢失了局部细节,倒数第二层特征对下游视觉问答更有利。
- 去掉 CLS token:CLS token是CLIP做对比学习时用于整图语义的全局表示,对于描述图像细节的VQA任务,patch-level的特征更重要。
2.3 模块二:MLP Projector(multimodal_projector/builder.py)
源码
python
# 文件:llava/model/multimodal_projector/builder.py
import torch.nn as nn
import re
class IdentityMap(nn.Module):
"""恒等映射,直接返回输入(用于测试或特殊场景)"""
def forward(self, x, *args, **kwargs):
return x
class SimpleResBlock(nn.Module):
"""残差块(在当前代码中作为备用,未被build_vision_projector调用)"""
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
# ① 单线性层(LLaVA v1 原版)
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
# mm_hidden_size = 1024(CLIP输出)
# hidden_size = 4096(LLaMA隐藏层)
# 参数量 = 1024 × 4096 = 4,194,304 ≈ 400万
# ② MLP(LLaVA v1.5 改进版,默认 mlp2x_gelu)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1)) # mlp2x_gelu → depth=2
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
# 2层MLP结构:Linear(1024→4096) → GELU → Linear(4096→4096)
# ③ 恒等映射(调试用)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
维度变化详解
CLIP 输出:[B, 576, 1024]
↓
Projector 内部(以 mlp2x_gelu 为例):
Layer 1: nn.Linear(1024, 4096)
输入:[B, 576, 1024] → 输出:[B, 576, 4096]
参数量:1024 × 4096 + 4096 = 4,198,400
Activation: GELU
输入:[B, 576, 4096] → 输出:[B, 576, 4096](形状不变)
Layer 2: nn.Linear(4096, 4096)
输入:[B, 576, 4096] → 输出:[B, 576, 4096]
参数量:4096 × 4096 + 4096 = 16,781,312
↓
Projector 输出:[B, 576, 4096]
--- 576个视觉token,每个维度为4096(与LLaMA的词嵌入维度相同)
原版(linear)vs 改进版(mlp2x_gelu):
| 版本 | 结构 | 参数量 | ScienceQA准确率 |
|---|---|---|---|
| LLaVA v1(论文原版) | Linear(1024→4096) | ~400万 | 90.92% |
| LLaVA v1.5 | Linear+GELU+Linear | ~2100万 | 更高 |
改用2层MLP的核心原因:单线性层的表达能力不足,很难把1024维的视觉特征完美映射到4096维的语言空间。加一个非线性激活函数后,投影层能学到更复杂的视觉-语言对齐关系。
2.4 模块三:多模态 Token 拼接(llava_arch.py 核心函数)
这是整个LLaVA中最复杂、最关键的函数,负责把视觉token和文本token拼接成一个统一的序列。
encode_images(图像编码入口)
python
# 文件:llava/model/llava_arch.py
def encode_images(self, images):
# 调用CLIP视觉编码器
image_features = self.get_model().get_vision_tower()(images)
# image_features: [B, 576, 1024]
# 调用MLP Projector
image_features = self.get_model().mm_projector(image_features)
# image_features: [B, 576, 4096]
return image_features
prepare_inputs_labels_for_multimodal(核心融合函数)
这个函数实现了"把 <IMAGE> 占位符替换为真实视觉 token"的逻辑。
python
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
images, image_sizes=None
):
# ── 输入 ──
# input_ids: [B, L_txt],其中某个位置是 IMAGE_TOKEN_INDEX(特殊占位符,值为-200)
# attention_mask: [B, L_txt]
# labels: [B, L_txt],Human部分是IGNORE_INDEX,Assistant部分是真实token id
# images: [B, 3, 336, 336]
vision_tower = self.get_vision_tower()
# --- Step A: 图像编码 ---
# 如果图像是列表(多图),先拼起来统一编码,再拆分
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images) # [total_images, 576, 4096]
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
# image_features: list of [n_images_per_sample, 576, 4096]
# flat模式:把多个图像的patch token展平成一个序列
if mm_patch_merge_type == 'flat':
image_features = [x.flatten(0, 1) for x in image_features]
# 每个: [576*n_images, 4096]
else:
# 单图情况
image_features = self.encode_images(images) # [B, 576, 4096]
# --- Step B: 去除padding,逐样本处理 ---
# 用attention_mask去掉padding token
input_ids = [cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
# --- Step C: 逐样本替换 <IMAGE> 占位符 ---
for batch_idx, cur_input_ids in enumerate(input_ids):
# 找到所有 IMAGE_TOKEN_INDEX(值为-200)的位置
image_token_indices = [-1] + \
torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + \
[cur_input_ids.shape[0]]
# 示例:input_ids = [w1, w2, -200, w3, w4]
# image_token_indices = [-1, 2, 5] (-1是虚拟起始,5是结尾)
# 把文本分成图像token前后的片段
cur_input_ids_noim = [] # 去掉IMAGE_TOKEN后的文本片段列表
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
# 切出两个相邻IMAGE_TOKEN之间的文本
cur_input_ids_noim.append(
cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
# cur_input_ids_noim = [[w1, w2], [w3, w4]]
# 把所有文本片段拼起来,统一做token embedding
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
# cur_input_embeds: [L_text_total, 4096]
# 再拆回各个片段
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
# cur_input_embeds_no_im = [[L1, 4096], [L2, 4096]]
# --- Step D: 插入视觉token ---
cur_new_input_embeds = []
cur_new_labels = []
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
for i in range(num_images + 1):
# 先放文本片段
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
# 再插入视觉token
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
# 对图像token的label全部设为IGNORE_INDEX(不计算损失)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],), # 576
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype
)
)
# 拼接成完整序列
# 示例:[文本前缀嵌入] + [576个视觉token] + [文本后缀嵌入]
cur_new_input_embeds = torch.cat(cur_new_input_embeds) # [L_total, 4096]
cur_new_labels = torch.cat(cur_new_labels) # [L_total]
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# --- Step E: padding到batch内最长序列 ---
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
# 创建padding后的张量
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, ...)
attention_mask = torch.zeros((batch_size, max_len), ...)
position_ids = torch.zeros((batch_size, max_len), ...)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
# 右padding(默认)或左padding
new_input_embeds_padded.append(
torch.cat([cur_new_embed, torch.zeros((max_len - cur_len, ...))])
)
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
# new_input_embeds: [B, max_len, 4096] ← 这就是送入LLaMA的最终输入
# ── 输出 ──
# input_ids: None(已经用input_embeds替代)
# position_ids: [B, max_len]
# attention_mask: [B, max_len]
# inputs_embeds: [B, max_len, 4096]
# labels: [B, max_len](图像位置为IGNORE_INDEX,不计入损失)
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
Token 拼接示意图
原始输入(含占位符):
input_ids = [101, 2023, 2003, -200, 1037, 3307, ...]
SOS "This" "is" <IMG> "a" "cat" ...
└── 文本token ──┘ ↑ └── 文本token ──┘
IMAGE_TOKEN_INDEX
经过 prepare_inputs_labels_for_multimodal() 后:
inputs_embeds = [
embed("This"), # [4096] ← 文本前缀
embed("is"), # [4096]
proj(clip_feat_1), # [4096] ← 视觉token 1
proj(clip_feat_2), # [4096] ← 视觉token 2
...(共576个视觉token)...
proj(clip_feat_576), # [4096] ← 视觉token 576
embed("a"), # [4096] ← 文本后缀
embed("cat"), # [4096]
...
]
# 总 shape:[B, L_txt - 1 + 576, 4096]
对应的 labels(训练时):
labels = [
IGNORE, # "This"(Human输入,不计损失)
IGNORE, # "is"
IGNORE, # 视觉token 1(不计损失)
...(共576个IGNORE)...
IGNORE, # 视觉token 576
IGNORE, # "a"(Human输入,不计损失)
IGNORE, # "cat"
tok_id("Response starts here"), # ← 从这里开始计算损失
tok_id("..."),
...
]
2.5 模块四:LLaMA 语言解码器(llava_llama.py)
类继承关系
python
# 文件:llava/model/language_model/llava_llama.py
# LLaVA 模型主体(多重继承)
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
# LlavaMetaModel:负责初始化视觉编码器和Projector
# LlamaModel:标准LLaMA Transformer主干(32层)
config_class = LlavaConfig
def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)
# LLaVA 完整CausalLM(带lm_head)
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
# LlamaForCausalLM:负责lm_head和loss计算
# LlavaMetaForCausalLM:负责多模态输入处理
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = LlavaLlamaModel(config) # LLaMA主干 + 视觉模块
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# lm_head: Linear(4096, 32000)
self.post_init()
def get_model(self):
return self.model
forward 函数(完整流程)
python
def forward(
self,
input_ids: torch.LongTensor = None, # 文本token IDs [B, L]
attention_mask: Optional[torch.Tensor] = None, # [B, L]
position_ids: Optional[torch.LongTensor] = None, # [B, L]
past_key_values: Optional[List[torch.FloatTensor]] = None, # 推理加速KV缓存
inputs_embeds: Optional[torch.FloatTensor] = None, # 直接输入embedding
labels: Optional[torch.LongTensor] = None, # [B, L]
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None, # ⭐ 图像输入 [B, 3, 336, 336]
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# ── 多模态处理阶段 ──
if inputs_embeds is None:
(
input_ids, # → None(被inputs_embeds取代)
position_ids, # → [B, L_total]
attention_mask, # → [B, L_total]
past_key_values, # → 不变
inputs_embeds, # → [B, L_total, 4096] ⭐ 核心输出
labels # → [B, L_total]
) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask,
past_key_values, labels, images, image_sizes
)
# 此时 inputs_embeds 已经是融合了视觉token的完整序列
# ── LLaMA 标准前向传播 ──
return super().forward(
input_ids=input_ids, # None(因为已用inputs_embeds)
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds, # [B, L_total, 4096] 送入LLaMA
labels=labels, # [B, L_total] 只计算助手回答部分的loss
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 内部:
# LLaMA 32层 Transformer Decoder 处理 inputs_embeds
# → hidden_states: [B, L_total, 4096]
# → lm_head(hidden_states): [B, L_total, 32000]
# → logits: [B, L_total, 32000]
# → 对label部分计算 CrossEntropyLoss
2.6 完整前向传播维度变化总表
| 步骤 | 操作 | 输入 shape | 输出 shape | 说明 |
|---|---|---|---|---|
| 0 | 图像预处理 | PIL Image | [B, 3, 336, 336] | CLIP ImageProcessor resize + normalize |
| 1 | CLIP ViT-L/14 | [B, 3, 336, 336] | [B, 577, 1024] | 含CLS token,576个patch token |
| 2 | 去掉CLS token | [B, 577, 1024] | [B, 576, 1024] | select_layer=-2,select_feature='patch' |
| 3 | MLP Projector(Linear) | [B, 576, 1024] | [B, 576, 4096] | 原版:单线性层 |
| 3' | MLP Projector(mlp2x) | [B, 576, 1024] | [B, 576, 4096] | v1.5:Linear→GELU→Linear |
| 4 | Text embed_tokens | [B, L_txt] | [B, L_txt, 4096] | LLaMA词嵌入层(32000词表) |
| 5 | Token拼接 | [B, L_txt, 4096] + [B, 576, 4096] | [B, L_total, 4096] | L_total = L_txt - 1 + 576 |
| 6 | LLaMA 32层Decoder | [B, L_total, 4096] | [B, L_total, 4096] | 标准因果语言模型 |
| 7 | lm_head | [B, L_total, 4096] | [B, L_total, 32000] | 输出词表logits |
| 8 | CrossEntropyLoss | [B, L_total, 32000] | scalar | 只对assistant回答部分计算 |
典型数值(单张图片,问题10个token,回答50个token):
L_txt = 1(system) + 10(问题) + 1(IMAGE) + 50(回答) + 特殊token = ~70
L_total = 70 - 1(去掉IMAGE占位) + 576(视觉token) = 645
输入到LLaMA的序列长度:645个token
其中:576个是视觉token,~69个是文本token
2.7 训练参数冻结策略总结
| 模块 | Stage 1 | Stage 2 | 参数量 |
|---|---|---|---|
| CLIP ViT-L/14(视觉编码器) | ❄️ 冻结 | ❄️ 冻结 | ~307M |
| MLP Projector | 🔥 训练 | 🔥 训练 | ~4M(linear)/ ~21M(mlp2x) |
| LLaMA 7B(语言解码器) | ❄️ 冻结 | 🔥 训练 | ~7B |
| 可训练参数总量 | ~4M | ~7B |
Stage 1只训练约400万参数(投影矩阵),所以极快(4小时);Stage 2解冻了整个LLM(70亿参数),训练时间约10小时。
2.8 推理流程(generate函数)
python
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None, # input_ids [B, L]
images: Optional[torch.Tensor] = None, # [B, 3, 336, 336]
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
# 和训练时相同的多模态处理
if images is not None:
(
inputs, position_ids, attention_mask,
_, inputs_embeds, _
) = self.prepare_inputs_labels_for_multimodal(
inputs, position_ids, attention_mask,
None, None, images, image_sizes=image_sizes
)
else:
# 纯文本模式
inputs_embeds = self.get_model().embed_tokens(inputs)
# 调用标准LLaMA的autoregressive generate
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs # max_new_tokens, temperature, top_p等生成参数
)
推理时的流程与训练完全一致(prepare_inputs_labels_for_multimodal是同一个函数),只是最后调用generate做自回归解码,而不是计算loss。
总结
LLaVA 的核心设计哲学
LLaVA的架构之所以成功,在于其"极简连接,数据驱动"的思路:
-
视觉编码器(CLIP ViT-L/14):只负责把图像变成576个高质量特征向量,全程冻结,不需要重新训练。
-
投影层(Linear/MLP):只有几百万到几千万参数,负责做特征空间的"翻译"------把视觉空间映射到语言空间。
-
LLM(Vicuna/LLaMA):负责理解图文上下文并生成回答,是模型能力的主要来源。
-
数据格式:通过把图像的图文对转换成"Human: [文本+视觉token] \n Assistant: [回答]"的指令格式,让LLM学会了多模态对话。
整个模型可以理解为:CLIP是眼睛,Projection是神经连接,LLM是大脑。视觉信息通过投影层"翻译"成LLM能理解的语言,然后LLM用自己的语言理解和推理能力来回答问题。