【图像描述生成】GAT:融合几何注意力与位置感知LSTM的Transformer模型详解

本文介绍一种改进的图像描述生成模型------Geometry Attention Transformer (GAT),通过在编码器中引入几何自注意力精炼器(GSR),在解码器中采用位置感知LSTM,显著提升了图像描述的准确性。

1. 引言

1.1 什么是图像描述生成?

图像描述生成(Image Captioning)是计算机视觉领域的一项核心任务,目标是让机器自动为图像生成自然语言描述。这不仅需要识别图像中包含哪些物体,还要准确捕捉物体之间的动作关系和空间位置关系。

举个例子,对于同一张图片:

  • "一个男孩站在滑板上"
  • "一个男孩手里举着滑板"

这两句话描述的是完全不同的场景,而区分它们的关键就在于几何位置关系

1.2 现有方法的局限

目前主流的图像描述模型大多采用Transformer的编码器-解码器架构,并结合注意力机制。然而,这些方法存在两个主要问题:

  1. 编码端:未能充分利用图像中各物体之间的几何空间关系
  2. 解码端:传统的正弦/余弦位置编码难以精确表达词序信息

2. GAT模型架构总览

GAT(Geometry Attention Transformer)的核心思想是在Transformer框架基础上,分别对编码器和解码器进行几何感知增强:

整体架构包含两个关键创新模块:

  • GSR(Geometry Self-attention Refiner):几何自注意力精炼器,用于编码器
  • Position-LSTM:位置感知LSTM,用于解码器

3. 核心技术详解

3.1 几何自注意力精炼器(GSR)

3.1.1 几何特征表示

对于图像中检测到的每个目标区域,除了传统的外观特征外,GAT还提取其几何特征。每个目标的几何信息用一个5维向量表示:

X g = ( x m i n , y m i n , x m a x , y m a x , S ) X_g = (x_{min}, y_{min}, x_{max}, y_{max}, S) Xg=(xmin,ymin,xmax,ymax,S)

其中:

  • ( x m i n , y m i n ) (x_{min}, y_{min}) (xmin,ymin):边界框左上角坐标
  • ( x m a x , y m a x ) (x_{max}, y_{max}) (xmax,ymax):边界框右下角坐标
  • S S S:目标区域相对于整幅图像的面积比例

所有坐标值都归一化到 ( 0 , 1 ) (0, 1) (0,1)区间。

3.1.2 几何-外观特征融合

GSR的核心是将几何信息与外观信息在Query和Key的计算中进行融合。具体做法是将两类特征拼接而非简单相加:

Q ′ = [ X A W Q A ; X G W Q G ] Q' = [X_A W_{Q_A} ; X_G W_{Q_G}] Q′=[XAWQA;XGWQG]

K ′ = [ X A W K A ; X G W K G ] K' = [X_A W_{K_A} ; X_G W_{K_G}] K′=[XAWKA;XGWKG]

其中:

  • X A X_A XA:外观特征
  • X G X_G XG:几何特征(由 X g X_g Xg经过嵌入层和ReLU得到)
  • W Q A , W K A W_{Q_A}, W_{K_A} WQA,WKA:外观特征的投影矩阵
  • W Q G , W K G W_{Q_G}, W_{K_G} WQG,WKG:几何特征的投影矩阵
  • ; \] \[;\] \[;\]:拼接操作

Ω ′ = Q ′ K ′ T 2 × d k \Omega' = \frac{Q'K'^T}{\sqrt{2 \times d_k}} Ω′=2×dk Q′K′T

A t t e n t i o n g ( X ) = s o f t m a x ( Ω ′ ) V A Attention_g(X) = softmax(\Omega')V_A Attentiong(X)=softmax(Ω′)VA

代码实现:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class GeometrySelfAttention(nn.Module):
    def __init__(self, d_model, d_geo, n_heads):
        super().__init__()
        self.d_model = d_model
        self.d_geo = d_geo
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 外观特征投影
        self.W_QA = nn.Linear(d_model, d_model)
        self.W_KA = nn.Linear(d_model, d_model)
        self.W_VA = nn.Linear(d_model, d_model)
        
        # 几何特征投影
        self.W_QG = nn.Linear(d_geo, d_model)
        self.W_KG = nn.Linear(d_geo, d_model)
        
        # 几何特征嵌入
        self.geo_embed = nn.Sequential(
            nn.Linear(5, d_geo),
            nn.ReLU()
        )
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, X_A, X_g):
        """
        X_A: 外观特征 [batch, N, d_model]
        X_g: 几何特征 [batch, N, 5]
        """
        batch_size, N, _ = X_A.shape
        
        # 嵌入几何特征
        X_G = self.geo_embed(X_g)  # [batch, N, d_geo]
        
        # 计算外观Q, K, V
        Q_A = self.W_QA(X_A)  # [batch, N, d_model]
        K_A = self.W_KA(X_A)
        V_A = self.W_VA(X_A)
        
        # 计算几何Q, K
        Q_G = self.W_QG(X_G)
        K_G = self.W_KG(X_G)
        
        # 拼接外观和几何特征
        Q = torch.cat([Q_A, Q_G], dim=-1)  # [batch, N, 2*d_model]
        K = torch.cat([K_A, K_G], dim=-1)
        
        # 计算注意力分数
        scale = (2 * self.d_k) ** 0.5
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attn_weights, V_A)
        output = self.out_proj(output)
        
        return output
3.1.3 门控线性单元(GLU)

为了进一步精炼注意力输出,GAT引入了门控线性单元:

G a t e C t r l = σ ( W g c ~ + b g ) GateCtrl = \sigma(W_g \tilde{c} + b_g) GateCtrl=σ(Wgc~+bg)

O u t p u t = G a t e C t r l ⊙ ( W i a ~ + b i ) Output = GateCtrl \odot (W_i \tilde{a} + b_i) Output=GateCtrl⊙(Wia~+bi)

其中 c ~ = [ X A ; X G ] \tilde{c} = [X_A; X_G] c~=[XA;XG]是当前上下文, ⊙ \odot ⊙表示逐元素乘法(Hadamard积)。

代码实现:

python 复制代码
class GatedLinearUnit(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_g = nn.Linear(d_model, d_model)
        self.W_i = nn.Linear(d_model, d_model)
        
    def forward(self, context, attention_output):
        """
        context: 上下文特征 [batch, N, d_model]
        attention_output: 注意力输出 [batch, N, d_model]
        """
        gate = torch.sigmoid(self.W_g(context))
        output = gate * self.W_i(attention_output)
        return output

3.2 位置感知LSTM解码器

3.2.1 设计动机

传统Transformer使用正弦/余弦函数进行位置编码,虽然有效但不够灵活。GAT采用LSTM来建模词序,有两个优势:

  1. 记忆已生成内容:LSTM的隐藏状态可以记住已生成的词序列
  2. 动态位置感知:能够自适应地指导解码器关注相关的图像区域
3.2.2 Position-LSTM结构

在每个时间步 t t t,Position-LSTM的输入为:

x t = [ w t , v ˉ ] x_t = [w_t, \bar{v}] xt=[wt,vˉ]

其中:

  • w t w_t wt:当前词的嵌入向量
  • v ˉ = 1 k ∑ i v i \bar{v} = \frac{1}{k}\sum_i v_i vˉ=k1∑ivi:图像特征的平均池化

LSTM更新:

h t , c t = L S T M ( x t , ( h t − 1 , c t − 1 ) ) h_t, c_t = LSTM(x_t, (h_{t-1}, c_{t-1})) ht,ct=LSTM(xt,(ht−1,ct−1))

隐藏状态 h t h_t ht作为位置编码,传递给后续的解码器层。

代码实现:

python 复制代码
class PositionLSTM(nn.Module):
    def __init__(self, word_dim, visual_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTMCell(word_dim + visual_dim, hidden_dim)
        self.hidden_dim = hidden_dim
        
    def forward(self, word_embed, visual_mean, prev_hidden, prev_cell):
        """
        word_embed: 词嵌入 [batch, word_dim]
        visual_mean: 图像特征均值 [batch, visual_dim]
        prev_hidden: 上一步隐藏状态 [batch, hidden_dim]
        prev_cell: 上一步细胞状态 [batch, hidden_dim]
        """
        # 拼接词嵌入和视觉特征
        lstm_input = torch.cat([word_embed, visual_mean], dim=-1)
        
        # LSTM前向传播
        h_t, c_t = self.lstm(lstm_input, (prev_hidden, prev_cell))
        
        return h_t, c_t
3.2.3 解码器整体流程

解码器采用多层结构,底层解码器使用Position-LSTM输出的 h t h_t ht作为Query,使用编码器输出的精炼特征 X r X_r Xr计算Key和Value:

python 复制代码
class GATDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads):
        super().__init__()
        self.word_embed = nn.Embedding(vocab_size, d_model)
        self.position_lstm = PositionLSTM(d_model, d_model, d_model)
        
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads) for _ in range(n_layers)
        ])
        
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, X_r, prev_words, visual_mean, hidden_states):
        """
        X_r: 编码器输出的精炼特征
        prev_words: 已生成的词序列
        visual_mean: 视觉特征均值
        hidden_states: LSTM隐藏状态
        """
        # 词嵌入
        word_embeds = self.word_embed(prev_words)
        
        # Position-LSTM编码位置
        h_t, c_t = self.position_lstm(word_embeds, visual_mean, *hidden_states)
        
        # 多层解码
        query = h_t.unsqueeze(1)
        for layer in self.decoder_layers:
            query = layer(query, X_r)
        
        # 输出词分布
        logits = self.output_proj(query.squeeze(1))
        return logits, (h_t, c_t)

3.3 完整的GAT编码器

结合GSR和GLU,完整的编码器处理流程如下:

python 复制代码
class GATEncoder(nn.Module):
    def __init__(self, d_model, d_geo, n_heads, n_layers, d_ff):
        super().__init__()
        self.layers = nn.ModuleList()
        
        for _ in range(n_layers):
            self.layers.append(nn.ModuleDict({
                'gsr': GeometrySelfAttention(d_model, d_geo, n_heads),
                'glu': GatedLinearUnit(d_model * 2),  # 外观+几何
                'norm1': nn.LayerNorm(d_model),
                'ffn': nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model)
                ),
                'norm2': nn.LayerNorm(d_model)
            }))
    
    def forward(self, X_A, X_g):
        """
        X_A: 外观特征 [batch, N, d_model]
        X_g: 几何特征 [batch, N, 5]
        """
        X_G = self.layers[0]['gsr'].geo_embed(X_g)
        
        for layer in self.layers:
            # 几何自注意力
            attn_out = layer['gsr'](X_A, X_g)
            
            # 门控单元精炼
            context = torch.cat([X_A, X_G], dim=-1)
            glu_out = layer['glu'](context, attn_out)
            
            # 残差连接 + LayerNorm
            X_A = layer['norm1'](X_A + glu_out)
            
            # 前馈网络
            ffn_out = layer['ffn'](X_A)
            X_A = layer['norm2'](X_A + ffn_out)
        
        return X_A

4. 实验结果分析

4.1 消融实验

在MS COCO数据集上的消融实验验证了各模块的有效性:

模型配置 BLEU-1 BLEU-4 CIDEr SPICE
Base (原始Transformer) 75.0 32.8 109.0 20.6
Base + GSR 76.9 35.6 115.1 21.4
Base + Position-LSTM 76.5 34.5 114.9 21.3
完整GAT 77.5 37.8 119.8 21.8

关键发现:

  1. GSR的效果:CIDEr从109.0提升到115.1(+6.1),说明几何关系对描述准确性至关重要
  2. Position-LSTM的效果:CIDEr提升5.9,证明了动态位置编码优于静态正弦编码
  3. 两者结合:最终CIDEr达到119.8,整体提升10.8

4.2 特征融合策略对比

融合方式 BLEU-1 BLEU-4 CIDEr
相加 (add) 76.0 35.1 116.4
拼接 (concat) 77.5 37.8 119.8

拼接策略明显优于相加,因为它保留了更丰富的特征交互信息。

4.3 与SOTA方法对比

在MS COCO数据集上(CIDEr优化后):

方法 BLEU-4 METEOR CIDEr SPICE
Up-Down 36.3 27.7 120.1 21.4
AoANet 38.9 29.0 129.8 22.4
ORT 38.6 28.7 128.3 22.6
GAT (本文) 39.7 29.1 130.5 22.9

GAT在几乎所有指标上都取得了最优结果。

4.4 定性分析

GAT生成的描述能够准确捕捉空间关系:

  • 基线模型:"人们坐在餐厅外的桌子旁"
  • GAT :"一群人坐在雨伞下 ,在餐厅前面"

GAT能够正确识别"雨伞下"和"餐厅前面"这样的空间关系,这正是GSR模块的功劳。

5. 总结与展望

5.1 核心贡献

  1. 几何自注意力精炼器(GSR):将物体的空间几何信息融入自注意力计算,让模型知道"物体在哪里"
  2. 位置感知LSTM:用LSTM替代静态位置编码,动态建模词序关系
  3. 门控精炼机制:进一步优化注意力输出

5.2 关键参数设置

参数 取值
图像特征维度 2048 → 512
LSTM隐藏层大小 1024
注意力头数 8
编码器/解码器层数 3
Dropout (LSTM) 0.5
Dropout (注意力) 0.1
初始学习率 5×10⁻⁴

5.3 未来方向

  • 探索更复杂的几何关系建模(如物体间的相对位置、遮挡关系)
  • 将GAT扩展到视频描述生成任务
  • 结合预训练视觉-语言模型进一步提升性能

参考代码GitHub - UESTC-nnLab/GAT

论文来源:Chi Wang et al. "Geometry Attention Transformer with Position-aware LSTMs for Image Captioning", Expert Systems with Applications, 2021.


如果觉得本文对你有帮助,欢迎点赞收藏!有问题欢迎在评论区讨论~

相关推荐
Coovally AI模型快速验证2 小时前
Meta ShapeR重磅开源:多模态3D生成,从真实杂乱视频中稳健重建
人工智能·学习·算法·yolo·3d·人机交互
菩提树下的凡夫2 小时前
强化学习和深度学习的区别与联系
人工智能·深度学习
九尾狐ai2 小时前
从九尾狐AI案例拆解智能矩阵技术架构:如何实现AI获客300万播放?
人工智能
wasp5202 小时前
Hudi 客户端实现分析
java·开发语言·人工智能·hudi
秦苒&2 小时前
【脉脉】AI 创作者 xAMA 知无不言:在浪潮里,做会发光的造浪者
大数据·c语言·数据库·c++·人工智能·ai·操作系统
chinesegf2 小时前
嵌入模型和大语言模型的关系
人工智能·语言模型·自然语言处理
啊阿狸不会拉杆2 小时前
《计算机操作系统》 第十一章 -多媒体操作系统
开发语言·c++·人工智能·os·计算机操作系统
_ziva_2 小时前
分布式(三)深入浅出理解PyTorch分布式训练:nn.parallel.DistributedDataParallel详解
人工智能·pytorch·分布式
江南小书生2 小时前
非标制造行业装配报工工时不准?缺料干扰+标准缺失如何破局?
大数据·人工智能