多模态大模型学习笔记(十四)——transformer学习之Self-Attention

多模态大模型学习笔记(十四)------Transformer学习之Self-Attention

Self-Attention(自注意力机制)是Transformer架构的核心引擎,它解决了RNN类模型"长距离依赖建模困难"和"并行计算效率低"的痛点,让模型能同时捕捉序列中任意两个Token的语义关联。

本文将结合核心示意图,由浅入深地拆解Scaled Dot-Product Attention(缩放点积注意力)、Mask机制、Multi-Head Attention(多头注意力)的核心逻辑,配套修正后的数学原理与可运行代码实现,彻底吃透Self-Attention的工作机制。


1. 核心概念铺垫:Q、K、V的通俗隐喻与本质

在进入技术细节前,先理解Q、K、V的核心角色,这是掌握Self-Attention的关键。

1.1 通俗隐喻(地图-经纬度-物品)

概念 隐喻含义 核心作用
Query(Q,查询) 一张模糊的地图 代表当前Token想要"找什么",是发起检索的"需求向量"
Key(K,键) 准确的经纬度地址 代表序列中每个Token"提供什么",是被检索的"特征向量"
Value(V,值) 某空间内的贵重物品 代表序列中每个Token的"核心语义内容",是最终要提取的信息

1.2 技术本质

在Transformer中,Q、K、V并非天然存在,而是通过输入嵌入向量(Token Embedding + 位置编码) 经过3个独立的可学习线性层投影得到:

Q=X⋅WQ,K=X⋅WK,V=X⋅WV Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V Q=X⋅WQ,K=X⋅WK,V=X⋅WV

其中:

  • X∈RB×L×DmodelX \in \mathbb{R}^{B \times L \times D_{\text{model}}}X∈RB×L×Dmodel:输入嵌入向量(BBB为批次大小,LLL为序列长度,DmodelD_{\text{model}}Dmodel为模型隐藏层维度);
  • WQ,WK,WV∈RDmodel×DkW_Q, W_K, W_V \in \mathbb{R}^{D_{\text{model}} \times D_k}WQ,WK,WV∈RDmodel×Dk:可学习的投影矩阵(DkD_kDk为单个注意力头的维度)。

2. Scaled Dot-Product Attention:自注意力的基础单元

Scaled Dot-Product Attention(缩放点积注意力)是Self-Attention的最小可执行单元。

2.1 核心流程

按执行顺序拆解每一步的作用:

  1. MatMul(Q×Kᵀ) :计算Q与每个K的相似度(注意力分数),衡量当前Token与序列中其他Token的关联程度;
  2. Scale(缩放) :除以Dk\sqrt{D_k}Dk ,解决高维向量点积导致的"梯度消失"问题;
  3. Mask(可选,遮罩):对无效位置(如padding填充位、生成式任务的未来Token)赋值为负无穷,避免模型关注这些位置;
  4. SoftMax :将注意力分数归一化为0~1的概率分布,总和为1,代表对每个Token的"关注权重";
  5. MatMul(权重×V) :用归一化的注意力权重对V加权求和,得到融合了全局语义的当前Token表示

2.2 数学原理

(1)核心公式

Scaled Dot-Product Attention的完整数学表达式为:

Attention(Q,K,V)=SoftMax(QKTDk+M)V \text{Attention}(Q, K, V) = \text{SoftMax}\left( \frac{QK^T}{\sqrt{D_k}} + M \right) V Attention(Q,K,V)=SoftMax(Dk QKT+M)V

各参数维度说明:

  • Q∈RB×Lq×DkQ \in \mathbb{R}^{B \times L_q \times D_k}Q∈RB×Lq×Dk:查询序列矩阵(LqL_qLq为查询序列长度);
  • K∈RB×Lk×DkK \in \mathbb{R}^{B \times L_k \times D_k}K∈RB×Lk×Dk:键序列矩阵(LkL_kLk为键序列长度,Self-Attention中Lq=LkL_q=L_kLq=Lk);
  • V∈RB×Lk×DvV \in \mathbb{R}^{B \times L_k \times D_v}V∈RB×Lk×Dv:值序列矩阵(通常Dk=DvD_k=D_vDk=Dv);
  • M∈RB×Lq×LkM \in \mathbb{R}^{B \times L_q \times L_k}M∈RB×Lq×Lk:Mask矩阵(无效位置为−∞-\infty−∞,有效位置为0);
  • 输出:RB×Lq×Dv\mathbb{R}^{B \times L_q \times D_v}RB×Lq×Dv(融合全局语义的查询序列表示)。
(2)为什么要"缩放"?

当DkD_kDk较大时,QKTQK^TQKT的点积结果方差会随DkD_kDk线性增大,导致SoftMax输出极度趋近于0或1(梯度消失)。除以Dk\sqrt{D_k}Dk 可将方差归一化为1,保证梯度稳定:

Var(qi⋅kj)=Dk(假设 qi,kj∼N(0,1))qi⋅kjDk  ⟹  Var(qi⋅kjDk)=1 \begin{align} \text{Var}(q_i \cdot k_j) &= D_k \quad (\text{假设} \ q_i,k_j \sim \mathcal{N}(0,1)) \\ \frac{q_i \cdot k_j}{\sqrt{D_k}} &\implies \text{Var}\left( \frac{q_i \cdot k_j}{\sqrt{D_k}} \right) = 1 \end{align} Var(qi⋅kj)Dk qi⋅kj=Dk(假设 qi,kj∼N(0,1))⟹Var(Dk qi⋅kj)=1

(3)Mask的两种类型
  1. Padding Mask :针对不等长序列,屏蔽padding填充位:
    MPadding[b,i,j]={−∞,若Tokenj是padding位0,其他 M_{\text{Padding}}[b, i, j] = \begin{cases} -\infty, & \text{若Token}_j \text{是padding位} \\ 0, & \text{其他} \end{cases} MPadding[b,i,j]={−∞,0,若Tokenj是padding位其他
  2. Look-ahead Mask :针对生成式任务(如GPT),屏蔽"当前Token之后的所有位置":
    MLook-ahead[b,i,j]={−∞,若j>i0,其他 M_{\text{Look-ahead}}[b, i, j] = \begin{cases} -\infty, & \text{若} j > i \\ 0, & \text{其他} \end{cases} MLook-ahead[b,i,j]={−∞,0,若j>i其他

2.3 代码实现(PyTorch版)

python 复制代码
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    mask: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    实现Scaled Dot-Product Attention,与数学公式严格对应
    参数:
        q: [batch_size, seq_len_q, d_k]  查询矩阵
        k: [batch_size, seq_len_k, d_k]  键矩阵
        v: [batch_size, seq_len_k, d_v]  值矩阵
        mask: [batch_size, seq_len_q, seq_len_k]  Mask矩阵(可选)
    返回:
        output: [batch_size, seq_len_q, d_v]  注意力输出
        attn_weights: [batch_size, seq_len_q, seq_len_k]  注意力权重
    """
    # 1. 计算Q×K^T(对应公式中的QK^T)
    d_k = q.size(-1)
    attn_scores = torch.matmul(q, k.transpose(-2, -1))  # [B, L_q, L_k]
    
    # 2. 缩放(对应公式中的/√D_k)
    attn_scores = attn_scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # 3. 应用Mask(对应公式中的+M)
    if mask is not None:
        attn_scores = attn_scores.masked_fill(mask == 1, -1e9)  # Mask位设为-∞
    
    # 4. SoftMax归一化(对应公式中的SoftMax(·))
    attn_weights = F.softmax(attn_scores, dim=-1)  # [B, L_q, L_k]
    
    # 5. 权重×V(对应公式中的SoftMax(·)V)
    output = torch.matmul(attn_weights, v)  # [B, L_q, D_v]
    
    return output, attn_weights

# 测试代码
if __name__ == "__main__":
    # 模拟输入:B=2, L=5, D_k=64
    batch_size, seq_len, d_k = 2, 5, 64
    q = torch.randn(batch_size, seq_len, d_k)
    k = torch.randn(batch_size, seq_len, d_k)
    v = torch.randn(batch_size, seq_len, d_k)
    
    # 模拟Padding Mask:第2个样本的后2个Token是padding
    mask = torch.zeros(batch_size, seq_len, seq_len)
    mask[1, :, 3:] = 1  # [2,5,5]
    
    # 执行注意力计算
    output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
    
    print(f"Q/K/V形状: {q.shape}")
    print(f"注意力权重形状: {attn_weights.shape}")  # [2,5,5]
    print(f"注意力输出形状: {output.shape}")        # [2,5,64]

3. Multi-Head Attention:多头注意力机制

Multi-Head Attention是Scaled Dot-Product Attention的升级版本,解决了"单一注意力头无法捕捉多维度语义"的问题。

3.1 核心逻辑

多头注意力的核心思想是:将Q、K、V拆分为hhh个独立的"注意力头",每个头学习不同维度的语义关联,最后拼接并线性投影,融合所有头的信息

执行步骤:

  1. Linear投影:输入Q、K、V分别经过独立线性层,映射到高维空间;
  2. 拆分多头 :将投影后的Q、K、V按维度拆分为hhh个头;
  3. 单头注意力计算:每个头独立执行Scaled Dot-Product Attention;
  4. 拼接多头输出 :将hhh个头的输出按维度拼接;
  5. 最终线性投影:融合多头语义信息,得到最终结果。

3.2 数学原理

(1)多头拆分与投影

假设模型隐藏层维度为DmodelD_{\text{model}}Dmodel,注意力头数为hhh,则每个头的维度Dk=Dmodel/hD_k = D_{\text{model}} / hDk=Dmodel/h(必须整除):

总维度: Dmodel=h×Dk单头投影: Qi=Q⋅WQi, Ki=K⋅WKi, Vi=V⋅WVi(i=1,2,...,h)其中: WQi,WKi,WVi∈RDmodel×Dk \begin{align} & \text{总维度:} \ D_{\text{model}} = h \times D_k \\ & \text{单头投影:} \ Q_i = Q \cdot W_{Q_i}, \ K_i = K \cdot W_{K_i}, \ V_i = V \cdot W_{V_i} \quad (i=1,2,...,h) \\ & \text{其中:} \ W_{Q_i}, W_{K_i}, W_{V_i} \in \mathbb{R}^{D_{\text{model}} \times D_k} \end{align} 总维度: Dmodel=h×Dk单头投影: Qi=Q⋅WQi, Ki=K⋅WKi, Vi=V⋅WVi(i=1,2,...,h)其中: WQi,WKi,WVi∈RDmodel×Dk

(2)单头注意力与拼接

headi=Attention(Qi,Ki,Vi)(i=1,2,...,h)MultiHead(Q,K,V)=Concat(head1,head2,...,headh)⋅WO其中: WO∈RDmodel×Dmodel,输出∈RB×L×Dmodel \begin{align} \text{head}_i &= \text{Attention}(Q_i, K_i, V_i) \quad (i=1,2,...,h) \\ \text{MultiHead}(Q,K,V) &= \text{Concat}(\text{head}1, \text{head}2, ..., \text{head}h) \cdot W_O \\ \text{其中:} & \ W_O \in \mathbb{R}^{D{\text{model}} \times D{\text{model}}}, \quad \text{输出} \in \mathbb{R}^{B \times L \times D{\text{model}}} \end{align} headiMultiHead(Q,K,V)其中:=Attention(Qi,Ki,Vi)(i=1,2,...,h)=Concat(head1,head2,...,headh)⋅WO WO∈RDmodel×Dmodel,输出∈RB×L×Dmodel

  • headi∈RB×L×Dk\text{head}_i \in \mathbb{R}^{B \times L \times D_k}headi∈RB×L×Dk:第iii个头的注意力输出;
  • Concat(⋅)\text{Concat}(\cdot)Concat(⋅):按最后一维拼接(将hhh个DkD_kDk维度拼接为DmodelD_{\text{model}}Dmodel);
  • WOW_OWO:最终投影矩阵,融合多头语义信息。

3.3 代码实现(PyTorch版)

python 复制代码
import torch
import torch.nn as nn
from typing import Optional

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        """
        实现Multi-Head Attention,与数学公式严格对应
        参数:
            d_model: 模型总维度(如768),需满足 d_model % num_heads == 0
            num_heads: 注意力头数(如12)
        """
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 单头维度(对应公式中的D_k)
        
        # 1. 线性投影层(对应公式中的W_Q/W_K/W_V)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # 5. 最终投影层(对应公式中的W_O)
        self.w_o = nn.Linear(d_model, d_model)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        将投影后的向量拆分为多头(对应公式中的拆分步骤)
        输入:x [B, L, D_model]
        输出:x [B, num_heads, L, D_k]
        """
        batch_size, seq_len, _ = x.shape
        # 拆分:[B, L, num_heads, D_k] → 转置:[B, num_heads, L, D_k]
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def _concat_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        拼接多头输出(对应公式中的Concat步骤)
        输入:x [B, num_heads, L, D_k]
        输出:x [B, L, D_model]
        """
        batch_size, _, seq_len, _ = x.shape
        # 转置:[B, L, num_heads, D_k] → 拼接:[B, L, D_model]
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        前向传播(与数学公式严格对应)
        参数:
            q/k/v: [B, L, D_model]  输入矩阵
            mask: [B, L, L]  Mask矩阵(可选)
        返回:
            output: [B, L, D_model]  多头注意力输出
        """
        batch_size = q.size(0)
        
        # Step 1: 线性投影(对应公式中的Q·W_Q等)
        q_proj = self.w_q(q)  # [B, L, D_model]
        k_proj = self.w_k(k)  # [B, L, D_model]
        v_proj = self.w_v(v)  # [B, L, D_model]
        
        # Step 2: 拆分多头(对应公式中的Q_i等)
        q_heads = self._split_heads(q_proj)  # [B, h, L, D_k]
        k_heads = self._split_heads(k_proj)  # [B, h, L, D_k]
        v_heads = self._split_heads(v_proj)  # [B, h, L, D_k]
        
        # Step 3: 单头注意力计算(对应公式中的head_i)
        # 扩展Mask维度以匹配多头:[B, L, L] → [B, 1, L, L]
        mask_expanded = mask.unsqueeze(1) if mask is not None else None
        attn_output, _ = scaled_dot_product_attention(q_heads, k_heads, v_heads, mask_expanded)
        # attn_output: [B, h, L, D_k]
        
        # Step 4: 拼接多头输出(对应公式中的Concat)
        attn_concat = self._concat_heads(attn_output)  # [B, L, D_model]
        
        # Step 5: 最终线性投影(对应公式中的·W_O)
        output = self.w_o(attn_concat)  # [B, L, D_model]
        
        return output

# 测试代码
if __name__ == "__main__":
    # 初始化:D_model=768,h=12(BERT-base配置)
    mha = MultiHeadAttention(d_model=768, num_heads=12)
    
    # 模拟输入:B=2, L=10, D_model=768(Self-Attention中Q=K=V)
    batch_size, seq_len, d_model = 2, 10, 768
    x = torch.randn(batch_size, seq_len, d_model)
    
    # 模拟Look-ahead Mask(生成式任务)
    look_ahead_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # [10,10]
    look_ahead_mask = look_ahead_mask.unsqueeze(0).repeat(batch_size, 1, 1)  # [2,10,10]
    
    # 执行多头注意力
    output = mha(x, x, x, mask=look_ahead_mask)
    
    print(f"输入形状: {x.shape}")
    print(f"多头注意力输出形状: {output.shape}")  # [2,10,768](与输入维度一致)

4. Self-Attention vs 普通Attention:关键区别

Self-Attention是Attention机制的一个特例,其核心公式为:

Self-Attention(X)=MultiHead(XWQ,XWK,XWV) \text{Self-Attention}(X) = \text{MultiHead}(XW_Q, XW_K, XW_V) Self-Attention(X)=MultiHead(XWQ,XWK,XWV)

与普通Attention的区别:

  • 普通Attention(如机器翻译的Encoder-Decoder Attention):Q来自Decoder,K、V来自Encoder,用于"目标序列对齐源序列";
  • Self-AttentionQ=K=V,均来自同一序列(如Encoder的输入),用于"序列内部Token之间的语义关联建模"。

这也是为什么Self-Attention能高效捕捉长文本的上下文依赖------它能同时计算序列中任意两个Token的注意力权重,无需像RNN那样逐词遍历。


5. 总结

  1. 基础单元:Scaled Dot-Product Attention通过"Q×Kᵀ相似度计算→缩放→Mask→SoftMax归一化→加权求和V",实现单个Token的全局语义融合;
  2. 升级版本:Multi-Head Attention通过"拆分多头→独立注意力计算→拼接→线性投影",捕捉多维度语义关联,是Transformer的核心;
  3. 核心优势:并行计算效率高、长距离依赖建模能力强,是大模型处理文本、图像等序列数据的基础。
相关推荐
MoRanzhi12032 小时前
Pillow 图像算术运算与通道计算
图像处理·人工智能·python·计算机视觉·pillow·图像差异检测·图像算术运算
小超同学你好2 小时前
Langgraph 4. 反思 Reflection
人工智能·语言模型·langchain
带娃的IT创业者2 小时前
神经形态意识模块理论基础详解:六大核心理论支柱
人工智能·深度学习·脑科学·神经科学·认知科学·意识理论·ai 架构
北京阿法龙科技有限公司2 小时前
解放双手,透视数据:AR+AI技术正在如何解决新能源储能行业的老大难问题
人工智能·ar
产品人卫朋2 小时前
AI硬件产品怎么做?——桌面机器人
人工智能·机器人
K姐研究社2 小时前
阿里QoderWork实测 – 打工人桌面AI助手,零配置替代OpenClaw
人工智能·aigc
机器觉醒时代2 小时前
DreamZero:从语言理解到世界建模——具身智能的WAM新范式
人工智能·具身智能·人形机器人·世界模型
FluxMelodySun2 小时前
机器学习(二十一) 集成学习-结合策略与多样性
人工智能·机器学习·集成学习
WangUnionpub2 小时前
别只盯着MDPI,又贵还卡单位,平替SCI/EI,免收版面费,这本15天录用!
大数据·人工智能·深度学习·物联网·计算机视觉