PyTorch 实现图像版多头注意力(Multi-Head Attention)和自注意力(Self-Attention)

本文提供一个适用于图像输入的多头注意力机制(Multi-Head Attention)PyTorch 实现,适用于 ViT、MAE 等视觉 Transformer 中的注意力计算。


模块说明

  • 输入支持图像格式 (B, C, H, W)
  • 内部转换为序列 (B, N, C),其中 N = H * W
  • 多头注意力计算:查询(Q)、键(K)、值(V)使用线性层投影
  • 结果 reshape 回原图维度 (B, C, H, W)

多头注意力机制代码(适用于图像输入)

python 复制代码
import torch
import torch.nn as nn

class ImageMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(ImageMultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Q, K, V 的线性映射
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # 输出映射层
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = self.head_dim ** 0.5

    def forward(self, x):
        # 输入 x: (B, C, H, W),需要 reshape 为 (B, N, C)
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)

        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # 拆成多头 (B, num_heads, N, head_dim)
        Q = Q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 注意力分数计算
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, V)

        # 合并多头
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, H * W, self.embed_dim)

        # 输出映射
        out = self.out_proj(attn_out)

        # 恢复回原图维度 (B, C, H, W)
        out = out.permute(0, 2, 1).view(B, C, H, W)
        return out

# 测试示例
# 假设输入是一张 14x14 的特征图(类似 patch embedding 后)
img = torch.randn(4, 64, 14, 14)  # (B, C, H, W)

mha = ImageMultiHeadAttention(embed_dim=64, num_heads=8)
out = mha(img)

print(out.shape)  # 输出应为 (4, 64, 14, 14)

PyTorch 实现自注意力机制(Self-Attention)

本节补充自注意力机制(Self-Attention)的核心代码实现,适用于 ViT 等模型中 patch token 的注意力操作。

自注意力机制代码(Self-Attention)

python 复制代码
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** 0.5

    def forward(self, x):
        # 输入 x: (B, N, C)
        B, N, C = x.shape

        # 一次性生成 Q, K, V
        qkv = self.qkv_proj(x)  # (B, N, 3C)
        Q, K, V = torch.chunk(qkv, chunks=3, dim=-1)  # 各自为 (B, N, C)

        # 计算注意力分数
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, N, N)
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # 得到注意力加权输出
        attn_out = torch.matmul(attn_probs, V)  # (B, N, C)

        # 映射回原维度
        out = self.out_proj(attn_out)  # (B, N, C)
        return out
        
#  测试示例
# 假设输入为 196 个 patch,每个 patch 的嵌入维度为 64
x = torch.randn(2, 196, 64)  # (B, N, C)

attn = SelfAttention(embed_dim=64)
out = attn(x)

print(out.shape)  # 输出应为 (2, 196, 64)

📎 拓展说明

• 本实现为单头自注意力机制

• 可用于 NLP 中的序列特征或 ViT 图像 patch 序列

• 若需改为多头注意力,只需将 embed_dim 拆成 num_heads × head_dim 并分别计算后合并


PyTorch 实现图像输入的自注意力机制(Self-Attention)

本节介绍一种适用于图像输入 (B, C, H, W) 的自注意力机制实现,适合卷积神经网络与 Transformer 的融合模块,如 Self-Attention ConvNet、BAM、CBAM、ViT 前层等。

自注意力机制(图像维度)代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(ImageSelfAttention, self).__init__()
        self.in_channels = in_channels
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # 可学习缩放因子

    def forward(self, x):
        # 输入 x: (B, C, H, W)
        B, C, H, W = x.size()

        # 生成 Q, K, V
        proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)  # (B, N, C//8)
        proj_key   = self.key_conv(x).view(B, -1, H * W)                      # (B, C//8, N)
        proj_value = self.value_conv(x).view(B, -1, H * W)                    # (B, C, N)

        # 注意力矩阵:Q * K^T
        energy = torch.bmm(proj_query, proj_key)         # (B, N, N)
        attention = F.softmax(energy, dim=-1)             # (B, N, N)

        # 加权求和 V
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)
        out = out.view(B, C, H, W)

        # 残差连接 + 缩放因子
        out = self.gamma * out + x
        return out
        
#测试用例
x = torch.randn(2, 64, 32, 32)  # 输入一张图像:B=2, C=64, H=W=32
self_attn = ImageSelfAttention(in_channels=64)
out = self_attn(x)

print(out.shape)  # 输出形状应为 (2, 64, 32, 32)

• 本模块基于图像 (B, C, H, W) 进行自注意力计算

• 使用卷积进行 Q/K/V 提取,保持局部感知力

• gamma 是可学习缩放因子,用于残差连接控制注意力贡献度


自注意力中**缩放因子(scale factor)的处理,在序列维度(如 ViT)和图片维度(如 Self-Attention Conv)**中有点不一样。下面我们来详细解释一下原因,并对两种写法做一个统一和对比分析

两种缩放因子的区别
  1. 序列维度的缩放因子

    scale = head_dim ** 0.5 # 或者 embed_dim ** 0.5
    attn = (Q @ K.T) / scale

• 来源:Transformer 原始论文(Attention is All You Need)

• 原因:在高维向量内积中,为了避免 dot product 的结果数值过大导致梯度不稳定,需要除以 sqrt(d_k)

• 使用场景:多头注意力机制,输入是 (B, N, C),应用在 NLP、ViT 等序列结构

  1. 图片维度(C, H, W)的注意力机制中没有缩放,或者使用 softmax 平衡

    attn = softmax(Q @ K.T) # 无 scale,或者手动调节

• 来源:Non-local Net、Self-Attention Conv、BAM 等 CNN + Attention 融合方法

• 原因:Q 和 K 都通过 1x1 conv 压缩成 C//8 或更小的维度,内积的值本身不会太大;同时图像 attention 主要用 softmax 控制权重范围

• 缩放因子的控制通常用 γ(gamma)作为残差通道缩放,不是 QK 内部的数值缩放


💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!

相关推荐
Mr_LeeCZ4 分钟前
PyTorch 深度学习 || 7. Unet | Ch7.1 Unet 框架
人工智能·深度学习·机器学习
James. 常德 student7 分钟前
多GPU训练
人工智能·pytorch·深度学习
梦回阑珊10 分钟前
《QT从基础到进阶·七十四》Qt+C++开发一个python编译器,能够编写,运行python程序改进版
c++·python·qt
前端开发张小七14 分钟前
13.Python Socket服务端开发指南
前端·python
前端开发张小七15 分钟前
14.Python Socket客户端开发指南
前端·python
Jozky8618 分钟前
大语言模型在端到端智驾中的应用
人工智能·语言模型·自然语言处理
Y1nhl24 分钟前
搜广推校招面经六十六
pytorch·python·深度学习·机器学习·广告算法·推荐算法·搜索算法
脑洞专家39 分钟前
基于改进的点线融合和关键帧选择的视觉SLAM 方法
人工智能·机器学习·计算机视觉
码界筑梦坊1 小时前
基于Django的二手交易校园购物系统
大数据·后端·python·信息可视化·django
明月看潮生2 小时前
青少年编程与数学 02-015 大学数学知识点 09课题、专业相关性分析
人工智能·青少年编程·数据科学·编程与数学·大学数学