【BLIP论文阅读】:统一视觉语言理解与生成的自举式预训练范式

论文信息


引言:多模态模型的"偏科"与"挑食"难题

想象一下,你有两个AI助手:一个能精准判断"这张图里有没有猫",但让它写一句描述图片的话就只会说"a cat";另一个能写出优美的图片描述,但让它从100张图里找出猫的图片就会频频出错。这就是2022年之前多模态预训练(VLP)模型的真实写照------要么擅长理解类任务(如CLIP、ALBEF),要么擅长生成类任务(如SimVLM),没有一个能做到全能

更糟糕的是,这些模型都有严重的"挑食"问题:它们只能吃"干净"的人工标注数据,但人工标注太贵了,所以大家只能去网上爬海量的图文对。可这些网页数据就像路边摊的小吃,虽然量大但满是"地沟油"------很多文本根本不描述图片内容,比如一张风景图配文"今天心情真好"。之前的方法只能用简单的规则过滤掉最明显的垃圾,剩下的噪声还是会严重影响模型性能。

BLIP的作者们一次性解决了这两个难题:

  1. 提出了多模态混合编码器-解码器(MED)架构,让同一个模型既能做理解又能做生成,彻底告别偏科
  2. 发明了字幕生成与过滤(CapFilt)技术,能把网上的噪声数据"提纯"成高质量训练数据,让模型吃得香又长得壮

一、MED架构:一个模型打天下

BLIP的核心是多模态混合编码器-解码器(Multimodal Mixture of Encoder-Decoder, MED) ,它就像一个变形金刚,能根据不同任务切换成三种不同的工作模式。

图片1:BLIP预训练架构与目标

出处:论文图2

1.1 三种工作模式

模式1:单模态编码器(Unimodal Encoder)
  • 功能:分别编码图像和文本,让它们的特征空间对齐
  • 输入 :图像 + 文本(开头加[CLS] token)
  • 训练目标:图像-文本对比损失(ITC)
  • 通俗解释:就像让两个说不同语言的人先学一些基础词汇,知道"cat"对应猫的图片,"dog"对应狗的图片
模式2:图像引导的文本编码器(Image-grounded Text Encoder)
  • 功能:融合图像和文本信息,学习细粒度的跨模态交互
  • 输入 :图像 + 文本(开头加[Encode] token)
  • 特殊设计 :在文本编码器的每个Transformer块里加了一个交叉注意力层,让文本token能"看到"图像token
  • 训练目标:图像-文本匹配损失(ITM)
  • 通俗解释:就像让两个人坐下来聊天,不仅知道对方说的词是什么意思,还能理解整句话的含义
模式3:图像引导的文本解码器(Image-grounded Text Decoder)
  • 功能:根据图像生成文本描述
  • 输入 :图像 + 文本(开头加[Decode] token)
  • 特殊设计 :把双向自注意力换成了因果自注意力(只能看到前面的token,不能看到后面的),这样才能自回归生成文本
  • 训练目标:语言建模损失(LM)
  • 通俗解释:就像让一个人看着图片讲故事,只能一个字一个字地说,不能提前剧透后面的内容

1.2 三大预训练目标

BLIP同时训练三个目标,总损失是三个损失的和:
L=Litc+Litm+Llm\mathcal{L} = \mathcal{L}{itc} + \mathcal{L}{itm} + \mathcal{L}_{lm}L=Litc+Litm+Llm

目标1:图像-文本对比损失(ITC)

ITC损失的目标是让匹配的图文对在特征空间里靠得更近,不匹配的离得更远。它借鉴了ALBEF的动量编码器思想,用动量模型生成软标签来处理潜在的正样本。

公式1:ITC损失
Litc=12E(I,T)∼D[H(yi2t,pi2t)+H(yt2i,pt2i)]\mathcal{L}{itc} = \frac{1}{2} \mathbb{E}{(I, T) \sim D} \left[ H(y^{i2t}, p^{i2t}) + H(y^{t2i}, p^{t2i}) \right]Litc=21E(I,T)∼D[H(yi2t,pi2t)+H(yt2i,pt2i)]

  • Litc\mathcal{L}_{itc}Litc:ITC总损失,是图像到文本和文本到图像损失的平均值
  • HHH:交叉熵损失函数
  • yi2ty^{i2t}yi2t:图像III对应的真实文本的one-hot标签
  • pi2tp^{i2t}pi2t:模型预测的图像III对应各个文本的概率分布
  • yt2iy^{t2i}yt2i:文本TTT对应的真实图像的one-hot标签
  • pt2ip^{t2i}pt2i:模型预测的文本TTT对应各个图像的概率分布
目标2:图像-文本匹配损失(ITM)

ITM损失是一个二分类任务,让模型判断一个图文对是匹配的还是不匹配的。BLIP用了硬负样本挖掘策略:选择那些语义相似但不匹配的负样本,这样训练出来的模型更鲁棒。

公式2:ITM损失
Litm=E(I,T)∼DH(yitm,pitm(I,T))\mathcal{L}{itm} = \mathbb{E}{(I, T) \sim D} H(y^{itm}, p^{itm}(I, T))Litm=E(I,T)∼DH(yitm,pitm(I,T))

  • Litm\mathcal{L}_{itm}Litm:ITM损失
  • yitmy^{itm}yitm:图文对的真实标签(匹配为1,不匹配为0)
  • pitm(I,T)p^{itm}(I, T)pitm(I,T):模型预测的图文对匹配的概率
目标3:语言建模损失(LM)

LM损失让模型学会根据图像生成文本描述。它优化的是交叉熵损失,训练模型最大化真实文本的似然概率。

公式3:LM损失
Llm=E(I,T)∼DH(ylm,plm(I,T))\mathcal{L}{lm} = \mathbb{E}{(I, T) \sim D} H(y^{lm}, p^{lm}(I, T))Llm=E(I,T)∼DH(ylm,plm(I,T))

  • Llm\mathcal{L}_{lm}Llm:LM损失
  • ylmy^{lm}ylm:真实文本的token序列
  • plm(I,T)p^{lm}(I, T)plm(I,T):模型自回归生成的token概率分布

1.3 参数共享策略

为了提高训练效率,BLIP的文本编码器和解码器共享除了自注意力层之外的所有参数

  • 共享的参数:嵌入层、交叉注意力层、前馈网络
  • 不共享的参数:自注意力层(编码器用双向,解码器用因果)

作者们通过实验证明,这种共享策略既能减少模型参数(从361M降到224M),又能提高性能。如果连自注意力层都共享,模型会因为编码和解码任务的冲突而性能下降。


二、CapFilt:把噪声数据变成黄金

网上爬来的图文对虽然量大,但质量参差不齐。BLIP的作者们发明了字幕生成与过滤 (Captioning and Filtering, CapFilt)技术,能把这些噪声数据"提纯"成高质量的训练数据。

图片2:CapFilt整体流程

出处:论文图3

2.1 CapFilt的工作流程

CapFilt分为三个步骤:

  1. 预训练基础模型:用原始的噪声图文对预训练一个MED模型
  2. 微调两个专家
    • Captioner(字幕生成器):用LM损失在高质量人工标注数据集(如COCO)上微调,学会给图片生成准确的描述
    • Filter(过滤器):用ITC和ITM损失在COCO上微调,学会判断一个文本是否准确描述了图片
  3. 提纯数据集
    • 用Captioner给所有网页图片生成合成字幕TsT_sTs
    • 用Filter过滤掉原始网页字幕TwT_wTw和合成字幕TsT_sTs中的噪声(ITM头预测为不匹配的)
    • 把过滤后的干净字幕和人工标注字幕合并,得到新的训练数据集
  4. 重新训练模型:用提纯后的数据集重新预训练一个MED模型,得到最终的BLIP模型

2.2 多样性比准确性更重要

在生成合成字幕时,作者们发现核采样(Nucleus Sampling)比束搜索(Beam Search)效果更好,尽管核采样生成的字幕噪声更多。

表格1:生成方法对比

生成方法 噪声比例 Retrieval-FT (COCO) TR@1 Caption-FT (COCO) CIDEr
Beam 19% 79.6 127.8
Nucleus 25% 80.6 129.7
出处:论文表2

分析:核采样生成的字幕更多样化,包含了更多新的语义信息,而束搜索倾向于生成安全、常见的字幕。就像老师教学生,不仅要教标准答案,还要教不同的解题思路,这样学生才能举一反三。

2.3 为什么要解耦Captioner和Filter

作者们还发现,Captioner和Filter不能共享参数,否则会出现"确认偏差"------Captioner生成的噪声字幕会被Filter错误地保留下来。

表格2:参数共享对比

Captioner & Filter 噪声比例 Retrieval-FT (COCO) TR@1
共享参数 8% 79.8
解耦 25% 80.6
出处:论文表4

分析:解耦后,Filter能更客观地评价Captioner生成的字幕,就像审稿人不能是作者本人一样,这样才能保证评审的公正性。


三、实验结果:用数据说话

BLIP在多个下游任务上都取得了state-of-the-art的结果,而且只用了比之前模型少得多的数据。

3.1 CapFilt真的有用吗?

表格3:CapFilt效果对比

C F 预训练数据集 Retrieval-FT (COCO) TR@1 Caption-FT (COCO) CIDEr
原始14M 78.4 127.8
原始+合成 79.1 128.2
过滤后原始 79.7 128.9
提纯后14M 80.6 129.7
出处:论文表1

分析

  • 只用Captioner(生成合成字幕),性能提升0.7%
  • 只用Filter(过滤原始字幕),性能提升1.3%
  • 两者一起用,性能提升2.2%
  • 当预训练数据增加到129M时,性能还能进一步提升到81.9%

这说明CapFilt的两个模块是互补的,生成和过滤缺一不可。

3.2 图文检索:用1/30的数据超越CLIP

表格4:零样本图文检索结果(Flickr30K)

Method 预训练图片数 TR R@1 IR R@1
CLIP 400M 88.0 68.7
ALIGN 1.8B 88.6 75.7
ALBEF 14M 94.1 82.8
BLIP 14M 94.8 84.9
BLIP 129M 96.0 85.0
出处:论文表6

分析:BLIP只用了14M预训练图片,就超过了用400M图片训练的CLIP和用1.8B图片训练的ALIGN!这说明BLIP的学习效率比之前的模型高得多,CapFilt功不可没。

3.3 其他下游任务:全面领先

表格5:VQA和NLVR2结果对比

Method 预训练图片数 VQA test-std NLVR2 test-P
ALBEF 14M 76.04 83.14
BLIP 14M 77.62 82.30
BLIP 129M 78.17 83.08
出处:论文表8

分析:在需要复杂推理的VQA和NLVR2任务上,BLIP也全面超越了之前的state-of-the-art方法ALBEF。而且BLIP不需要目标检测器,推理速度比需要检测器的方法快10倍以上。

3.4 零样本视频任务:惊人的泛化能力

BLIP不仅在图像任务上表现出色,还能直接零样本迁移到视频任务上!

表格6:零样本文本到视频检索结果(MSRVTT)

Method R@1 R@5 R@10
VideoCLIP 8.7 22.2 30.0
MIL-NCE 10.4 23.0 32.4
BLIP (零样本) 19.2 43.3 51.6
ClipBERT (微调) 22.0 46.8 59.9
出处:论文表10

分析:BLIP只用了图像数据预训练,零样本迁移到视频任务上就超过了很多专门在视频上微调的模型!这说明BLIP学到的视觉语言表征具有很强的泛化能力。


四、核心代码实现

下面是BLIP核心部分的代码实现,包括MED架构和三个预训练损失:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer, ViTModel, ViTFeatureExtractor

class BLIP(nn.Module):
    def __init__(self, vit_name='google/vit-base-patch16-224', bert_name='bert-base-uncased', embed_dim=256):
        super().__init__()
        # 图像编码器
        self.vision_encoder = ViTModel.from_pretrained(vit_name)
        # 文本编码器/解码器(共享参数)
        self.text_encoder = BertModel.from_pretrained(bert_name, add_pooling_layer=False)
        self.text_decoder = BertModel.from_pretrained(bert_name, add_pooling_layer=False)
        
        # 投影层,用于对比学习
        self.vision_proj = nn.Linear(self.vision_encoder.config.hidden_size, embed_dim)
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, embed_dim)
        
        # ITM分类头
        self.itm_head = nn.Linear(self.text_encoder.config.hidden_size, 2)
        
        # LM头
        self.lm_head = nn.Linear(self.text_decoder.config.hidden_size, self.text_decoder.config.vocab_size)
        
        # 温度参数
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0 / 0.07)))
        
        # 初始化解码器参数(共享除了自注意力之外的层)
        self._init_decoder()

    def _init_decoder(self):
        """初始化解码器,共享除了自注意力之外的参数"""
        for name, param in self.text_encoder.named_parameters():
            if 'attention.self' not in name:
                self.text_decoder.state_dict()[name].copy_(param.data)
        
        # 冻结共享参数的梯度
        for name, param in self.text_decoder.named_parameters():
            if 'attention.self' not in name:
                param.requires_grad = False

    def forward_vision(self, images):
        """编码图像"""
        vision_outputs = self.vision_encoder(images)
        vision_embeds = vision_outputs.last_hidden_state
        vision_cls = vision_embeds[:, 0, :]
        return vision_embeds, vision_cls

    def forward_text_encoder(self, input_ids, attention_mask, vision_embeds):
        """图像引导的文本编码器"""
        # 文本自注意力
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        text_embeds = text_outputs.last_hidden_state
        
        # 交叉注意力(简化版,实际实现需要更复杂的交叉注意力层)
        cross_attn_output = torch.bmm(text_embeds, vision_embeds.transpose(1, 2))
        cross_attn_output = F.softmax(cross_attn_output, dim=-1)
        cross_attn_output = torch.bmm(cross_attn_output, vision_embeds)
        text_embeds = text_embeds + cross_attn_output
        
        return text_embeds

    def forward_text_decoder(self, input_ids, attention_mask, vision_embeds):
        """图像引导的文本解码器(因果自注意力)"""
        # 因果自注意力掩码
        seq_length = input_ids.size(1)
        causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device))
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
        
        # 文本因果自注意力
        text_outputs = self.text_decoder(
            input_ids=input_ids,
            attention_mask=causal_mask,
            output_hidden_states=True
        )
        text_embeds = text_outputs.last_hidden_state
        
        # 交叉注意力
        cross_attn_output = torch.bmm(text_embeds, vision_embeds.transpose(1, 2))
        cross_attn_output = F.softmax(cross_attn_output, dim=-1)
        cross_attn_output = torch.bmm(cross_attn_output, vision_embeds)
        text_embeds = text_embeds + cross_attn_output
        
        return text_embeds

    def forward(self, images, input_ids, attention_mask, labels=None):
        # 编码图像
        vision_embeds, vision_cls = self.forward_vision(images)
        vision_proj = F.normalize(self.vision_proj(vision_cls), dim=-1)
        
        # 编码文本(单模态)
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_cls = text_outputs.last_hidden_state[:, 0, :]
        text_proj = F.normalize(self.text_proj(text_cls), dim=-1)
        
        # 计算ITC损失
        loss_itc = self.itc_loss(vision_proj, text_proj)
        
        # 计算ITM损失
        loss_itm = self.itm_loss(vision_embeds, input_ids, attention_mask)
        
        # 计算LM损失
        loss_lm = 0.0
        if labels is not None:
            decoder_embeds = self.forward_text_decoder(input_ids, attention_mask, vision_embeds)
            logits = self.lm_head(decoder_embeds)
            loss_lm = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        # 总损失
        loss = loss_itc + loss_itm + loss_lm
        
        return loss

    def itc_loss(self, vision_proj, text_proj):
        """计算图像-文本对比损失"""
        logits_i2t = vision_proj @ text_proj.T * self.logit_scale.exp()
        logits_t2i = text_proj @ vision_proj.T * self.logit_scale.exp()
        
        labels = torch.arange(logits_i2t.size(0), device=logits_i2t.device)
        
        loss_i2t = F.cross_entropy(logits_i2t, labels)
        loss_t2i = F.cross_entropy(logits_t2i, labels)
        
        return (loss_i2t + loss_t2i) / 2

    def itm_loss(self, vision_embeds, input_ids, attention_mask):
        """计算图像-文本匹配损失(带硬负样本挖掘)"""
        batch_size = input_ids.size(0)
        
        # 生成硬负样本
        with torch.no_grad():
            vision_proj = F.normalize(self.vision_proj(vision_embeds[:, 0, :]), dim=-1)
            text_proj = F.normalize(self.text_proj(self.text_encoder(input_ids, attention_mask).last_hidden_state[:, 0, :]), dim=-1)
            sim_i2t = vision_proj @ text_proj.T
            sim_t2i = text_proj @ vision_proj.T
            
            # 为每个图像选择一个硬负文本
            weights_i2t = F.softmax(sim_i2t, dim=1)
            weights_i2t.fill_diagonal_(0)
            neg_text_ids = torch.multinomial(weights_i2t, 1).squeeze()
            
            # 为每个文本选择一个硬负图像
            weights_t2i = F.softmax(sim_t2i, dim=1)
            weights_t2i.fill_diagonal_(0)
            neg_image_ids = torch.multinomial(weights_t2i, 1).squeeze()
        
        # 构造正负样本
        pos_vision_embeds = vision_embeds
        neg_vision_embeds = vision_embeds[neg_image_ids]
        all_vision_embeds = torch.cat([pos_vision_embeds, neg_vision_embeds], dim=0)
        
        pos_input_ids = input_ids
        neg_input_ids = input_ids[neg_text_ids]
        all_input_ids = torch.cat([pos_input_ids, neg_input_ids], dim=0)
        
        pos_attention_mask = attention_mask
        neg_attention_mask = attention_mask[neg_text_ids]
        all_attention_mask = torch.cat([pos_attention_mask, neg_attention_mask], dim=0)
        
        # 编码
        text_embeds = self.forward_text_encoder(all_input_ids, all_attention_mask, all_vision_embeds)
        cls_embeds = text_embeds[:, 0, :]
        
        # 预测
        logits = self.itm_head(cls_embeds)
        labels = torch.cat([torch.ones(batch_size), torch.zeros(batch_size)], dim=0).long().to(input_ids.device)
        
        return F.cross_entropy(logits, labels)

五、可视化:看看模型学到了什么

图片3:CapFilt示例

出处:论文图4

从图3可以看到:

  • 原始网页字幕TwT_wTw(红色)大多是噪声,比如"from bridge near my house"根本不描述图片内容
  • Captioner生成的合成字幕TsT_sTs(绿色)都准确描述了图片内容,比如"a flock of birds flying over a lake at sunset"
  • Filter成功过滤掉了所有噪声字幕,只保留了干净的合成字幕

这说明CapFilt确实能有效提纯噪声数据,让模型学到更准确的视觉语言对应关系。


六、总结与未来方向

BLIP是多模态预训练领域的一个里程碑式的工作,它的主要贡献有三点:

  1. 统一架构:提出了MED架构,让同一个模型既能做理解又能做生成,打破了之前模型的偏科问题
  2. 数据提纯:发明了CapFilt技术,能把网上的噪声数据变成高质量训练数据,大大提高了数据利用效率
  3. 全面领先:在多个下游任务上取得了state-of-the-art的结果,而且具有惊人的零样本泛化能力

未来的研究方向包括:

  • 多轮CapFilt:多次迭代提纯数据集,进一步提高数据质量
  • 多字幕生成:为每张图片生成多个不同的合成字幕,扩大训练语料
  • 模型集成:训练多个不同的Captioner和Filter,组合它们的力量

BLIP的思想影响了后来的很多多模态模型,比如BLIP-2、InstructBLIP等,它证明了好的数据和高效的架构同样重要,为多模态预训练指明了新的方向。


相关推荐
wujian83111 小时前
千问 文心 元宝 Kimi导出pdf方法
人工智能·ai·pdf·豆包·deepseek
这张生成的图像能检测吗1 小时前
(论文速读)MDRS-GAN:基于多尺度密集残余收缩GAN的旋转机械故障诊断
人工智能·深度学习·生成对抗网络·故障诊断
yyy(十一月限定版)1 小时前
问题解决策略搜索训练3
算法
Traving Yu1 小时前
LangChain
人工智能·langchain
青衫码上行1 小时前
如何接入AI大模型
java·人工智能·ai·langchain·ai编程
创意岛1 小时前
AI时代,你的品牌在城市发展中“被消失”了吗?
人工智能·python
吃好睡好便好1 小时前
在Matlab中绘制圆锥三维曲面图
开发语言·人工智能·学习·算法·matlab·信息可视化
沪漂阿龙1 小时前
AI大模型面试题:数据处理与特征工程详解——特征工程、缺失值、标准化、归一化、特征选择、数据不平衡、数据泄漏一次讲透
人工智能·机器学习
CORNERSTONE3651 小时前
如何理解工业软件 PLM、ERP、MES 的边界?
大数据·人工智能·plm·产品全生命周期管理