【大语言模型】—— 自注意力机制及其变体(交叉注意力、因果注意力、多头注意力)的代码实现

【大语言模型】------ 注意力机制及其变体的代码实现

摘要

本文介绍了注意力机制的几种变体及其PyTorch代码实现。主要包括:

  1. Self-Attention:基础自注意力机制,通过Q、K、V计算注意力权重,适用于序列内部建模。
  2. CrossAttention:让一个序列关注另一个序列,典型应用于Transformer解码器-编码器交互和多模态任务。
  3. CausalAttention:通过三角掩码实现因果性,确保只能关注当前位置之前的token,适用于自回归生成任务。
  4. MultiHeadSelfAttention:多头注意力机制,将输入分割到多个子空间并行计算注意力,最后合并结果。

代码实现中详细展示了各注意力的关键操作,包括线性变换、注意力分数计算、softmax归一化和掩码处理等。特别解释了dim=-1的作用、unsqueeze(0)的广播机制等实现细节。

Self-Attention

python 复制代码
class SelfAttention(nn.module):
# torch.matmul     专用于批量矩阵乘法,适用于形状为 (batch_size, n, m) 和 (batch_size, m, p) 的 3D 张量。
# torch.matmul  支持更灵活的张量乘法运算
	def __init__(self, input_dim, dim_k,dim_v):
		super().__init__()
		self.q = nn.Linear(input_dim, dim_k)
		self.k = nn.Linear(input_dim, dim_k)
		selv.v = nn.Linear(input_dim, dim_v)
		self.scale = np.sqrt(dim_k)
		
	def forward(self, x):
		Q = self.q(x)
		K = self.k(x)
		V = self.v(x)
	
		atten = torch.softmax(torch.matmul(Q, K.permute(0,2,1))/self.scale, dim=-1)
		return torch.matmul(atten, V)

为什么 dim=-1

在自注意力机制中,nn.Softmax(dim=-1)的作用是对 ​​注意力分数矩阵​​ 进行归一化,使得每一行的权重之和为 1。这里 dim=-1表示在最后一个维度(即 seq_len维度)上进行 Softmax 计算。

在 Q K T QK^T QKT计算后,得到的注意力分数矩阵的形状是 [batch_size, seq_len, seq_len],其中:

​​第 1 个 seq_len(dim=1)​​:代表 Q的序列长度(即当前 token 的位置)。

​​第 2 个 seq_len(dim=2)​​:代表 K的序列长度(即被计算注意力的 token 的位置)。
dim=-1(即 dim=2)表示 ​​对每个 token 计算它对所有 token 的注意力权重​​,即 ​​对每一行进行 Softmax​​,使得:

每一行的所有值加起来等于 1(概率分布)。

这样,每个 token 的注意力权重是独立计算的。

假设 Q K T QK^T QKT的结果是:

bash 复制代码
[
  [[1.0, 0.5, 0.2],  # Token 0 对所有 token 的注意力分数
   [0.3, 1.2, 0.7],  # Token 1 对所有 token 的注意力分数
   [0.1, 0.4, 1.5]]  # Token 2 对所有 token 的注意力分数
]

应用 nn.Softmax(dim=-1)后:

bash 复制代码
[
  [[0.55, 0.27, 0.18],  # Token 0 的注意力权重(总和=1)
   [0.16, 0.58, 0.26],  # Token 1 的注意力权重(总和=1)
   [0.07, 0.20, 0.73]]  # Token 2 的注意力权重(总和=1)
]

这样,每个 token 的注意力权重是独立的,且所有 token 对它的影响权重之和为 1。

CrossAttention

python 复制代码
# 查询通常来自解码器,键和值通常来自编码器
import torch
import torch.nn as nn
import numpy as np

class CrossAttention(nn.Module):
    def __init__(self, input_dim, dim_k, dim_v):
        super().__init__()  # 必须调用父类初始化
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self.scale = np.sqrt(dim_k)
    
    def forward(self, x1, x2):
        Q1 = self.q(x1)  # [batch_size, seq_len1, dim_k]
        K2 = self.k(x2)  # [batch_size, seq_len2, dim_k]
        V2 = self.v(x2)  # [batch_size, seq_len2, dim_v]
        
        # 计算注意力分数
        atten = torch.softmax(torch.matmul(Q1, K2.permute(0, 2, 1)) / self.scale, dim=-1)  # [batch_size, seq_len1, seq_len2]
        
        # 加权求和
        return torch.matmul(atten, V2)  # [batch_size, seq_len1, dim_v]

交叉注意力的作用

交叉注意力用于 让一个序列 x 1 x1 x1关注另一个序列 x 2 x2 x2,典型应用包括:

  1. Transformer 解码器:
    • x1= 解码器的输入(当前生成的 token)
    • x2= 编码器的输出(源序列的表示)
    • 解码器通过交叉注意力关注编码器的信息。
  2. 多模态任务(如视觉-语言模型):
    • x1= 文本序列
    • x2= 图像特征
    • 文本通过交叉注意力关注图像的关键区域。

CausalAttention

python 复制代码
class CausalAttention(nn.Module):
	def __init__(self,input_dim, dim_k,dim_v):
		super().__init__()
		self.q = nn.Linear(input_dim, dim_k)
		self.k = nn.Linear(input_dim, dim_k)
		self.v = nn.Linear(input_dim, dim_v)
		self.scale = np.sqrt(dim_k)
    def forward(self, x):
        # x: [batch_size, seq_len, input_dim]
        Q = self.q(x)  # [batch_size, seq_len, dim_k]
        K = self.k(x)  # [batch_size, seq_len, dim_k]
        V = self.v(x)  # [batch_size, seq_len, dim_v]

        # 注意力分数
        atten = torch.matmul(Q, K.permute(0, 2, 1)) / self.scale  # [batch, seq, seq]

        # 下三角 mask,确保因果性(只能看到之前的token)
        seq_len = atten.size(-1)
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
        atten = atten.masked_fill(mask == 0, float('-inf'))

        # softmax 归一化
        atten = nn.Softmax(dim=-1)(atten)

        # 输出加权和
        return torch.matmul(atten, V)  # [batch, seq, dim_v]

unsqueeze(0)的作用​​

假设我们有一个 2D 张量 mask,形状是 [seq_len, seq_len]

python 复制代码
mask = torch.tril(torch.ones(seq_len, seq_len))  # 形状 [seq_len, seq_len]

如果我们想让它变成 [1, seq_len, seq_len](即增加一个 batch 维度),可以使用:

python 复制代码
mask = mask.unsqueeze(0)  # 形状变为 [1, seq_len, seq_len]

这样做的目的是:

  • 匹配注意力分数矩阵的形状(atten的形状是 [batch_size, seq_len, seq_len])。
  • 支持批量计算,因为 mask需要广播到所有 batch 样本。

这样:

mask的形状变成 [1, seq_len, seq_len]

PyTorch 会自动广播 mask到 [batch_size, seq_len, seq_len],使其与 atten形状匹配。

MultiHeadSelfAttention

python 复制代码
class MultiHeadAttention(nn.Module):
	def __init__(self, heads, input_dim, dim_k, dim_v):
		super().__init__()
		self.heads = heads
		
		self.dim_k_per_head = dim_k // heads
		self.dim_v_per_head = dim_v // heads
		
		self.q = nn.Linear(input_dim, dim_k)
		self.k = nn.Linear(input_dim, dim_k)
		self.v = nn.Linear(input_dim, dim_v)
		
		self.scale = np.sqrt(self.dim_k_per_head)
		
		self.out = nn.Linear(dim_v, input_dim)
	def forward(self, x):
		batch_size = x.size(0)
		
		Q = self.q(x)#[batch_size, seq_len, dim_k]
		K = self.k(x)
		V = self.v(x)
		
		#[batch_size, seq_len, heads, dim_k_per_head] 
		#		--> [batch_size, heads, seq_len, dim_k_per_head]
		Q = Q.view(batch_size, -1, self.heads, self.dim_k_per_head).permute(0,2,1,3)
		#[batch_size, seq_len, heads, dim_k_per_head] 
		#		--> [batch_size, heads, seq_len, dim_k_per_head]
		K = K.view(batch_size, -1, self.heads, self.dim_k_per_head).permute(0,2,1,3)
		#[batch_size, seq_len, heads, dim_v_per_head] 
		#		--> [batch_size, heads, seq_len, dim_v_per_head]
		V = V.view(batch_size, -1, self.heads, self.dim_v_per_head).permute(0,2,1,3)
		
		#转置[batch_size, heads, seq_len, dim_k_per_head] 
		#		--> [batch_size, heads, dim_k_per_head, seq_len]
		K = K.permute(0, 1, 3, 2)
		
		# [batch_size, heads, seq_len, seq_len]
		atten = torch.softmax(torch.matmul(Q,K) / self.scale, dim = -1)
		# [batch_size, heads, seq_len, dim_v_per_head]
		out = torch.matmul(atten, V)
		# [batch_size, seq_len, heads, dim_v_per_head]
		out = out.permute(0, 2, 1, 3).contiguous()
		# [batch_size, seq_len, heads* dim_v_per_head]
		out = out.view(batch_size, -1, self.heads * self.dim_v_per_head)
		return self.out(out) # [batch_size, seq_len, input_dim]

多头自注意力(Multi-Head Attention)的核心思想

多头自注意力(Multi-Head Attention)的核心思想是将输入向量分别映射为查询 Q、键 K、值 V,再按照头数切分到多个子空间中;每个头独立计算注意力分数并得到加权表示,最后拼接各头的结果,通过线性层 out 映射回输入维度,从而捕捉序列中多角度的相关性。

在使用时需要注意以下几点:

  1. 维度整除 :要确保 dim_kdim_v 能被 heads 整除,否则 view 时会报错。
  2. 缩放因子 :缩放应该基于每个头的维度 sqrt(dim_k_per_head),而不是整体的 dim_k
  3. Softmax 顺序 :正确做法是 Softmax(QK^T / scale),不要写成 Softmax(QK^T) / scale
  4. 张量连续性permute 之后用 .contiguous().view(),否则可能报错;或者用 reshape 自动处理。
  5. 输出层作用self.out 的作用是把多头拼接后的结果重新映射回输入维度,保持层间维度一致。
相关推荐
从孑开始4 小时前
ManySpeech.MoonshineAsr 使用指南
人工智能·ai·c#·.net·私有化部署·语音识别·onnx·asr·moonshine
涛涛讲AI4 小时前
一段音频多段字幕,让音频能够流畅自然对应字幕 AI生成视频,扣子生成剪映视频草稿
人工智能·音视频·语音识别
可触的未来,发芽的智生4 小时前
新奇特:黑猫警长的纳米世界,忆阻器与神经网络的智慧
javascript·人工智能·python·神经网络·架构
WWZZ20255 小时前
快速上手大模型:机器学习2(一元线性回归、代价函数、梯度下降法)
人工智能·算法·机器学习·计算机视觉·机器人·大模型·slam
AKAMAI5 小时前
数据孤岛破局之战 :跨业务分析的难题攻坚
运维·人工智能·云计算
Chicheng_MA5 小时前
算能 CV184 智能相机整体方案介绍
人工智能·数码相机·算能
Element_南笙5 小时前
吴恩达新课程:Agentic AI(笔记2)
数据库·人工智能·笔记·python·深度学习·ui·自然语言处理
倔强青铜三5 小时前
苦练Python第69天:subprocess模块从入门到上瘾,手把手教你驯服系统命令!
人工智能·python·面试
Antonio9155 小时前
【图像处理】rgb和srgb
图像处理·人工智能·数码相机