本文介绍一种将Transformer与CNN结合的视觉显著性预测模型------TranSalNet。通过在多尺度特征图上引入Transformer编码器,有效捕获长距离上下文信息,使显著性预测更接近人类视觉注意力机制。
1. 引言
1.1 什么是视觉显著性预测?
视觉显著性预测(Visual Saliency Prediction)旨在模拟人类视觉注意力机制,预测图像中哪些区域最容易吸引人的注意力。这项技术在多个领域有广泛应用:
- 图像压缩:对显著区域分配更多比特
- 图像重定向:保留显著内容进行裁剪
- 视频摘要:基于显著性选取关键帧
- 图像质量评估:加权显著区域的失真
1.2 现有方法的局限
传统方法:基于颜色、亮度、纹理等底层特征,忽略了物体等高层语义信息。
CNN方法 :虽然能学习高层特征,但卷积操作的感受野有限,难以建模长距离依赖关系。
例如,当图像中有两个相距较远的人脸时,CNN难以捕捉它们之间的关联性。
1.3 Transformer的优势
Transformer通过自注意力机制可以直接建模任意两个位置之间的关系,天然适合捕获全局上下文信息。
TranSalNet的核心思想:CNN提取多尺度特征 + Transformer增强长距离上下文
2. TranSalNet模型架构
2.1 整体结构
TranSalNet采用编码器-解码器架构,由三部分组成:
输入图像 (3×384×288)
↓
┌─────────────────────────────────────┐
│ CNN编码器 (ResNet-50) │
│ 提取多尺度特征: conv3_x, conv4_x, │
│ conv5_x │
└─────────────────────────────────────┘
↓ x1, x2, x3
┌─────────────────────────────────────┐
│ Transformer编码器 ×3 │
│ 增强各尺度特征的长距离上下文信息 │
└─────────────────────────────────────┘
↓ x1_c, x2_c, x3_c
┌─────────────────────────────────────┐
│ CNN解码器 │
│ 跳跃连接 + 逐步上采样 → 显著性图 │
└─────────────────────────────────────┘
↓
输出显著性图 (1×384×288)
2.2 CNN编码器
使用预训练的ResNet-50作为骨干网络,提取三个不同尺度的特征图:
| 特征 | 来源 | 通道数 | 空间尺寸 |
|---|---|---|---|
| x1 | conv5_x | 2048 | H/32 × W/32 |
| x2 | conv4_x | 1024 | H/16 × W/16 |
| x3 | conv3_x | 512 | H/8 × W/8 |
为什么不用conv1和conv2_x?
实验发现,使用浅层特征会导致显著性图出现边界线和伪影,因为浅层特征包含过多的低级纹理信息。
python
import torch
import torch.nn as nn
from torchvision.models import resnet50
class CNNEncoder(nn.Module):
"""CNN编码器 - 基于ResNet-50"""
def __init__(self, pretrained=True):
super().__init__()
# 加载预训练ResNet-50
resnet = resnet50(pretrained=pretrained)
# 提取各个stage
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1 # conv2_x
self.layer2 = resnet.layer2 # conv3_x -> x3
self.layer3 = resnet.layer3 # conv4_x -> x2
self.layer4 = resnet.layer4 # conv5_x -> x1
def forward(self, x):
"""
Args:
x: 输入图像 [B, 3, H, W]
Returns:
x1: conv5_x特征 [B, 2048, H/32, W/32]
x2: conv4_x特征 [B, 1024, H/16, W/16]
x3: conv3_x特征 [B, 512, H/8, W/8]
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x) # conv2_x
x3 = self.layer2(x) # conv3_x: [B, 512, H/8, W/8]
x2 = self.layer3(x3) # conv4_x: [B, 1024, H/16, W/16]
x1 = self.layer4(x2) # conv5_x: [B, 2048, H/32, W/32]
return x1, x2, x3
2.3 Transformer编码器
每个Transformer编码器包含2层,用于增强特征的长距离上下文信息。
2.3.1 处理流程
对于输入特征 xix_ixi:
- 通道压缩:用1×1卷积将通道数对齐到Transformer可接受的维度
- 位置编码:添加可学习的位置嵌入
- 自注意力:多头自注意力 + 残差连接
- 前馈网络:MLP + 残差连接
数学表达:
z0=Conv1×1(xi)⊕POSiz_0 = Conv_{1\times1}(x_i) \oplus POS_iz0=Conv1×1(xi)⊕POSi
zl′=MSA(LN(zl−1))⊕zl−1,l=1,2z'l = MSA(LN(z{l-1})) \oplus z_{l-1}, \quad l=1,2zl′=MSA(LN(zl−1))⊕zl−1,l=1,2
zl=MLP(LN(zl′))⊕zl′,l=1,2z_l = MLP(LN(z'_l)) \oplus z'_l, \quad l=1,2zl=MLP(LN(zl′))⊕zl′,l=1,2
2.3.2 配置参数
| Transformer | 输入通道 | 压缩后通道 | 注意力头数 |
|---|---|---|---|
| Encoder 1 | 2048 | 768 | 12 |
| Encoder 2 | 1024 | 768 | 12 |
| Encoder 3 | 512 | 512 | 8 |
代码实现:
python
class TransformerEncoder(nn.Module):
"""Transformer编码器"""
def __init__(self, in_channels, embed_dim, num_heads,
num_layers=2, mlp_ratio=4.0):
super().__init__()
self.embed_dim = embed_dim
# 通道压缩(如果需要)
if in_channels != embed_dim:
self.proj = nn.Conv2d(in_channels, embed_dim, 1)
else:
self.proj = nn.Identity()
# 位置编码(可学习)
self.pos_embed = None # 动态初始化
# Transformer层
self.layers = nn.ModuleList([
TransformerLayer(embed_dim, num_heads, mlp_ratio)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
"""
Args:
x: 输入特征 [B, C, H, W]
Returns:
out: 上下文增强的特征 [B, embed_dim, H, W]
"""
B, C, H, W = x.shape
# 通道投影
x = self.proj(x) # [B, embed_dim, H, W]
# 展平为序列
x = x.flatten(2).transpose(1, 2) # [B, H*W, embed_dim]
# 添加位置编码
if self.pos_embed is None or self.pos_embed.size(1) != H * W:
self.pos_embed = nn.Parameter(
torch.zeros(1, H * W, self.embed_dim, device=x.device)
)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
x = x + self.pos_embed
# Transformer层
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# 重塑回特征图
x = x.transpose(1, 2).view(B, self.embed_dim, H, W)
return x
class TransformerLayer(nn.Module):
"""单层Transformer"""
def __init__(self, dim, num_heads, mlp_ratio=4.0):
super().__init__()
# 多头自注意力
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
# MLP
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
def forward(self, x):
# 自注意力 + 残差
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out
# MLP + 残差
x = x + self.mlp(self.norm2(x))
return x
2.4 CNN解码器
解码器通过跳跃连接融合多尺度上下文增强特征,逐步上采样生成显著性图。
2.4.1 关键设计
- 跳跃连接 :使用逐元素乘法(而非加法或拼接)融合特征
- 逐步上采样:每个block进行2倍最近邻上采样
- 激活函数:中间层用ReLU,输出层用Sigmoid
2.4.2 解码流程
block_1: x1_c → Conv3×3 → BN → ReLU
↓ (×2 upsample)
block_2: ⊙ x2_c → Conv3×3 → BN → ReLU (逐元素乘法)
↓ (×2 upsample)
block_3: ⊙ x3_c → Conv3×3 → BN → ReLU
↓ (×2 upsample)
block_4: → Conv3×3 → BN → ReLU
↓ (×2 upsample)
block_5: → Conv3×3 → BN → ReLU
↓ (×2 upsample)
block_6: → Conv3×3 → BN → ReLU
↓
block_7: → Conv3×3 → Sigmoid → 显著性图
代码实现:
python
class CNNDecoder(nn.Module):
"""CNN解码器"""
def __init__(self):
super().__init__()
# Block 1: 处理x1_c (768通道)
self.block1 = self._make_block(768, 768)
# Block 2: 融合x2_c (768通道)
self.block2 = self._make_block(768, 768)
# Block 3: 融合x3_c (512通道)
self.block3 = self._make_block(512, 512)
# Block 4-6: 逐步降维
self.block4 = self._make_block(512, 256)
self.block5 = self._make_block(256, 128)
self.block6 = self._make_block(128, 64)
# Block 7: 输出层
self.block7 = nn.Sequential(
nn.Conv2d(64, 1, 3, 1, 1),
nn.Sigmoid()
)
# 上采样
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
def _make_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1_c, x2_c, x3_c):
"""
Args:
x1_c: Transformer增强的conv5_x特征 [B, 768, H/32, W/32]
x2_c: Transformer增强的conv4_x特征 [B, 768, H/16, W/16]
x3_c: Transformer增强的conv3_x特征 [B, 512, H/8, W/8]
Returns:
saliency: 显著性图 [B, 1, H, W]
"""
# Block 1
x = self.block1(x1_c)
# Block 2: 上采样 + 逐元素乘法融合
x = self.upsample(x)
x = torch.relu(x * x2_c) # 逐元素乘法
x = self.block2(x)
# Block 3: 上采样 + 逐元素乘法融合
x = self.upsample(x)
x = torch.relu(x * x3_c)
x = self.block3(x)
# Block 4-6: 逐步上采样
x = self.upsample(x)
x = self.block4(x)
x = self.upsample(x)
x = self.block5(x)
x = self.upsample(x)
x = self.block6(x)
# Block 7: 输出
saliency = self.block7(x)
return saliency
2.5 完整TranSalNet
python
class TranSalNet(nn.Module):
"""TranSalNet: Transformer增强的显著性预测网络"""
def __init__(self, pretrained=True):
super().__init__()
# CNN编码器
self.encoder = CNNEncoder(pretrained=pretrained)
# Transformer编码器
self.trans_enc1 = TransformerEncoder(
in_channels=2048, embed_dim=768, num_heads=12, num_layers=2
)
self.trans_enc2 = TransformerEncoder(
in_channels=1024, embed_dim=768, num_heads=12, num_layers=2
)
self.trans_enc3 = TransformerEncoder(
in_channels=512, embed_dim=512, num_heads=8, num_layers=2
)
# CNN解码器
self.decoder = CNNDecoder()
def forward(self, x):
"""
Args:
x: 输入图像 [B, 3, 384, 288]
Returns:
saliency: 显著性图 [B, 1, 384, 288]
"""
# CNN编码
x1, x2, x3 = self.encoder(x)
# Transformer增强
x1_c = self.trans_enc1(x1)
x2_c = self.trans_enc2(x2)
x3_c = self.trans_enc3(x3)
# CNN解码
saliency = self.decoder(x1_c, x2_c, x3_c)
return saliency
3. 损失函数设计
TranSalNet采用多指标组合损失,将显著性评价指标直接用于训练。
3.1 四个损失分量
3.1.1 NSS损失(Normalized Scanpath Saliency)
衡量预测图在真实注视点位置的响应:
LNSS=1∑iyif∑iy^i−μ(y^)σ(y^)⋅yif\mathcal{L}_{NSS} = \frac{1}{\sum_i y_i^f} \sum_i \frac{\hat{y}_i - \mu(\hat{y})}{\sigma(\hat{y})} \cdot y_i^fLNSS=∑iyif1i∑σ(y^)y^i−μ(y^)⋅yif
其中 yfy^fyf 是注视点图(fixation map),y^\hat{y}y^ 是预测图。
3.1.2 KLD损失(KL散度)
衡量预测分布与真实分布的差异:
LKLD=∑iyislog(yis+ϵy^i+ϵ)\mathcal{L}_{KLD} = \sum_i y_i^s \log\left(\frac{y_i^s + \epsilon}{\hat{y}_i + \epsilon}\right)LKLD=i∑yislog(y^i+ϵyis+ϵ)
3.1.3 CC损失(相关系数)
衡量预测图与真实图的线性相关性:
LCC=cov(ys,y^)σ(ys)⋅σ(y^)\mathcal{L}_{CC} = \frac{cov(y^s, \hat{y})}{\sigma(y^s) \cdot \sigma(\hat{y})}LCC=σ(ys)⋅σ(y^)cov(ys,y^)
3.1.4 SIM损失(相似度)
衡量两个归一化分布的重叠程度:
LSIM=∑imin(yis,y^i)\mathcal{L}_{SIM} = \sum_i \min(y_i^s, \hat{y}_i)LSIM=i∑min(yis,y^i)
3.2 组合损失
L=λ1LNSS+λ2LKLD+λ3LCC+λ4LSIM\mathcal{L} = \lambda_1 \mathcal{L}{NSS} + \lambda_2 \mathcal{L}{KLD} + \lambda_3 \mathcal{L}{CC} + \lambda_4 \mathcal{L}{SIM}L=λ1LNSS+λ2LKLD+λ3LCC+λ4LSIM
权重设置 :λ1=−1,λ2=10,λ3=−2,λ4=−1\lambda_1=-1, \lambda_2=10, \lambda_3=-2, \lambda_4=-1λ1=−1,λ2=10,λ3=−2,λ4=−1
(负号是因为NSS、CC、SIM越大越好,而损失需要最小化)
代码实现:
python
class SaliencyLoss(nn.Module):
"""显著性预测组合损失"""
def __init__(self, lambda_nss=-1, lambda_kld=10,
lambda_cc=-2, lambda_sim=-1, eps=2.2204e-16):
super().__init__()
self.lambda_nss = lambda_nss
self.lambda_kld = lambda_kld
self.lambda_cc = lambda_cc
self.lambda_sim = lambda_sim
self.eps = eps
def forward(self, pred, saliency_gt, fixation_gt):
"""
Args:
pred: 预测显著性图 [B, 1, H, W]
saliency_gt: 真实显著性图 [B, 1, H, W]
fixation_gt: 注视点图 [B, 1, H, W]
"""
pred = pred.squeeze(1)
saliency_gt = saliency_gt.squeeze(1)
fixation_gt = fixation_gt.squeeze(1)
# NSS损失
nss_loss = self._nss_loss(pred, fixation_gt)
# KLD损失
kld_loss = self._kld_loss(pred, saliency_gt)
# CC损失
cc_loss = self._cc_loss(pred, saliency_gt)
# SIM损失
sim_loss = self._sim_loss(pred, saliency_gt)
# 组合损失
total_loss = (self.lambda_nss * nss_loss +
self.lambda_kld * kld_loss +
self.lambda_cc * cc_loss +
self.lambda_sim * sim_loss)
return total_loss
def _nss_loss(self, pred, fixation):
"""NSS损失"""
# 标准化预测图
pred_mean = pred.mean(dim=[1, 2], keepdim=True)
pred_std = pred.std(dim=[1, 2], keepdim=True) + self.eps
pred_norm = (pred - pred_mean) / pred_std
# 在注视点位置计算响应
nss = (pred_norm * fixation).sum(dim=[1, 2]) / (fixation.sum(dim=[1, 2]) + self.eps)
return nss.mean()
def _kld_loss(self, pred, saliency):
"""KL散度损失"""
# 归一化为概率分布
pred_norm = pred / (pred.sum(dim=[1, 2], keepdim=True) + self.eps)
sal_norm = saliency / (saliency.sum(dim=[1, 2], keepdim=True) + self.eps)
kld = (sal_norm * torch.log((sal_norm + self.eps) / (pred_norm + self.eps))).sum(dim=[1, 2])
return kld.mean()
def _cc_loss(self, pred, saliency):
"""相关系数损失"""
# 归一化
pred_norm = pred / (pred.sum(dim=[1, 2], keepdim=True) + self.eps)
sal_norm = saliency / (saliency.sum(dim=[1, 2], keepdim=True) + self.eps)
# 去均值
pred_mean = pred_norm.mean(dim=[1, 2], keepdim=True)
sal_mean = sal_norm.mean(dim=[1, 2], keepdim=True)
pred_centered = pred_norm - pred_mean
sal_centered = sal_norm - sal_mean
# 相关系数
cov = (pred_centered * sal_centered).sum(dim=[1, 2])
pred_std = (pred_centered ** 2).sum(dim=[1, 2]).sqrt()
sal_std = (sal_centered ** 2).sum(dim=[1, 2]).sqrt()
cc = cov / (pred_std * sal_std + self.eps)
return cc.mean()
def _sim_loss(self, pred, saliency):
"""相似度损失"""
# 归一化
pred_norm = pred / (pred.sum(dim=[1, 2], keepdim=True) + self.eps)
sal_norm = saliency / (saliency.sum(dim=[1, 2], keepdim=True) + self.eps)
sim = torch.minimum(pred_norm, sal_norm).sum(dim=[1, 2])
return sim.mean()
4. 训练策略
4.1 数据集
| 数据集 | 图像数 | 用途 | 特点 |
|---|---|---|---|
| SALICON | 10000+5000 | 预训练 | 鼠标点击模拟眼动 |
| MIT1003 | 1003 | 微调+测试 | 自然场景,15人眼动 |
| CAT2000 | 2000 | 微调+测试 | 20类场景,24人眼动 |
| MIT300 | 300 | 测试 | 官方基准测试集 |
4.2 训练流程
python
def train_transalnet():
# 配置
config = {
'input_size': (384, 288), # 4:3比例
'batch_size': 4,
'epochs': 30,
'lr': 1e-5,
'lr_decay_step': 3,
'lr_decay_gamma': 0.1,
'patience': 5 # 早停
}
# 模型
model = TranSalNet(pretrained=True)
# 损失函数
criterion = SaliencyLoss()
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
# 学习率调度
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config['lr_decay_step'],
gamma=config['lr_decay_gamma']
)
# 训练循环
for epoch in range(config['epochs']):
model.train()
for images, saliency_maps, fixation_maps in train_loader:
optimizer.zero_grad()
# 前向传播
pred = model(images)
# 计算损失
loss = criterion(pred, saliency_maps, fixation_maps)
# 反向传播
loss.backward()
optimizer.step()
scheduler.step()
# 验证和早停...
4.3 两阶段训练
- 预训练:在SALICON数据集上训练
- 微调:选择验证集上最优模型,在MIT1003/CAT2000上微调
使用10折交叉验证评估最终性能。
5. 实验结果
5.1 消融实验
| 模型 | 配置 | sAUC↑ | AUC↑ | NSS↑ | CC↑ | SIM↑ | KLD↓ |
|---|---|---|---|---|---|---|---|
| BaseNet | 无Transformer,无跳跃连接 | 0.746 | 0.902 | 2.71 | 0.736 | 0.586 | 0.789 |
| BaseNet+ | +1个Transformer | 0.747 | 0.905 | 2.74 | 0.745 | 0.600 | 0.837 |
| SkipNet | +跳跃连接 | 0.751 | 0.901 | 2.76 | 0.739 | 0.538 | 0.733 |
| TranSalNet_BCE | 完整结构+BCE损失 | 0.751 | 0.909 | 2.83 | 0.759 | 0.612 | 0.779 |
| TranSalNet | 完整结构+组合损失 | 0.755 | 0.909 | 2.85 | 0.760 | 0.615 | 0.778 |
关键发现:
- Transformer编码器有效提升性能
- 跳跃连接进一步改善多尺度融合
- 组合损失优于单一BCE损失
5.2 与SOTA方法对比
MIT1003数据集:
| 方法 | sAUC↑ | AUC↑ | NSS↑ | CC↑ | SIM↑ | KLD↓ |
|---|---|---|---|---|---|---|
| ML-Net | 0.722 | 0.862 | 2.33 | 0.598 | 0.496 | 1.350 |
| SAM-VGG | 0.726 | 0.900 | 2.75 | 0.726 | 0.598 | 1.220 |
| SAM-ResNet | 0.737 | 0.902 | 2.80 | 0.747 | 0.607 | 1.247 |
| DVA | 0.736 | 0.900 | 2.65 | 0.711 | 0.553 | 0.723 |
| MSI-NET | 0.745 | 0.907 | 2.80 | 0.747 | 0.608 | 0.816 |
| TranSalNet | 0.755 | 0.909 | 2.85 | 0.760 | 0.615 | 0.778 |
5.3 MIT300基准测试
TranSalNet在MIT/Tübingen官方基准上取得优异成绩:
- CC: 0.7991 (第1)
- SIM: 0.6852 (第1)
- NSS: 2.3758 (第2)
- sAUC: 0.7471 (第2)
CC和SIM是最接近人类感知的指标,TranSalNet在这两个指标上排名第一,说明模型预测的显著性图最"像人类"。
6. 为什么Transformer有效?
6.1 长距离依赖建模
CNN的感受野受卷积核大小限制,而Transformer可以直接建模任意距离的像素关系。
例如:图像左上角的人脸和右下角的人脸,Transformer可以直接关联它们的显著性。
6.2 全局上下文理解
自注意力机制让每个位置都能"看到"整幅图像,从而理解全局语义上下文。
6.3 多尺度上下文增强
TranSalNet在三个不同尺度上都应用Transformer,分别增强:
- 高层语义(conv5_x):物体级别的显著性
- 中层特征(conv4_x):部件级别的显著性
- 低层细节(conv3_x):边缘和纹理的显著性
7. 关键参数总结
| 参数 | 值 |
|---|---|
| 输入尺寸 | 384×288 |
| CNN骨干 | ResNet-50 (预训练) |
| Transformer层数 | 每个编码器2层 |
| 注意力头数 | 12/12/8 |
| 嵌入维度 | 768/768/512 |
| 学习率 | 1×10⁻⁵ |
| 批大小 | 4 |
| 训练轮数 | 30 |
| 损失权重 | NSS:-1, KLD:10, CC:-2, SIM:-1 |
8. 总结
8.1 核心贡献
- 首次将Transformer引入显著性预测:有效捕获长距离上下文信息
- 多尺度Transformer增强:在不同语义层级增强上下文
- 组合损失函数:直接优化显著性评价指标
- 跳跃连接设计:逐元素乘法融合多尺度特征
8.2 适用场景
- 图像/视频质量评估
- 图像压缩与重定向
- 视觉注意力分析
- 广告效果评估
8.3 局限与展望
- 计算开销较大(Transformer的平方复杂度)
- 可探索更高效的注意力机制(如线性注意力)
- 可扩展到视频显著性预测
如果觉得本文对你有帮助,欢迎点赞收藏!有问题欢迎在评论区讨论~