深度学习基本模块:MultiheadAttention 多头注意力机制

多头注意力机制(MultiheadAttention)是Transformer架构中的核心组件,通过并行计算多个注意力头,使模型能够同时关注输入序列中不同表示子空间的信息 ,从而更有效地捕捉长距离依赖和复杂模式。MultiheadAttention通过将输入特征分割到多个子空间并独立计算注意力,然后合并结果,实现了更强大的序列建模能力

一、MultiheadAttention介绍

1.1 结构

  • 输入层 :序列数据,形状为(batch_size, seq_len, embed_dim)的张量

  • MultiheadAttention层

    • 核心机制

      • 查询投影(Query Projection):将输入转换为查询向量
      • 键投影(Key Projection):将输入转换为键向量
      • 值投影(Value Projection):将输入转换为值向量
      • 多头分割:将投影后的特征分割到多个注意力头
      • 缩放点积注意力:每个头独立计算注意力权重
      • 头合并:将多个头的输出拼接并投影回原始维度
    • 可学习参数

      • 投影权重in_proj_weight,形状为(3 * embed_dim, embed_dim)
      • 输出投影权重out_proj.weight,形状为(embed_dim, embed_dim)
      • 偏置项in_proj_biasout_proj.bias,形状均为(embed_dim,)
  • 激活函数

    • 注意力权重计算 :使用Softmax函数,确保权重和为1
    • 值变换:通常使用线性变换,但可结合非线性激活函数
  • 多头机制的意义

    • 并行处理:多个注意力头同时计算不同的表示子空间
    • 多样化捕捉:每个头关注序列中不同类型的关系模式
    • 计算效率:将高维注意力计算分解为多个低维计算
    • 表示能力增强:通过多头机制获得更丰富的序列表示

1.2 参数

  • embed_dim:模型的总维度,必须是num_heads的倍数
  • num_heads:并行注意力头的数量,每个头的维度为embed_dim // num_heads
  • dropout:注意力权重的dropout概率,默认0.0
  • bias:是否在输入/输出投影层添加偏置,默认True
  • add_bias_kv:是否为key和value序列在dim=0添加偏置,默认False
  • add_zero_attn:是否在key和value序列的dim=1添加零批次,默认False
  • kdim:key的特征维度,默认None(使用embed_dim)
  • vdim:value的特征维度,默认None(使用embed_dim)
  • batch_first :输入输出张量格式,False为(seq, batch, feature)True为(batch, seq, feature)

1.3 输入输出维度

batch_first=True 模式下,输入张量维度如下:

  • Query(查询)(Batch Size, Query Sequence Length, Query Embedding Dimension)
  • Key(键)(Batch Size, Key Sequence Length, Key Embedding Dimension)
  • Value(值)(Batch Size, Value Sequence Length, Value Embedding Dimension)

输出维度如下:

  • 注意力输出 (attn_output)(Batch Size, Query Sequence Length, Output Embedding Dimension)
  • 注意力权重 (attn_output_weights)
    • average_attn_weights=True(默认):(Batch Size, Query Sequence Length, Key Sequence Length),表示每个批次中每个查询位置对所有键位置的平均注意力权重
    • average_attn_weights=False(Batch Size, Number of Attention Heads, Query Sequence Length, Key Sequence Length),提供每个注意力头的独立权重,便于分析不同头的注意力模式
参数类型 维度格式 关键含义
Query 输入 (N, L, E_q) 批量大小 × 查询序列长度 × 查询嵌入维度
Key 输入 (N, S, E_k) 批量大小 × 键序列长度 × 键嵌入维度
Value 输入 (N, S, E_v) 批量大小 × 值序列长度 × 值嵌入维度
注意力输出 (N, L, E) 批量大小 × 查询序列长度 × 输出嵌入维度
注意力权重(平均) (N, L, S) 批量大小 × 查询序列长度 × 键序列长度
注意力权重(分头) (N, H, L, S) 批量大小 × 注意力头数 × 查询序列长度 × 键序列长度
  • 维度一致性 :Query Embedding Dimension (E_q) 必须等于模块的 embed_dim 参数
  • 灵活配置 :Key 和 Value 的嵌入维度(E_k, E_v)可通过 kdimvdim 参数独立设置,默认等于 embed_dim 参数
  • 序列长度独立性:Query Sequence Length (L) 和 Key Sequence Length (S) 可以不同,支持跨序列注意力计算
  • 输出维度稳定 :注意力输出维度始终与 embed_dim 保持一致,无论输入维度如何变化
python 复制代码
import torch.nn as nn
attention_layer = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
output, weights = attention_layer(query, key, value)

attention_layer(query, key, value) 调用中,Query、Key、Value 的来源取决于具体的注意力类型和应用场景。以下是常见来源的总结表格:

注意力类型 Query 来源 Key 来源 Value 来源 典型应用场景
自注意力 (Self-Attention) 输入序列 输入序列 输入序列 Transformer 编码器、BERT、文本分类
交叉注意力 (Cross-Attention) 目标序列 源序列 源序列 Transformer 解码器、机器翻译、文本生成
编码器-解码器注意力 解码器输出 编码器输出 编码器输出 Seq2Seq 模型、语音识别
记忆检索注意力 查询向量 记忆键 记忆值 问答系统、知识检索
多头键值注意力 外部查询 键序列 值序列 推荐系统、个性化建模

在 PyTorch 的 MultiheadAttention 中:

  1. Query、Key、Value 的来源是灵活的,取决于具体任务和注意力类型
  2. 自注意力是最常见的形式,Q、K、V 都来自同一输入
  3. 交叉注意力广泛用于序列到序列任务,Q 来自目标序列,K/V 来自源序列
  4. 模块不关心输入的具体来源,只负责根据给定的 Q、K、V 计算注意力
  5. 正确的来源选择对模型性能至关重要,需要根据任务需求设计

1.4 计算过程

见这两篇文章:
深度学习:注意力机制(Attention Mechanism)
论文解读:Attention Is All You Need

二、代码示例

通过多头注意力层处理一段音频频谱,打印每层的输出形状、参数形状,并可视化特征图。

python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib
import librosa
import numpy as np

matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号 '-' 显示为方块的问题
matplotlib.rcParams['font.family'] = 'Kaiti SC'  # 可以替换为其他字体


# 定义多头注意力模型
class MultiheadAttentionModel(nn.Module):
    def __init__(self, input_size, embed_dim, num_heads):
        super(MultiheadAttentionModel, self).__init__()
        # 输入投影层(可选,用于调整输入维度)
        self.input_proj = nn.Linear(input_size, embed_dim)
        # 多头注意力层
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            batch_first=True
        )

    def forward(self, x):
        # 输入投影
        projected = self.input_proj(x)
        # 多头注意力(自注意力模式)
        attn_output, attn_weights = self.attention(projected, projected, projected)
        return projected, attn_output, attn_weights


# 读取音频文件并处理
file_path = 'test.wav'  # 替换为您的音频文件路径
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True)

# 选取3秒的数据
start_sample = int(1.5 * sample_rate)
end_sample = int(4.5 * sample_rate)
audio_segment = waveform[start_sample:end_sample]

# 转换为频谱
n_fft = 512
hop_length = 256
spectrogram = librosa.stft(audio_segment, n_fft=n_fft, hop_length=hop_length)
spectrogram_db = librosa.amplitude_to_db(np.abs(spectrogram))
spectrogram_tensor = torch.tensor(spectrogram_db, dtype=torch.float32).unsqueeze(0)
spectrogram_tensor = spectrogram_tensor.permute(0, 2, 1)  # (batch, seq_len, input_size)
print(f"Spectrogram tensor shape: {spectrogram_tensor.shape}")

# 创建多头注意力模型实例
input_size = spectrogram_tensor.shape[2]
embed_dim = 128  # 嵌入维度
num_heads = 8  # 注意力头数量
model = MultiheadAttentionModel(input_size, embed_dim, num_heads)

# 前向传播
projected, attn_output, attn_weights = model(spectrogram_tensor)

# 打印输出形状
print("\n=== 输出形状 ===")
print(f"投影后输出形状: {projected.shape}")
print(f"注意力输出形状: {attn_output.shape}")
print(f"注意力权重形状: {attn_weights.shape}")

# 打印参数形状
print("\n=== 参数形状 ===")
print("输入投影层:")
print(f"  权重: {model.input_proj.weight.shape}")
print(f"  偏置: {model.input_proj.bias.shape}")

print("\n多头注意力层:")
print(f"  输入投影权重: {model.attention.in_proj_weight.shape}")
print(f"  输入投影偏置: {model.attention.in_proj_bias.shape}")
print(f"  输出投影权重: {model.attention.out_proj.weight.shape}")
print(f"  输出投影偏置: {model.attention.out_proj.bias.shape}")

# 可视化
plt.figure(figsize=(15, 10))

# 原始频谱
plt.subplot(2, 2, 1)
plt.imshow(spectrogram_db, aspect='auto', origin='lower', cmap='inferno')
plt.title("原始频谱")
plt.xlabel("时间帧")
plt.ylabel("频率点")
plt.colorbar(format='%+2.0f dB')

# 投影后特征
plt.subplot(2, 2, 2)
plt.imshow(projected[0].detach().numpy().T, aspect='auto', origin='lower', cmap='viridis')
plt.title("投影后特征")
plt.xlabel("时间步")
plt.ylabel("特征维度")
plt.colorbar(label='特征值')

# 注意力输出
plt.subplot(2, 2, 3)
plt.imshow(attn_output[0].detach().numpy().T, aspect='auto', origin='lower', cmap='plasma')
plt.title("注意力输出")
plt.xlabel("时间步")
plt.ylabel("特征维度")
plt.colorbar(label='特征值')

# 注意力权重(平均)
plt.subplot(2, 2, 4)
plt.imshow(attn_weights[0].detach().numpy(), aspect='auto', cmap='viridis')
plt.title("注意力权重(平均)")
plt.xlabel("键位置")
plt.ylabel("查询位置")
plt.colorbar(label='注意力权重')

plt.tight_layout()
plt.savefig('attention_spectrogram_features.png', dpi=300)
plt.show()
python 复制代码
Spectrogram tensor shape: torch.Size([1, 188, 257])

=== 输出形状 ===
投影后输出形状: torch.Size([1, 188, 128])
注意力输出形状: torch.Size([1, 188, 128])
注意力权重形状: torch.Size([1, 188, 188])

=== 参数形状 ===
输入投影层:
  权重: torch.Size([128, 257])
  偏置: torch.Size([128])

多头注意力层:
  输入投影权重: torch.Size([384, 128])
  输入投影偏置: torch.Size([384])
  输出投影权重: torch.Size([128, 128])
  输出投影偏置: torch.Size([128])
python 复制代码
attn_output, attn_weights = self.attention(projected, projected, projected)
  1. self.attention是一个多头注意力层,通常是通过 nn.MultiheadAttention 创建的。它实现了多头自注意力机制,能够处理输入序列并计算注意力权重。

  2. 这一行代码的输入参数为 projected,这是经过输入投影层处理后的特征张量,形状为 (batch_size, seq_length, embed_dim)。它被传递了三次,分别作为查询(Query)、键(Key)和值(Value)。

  3. 自注意力机制,在这里,使用相同的张量作为查询、键和值,表示模型在处理输入序列时关注自身的不同部分。这种方式称为自注意力(Self-Attention),允许模型在序列的不同位置之间建立关系。

python 复制代码
在多头注意力机制中,即使使用相同的张量作为查询(Query)、键(Key)和值(Value),模型仍然能够学习到丰富的特征表示,这主要归功于以下关键机制:
1. 可学习的线性变换(核心机制)
虽然输入是相同的张量,但在内部计算过程中,模型会应用不同的可学习线性变换:
# 内部计算过程(简化表示)
Q = projected @ W_q  # 查询变换
K = projected @ W_k  # 键变换
V = projected @ W_v  # 值变换
• W_q, W_k, W_v:三个独立的可学习权重矩阵
• 这些权重矩阵在训练过程中学习不同的特征表示
• 即使输入相同,不同的变换会产生不同的特征表示

2. 多头机制(多视角学习)
多头注意力将嵌入维度分割为多个"头",每个头有自己的变换:
# 多头分割
Q_heads = split_heads(Q)  # [batch, num_heads, seq_len, head_dim]
K_heads = split_heads(K)
V_heads = split_heads(V)
• 每个头学习不同的表示子空间
• 不同头关注输入的不同方面
• 即使输入相同,多头机制提供多视角分析

3. 注意力计算(动态加权)
注意力机制的核心是计算权重并加权求和:
# 注意力计算
attention_scores = softmax(Q @ K.T / sqrt(d_k))
attn_output = attention_scores @ V
• 动态权重分配:根据输入内容动态计算权重
• 内容相关:权重取决于输入特征之间的关系
• 位置无关:权重基于内容相似度,而非固定位置

4. 特征增强过程
使用相同输入的自注意力实际上是一种特征增强:
	1. 特征变换:通过W_q, W_k, W_v学习不同的特征表示
	2. 关系建模:计算元素之间的相似度关系
	3. 信息聚合:基于相似度加权聚合特征
	4. 特征精炼:输出包含全局上下文的新特征

5. 类比理解
想象一位艺术家画自画像:
• 相同输入:艺术家本人(就像相同的输入张量)
• 不同视角:
  • 镜子反射(W_q变换)
  • 照片参考(W_k变换)
  • 自我感知(W_v变换)
• 最终作品:融合多视角的独特自画像(就像attn_output)

6. 实际学习过程
在训练中:
1. 随机初始化:W_q, W_k, W_v初始化为不同值
2. 梯度下降:通过反向传播独立更新三个矩阵
3. 专业化:每个矩阵学习不同的特征提取模式
   • W_q:学习如何构建有效的查询
   • W_k:学习如何创建可匹配的键
   • W_v:学习如何编码有价值的信息
  1. attn_output :这是多头注意力层的输出,形状为 (batch_size, seq_length, embed_dim)。它是通过对值(Value)进行加权求和得到的,权重由查询(Query)和键(Key)之间的相似度计算得出。
  2. attn_weights :这是注意力权重,形状为 (batch_size, num_heads, seq_length, seq_length)。它表示每个查询位置对所有键位置的注意力分配。每个头的权重矩阵显示了模型在计算输出时关注的不同部分。
相关推荐
清风吹过3 小时前
LSTM新架构论文分享3:LSTM+卡尔曼滤波
人工智能·rnn·lstm
MPCTHU3 小时前
Deep Learning|03 Overview of Machine Learning
人工智能·深度学习·机器学习
Rock_yzh3 小时前
AI学习日记——卷积神经网络(CNN):完整实现与可视化分析
人工智能·python·深度学习·神经网络·学习·cnn
空白到白3 小时前
RNN-seq2seq 英译法案例
人工智能·rnn·深度学习
生信小白菜儿3 小时前
深度学习(DL)概念及实例操作
人工智能·python·深度学习
DatGuy3 小时前
Week 20: 深度学习补遗:Transformer Decoder架构
人工智能·深度学习·transformer
金融街小单纯4 小时前
从蓝军建设中学习颠覆性质疑思维
人工智能·算法·机器学习
Dongsheng_20194 小时前
【汽车篇】AI深度学习在汽车零部件外观检测——连接器X-Ray透视检测
人工智能·深度学习·汽车
没有梦想的咸鱼185-1037-16634 小时前
SCI论文写作:从实验设计到发表(选题、文献调研、实验设计、数据分析、论文结构及语言规范)
人工智能·信息可视化·数据分析·ai写作